Commit e272c29c16

Andrew Kelley <andrew@ziglang.org>
2021-04-01 00:06:03
Sema: implement switch validation for ranges
1 parent c7b09be
src/AstGen.zig
@@ -2538,18 +2538,106 @@ fn forExpr(
 fn getRangeNode(
     node_tags: []const ast.Node.Tag,
     node_datas: []const ast.Node.Data,
-    start_node: ast.Node.Index,
+    node: ast.Node.Index,
 ) ?ast.Node.Index {
-    var node = start_node;
-    while (true) {
-        switch (node_tags[node]) {
-            .switch_range => return node,
-            .grouped_expression => node = node_datas[node].lhs,
-            else => return null,
-        }
+    switch (node_tags[node]) {
+        .switch_range => return node,
+        .grouped_expression => unreachable,
+        else => return null,
     }
 }
 
+pub const SwitchProngSrc = union(enum) {
+    scalar: u32,
+    multi: Multi,
+    range: Multi,
+
+    pub const Multi = struct {
+        prong: u32,
+        item: u32,
+    };
+
+    /// This function is intended to be called only when it is certain that we need
+    /// the LazySrcLoc in order to emit a compile error.
+    pub fn resolve(
+        prong_src: SwitchProngSrc,
+        decl: *Decl,
+        switch_node_offset: i32,
+        range_expand: enum { none, first, last },
+    ) LazySrcLoc {
+        @setCold(true);
+        const switch_node = decl.relativeToNodeIndex(switch_node_offset);
+        const tree = decl.container.file_scope.base.tree();
+        const main_tokens = tree.nodes.items(.main_token);
+        const node_datas = tree.nodes.items(.data);
+        const node_tags = tree.nodes.items(.tag);
+        const extra = tree.extraData(node_datas[switch_node].rhs, ast.Node.SubRange);
+        const case_nodes = tree.extra_data[extra.start..extra.end];
+
+        var multi_i: u32 = 0;
+        var scalar_i: u32 = 0;
+        for (case_nodes) |case_node| {
+            const case = switch (node_tags[case_node]) {
+                .switch_case_one => tree.switchCaseOne(case_node),
+                .switch_case => tree.switchCase(case_node),
+                else => unreachable,
+            };
+            if (case.ast.values.len == 0)
+                continue;
+            if (case.ast.values.len == 1 and
+                node_tags[case.ast.values[0]] == .identifier and
+                mem.eql(u8, tree.tokenSlice(main_tokens[case.ast.values[0]]), "_"))
+            {
+                continue;
+            }
+            const is_multi = case.ast.values.len != 1 or
+                getRangeNode(node_tags, node_datas, case.ast.values[0]) != null;
+
+            switch (prong_src) {
+                .scalar => |i| if (!is_multi and i == scalar_i) return LazySrcLoc{
+                    .node_offset = decl.nodeIndexToRelative(case.ast.values[0]),
+                },
+                .multi => |s| if (is_multi and s.prong == multi_i) {
+                    var item_i: u32 = 0;
+                    for (case.ast.values) |item_node| {
+                        if (getRangeNode(node_tags, node_datas, item_node) != null)
+                            continue;
+
+                        if (item_i == s.item) return LazySrcLoc{
+                            .node_offset = decl.nodeIndexToRelative(item_node),
+                        };
+                        item_i += 1;
+                    } else unreachable;
+                },
+                .range => |s| if (is_multi and s.prong == multi_i) {
+                    var range_i: u32 = 0;
+                    for (case.ast.values) |item_node| {
+                        const range = getRangeNode(node_tags, node_datas, item_node) orelse continue;
+
+                        if (range_i == s.item) switch (range_expand) {
+                            .none => return LazySrcLoc{
+                                .node_offset = decl.nodeIndexToRelative(item_node),
+                            },
+                            .first => return LazySrcLoc{
+                                .node_offset = decl.nodeIndexToRelative(node_datas[range].lhs),
+                            },
+                            .last => return LazySrcLoc{
+                                .node_offset = decl.nodeIndexToRelative(node_datas[range].rhs),
+                            },
+                        };
+                        range_i += 1;
+                    } else unreachable;
+                },
+            }
+            if (is_multi) {
+                multi_i += 1;
+            } else {
+                scalar_i += 1;
+            }
+        } else unreachable;
+    }
+};
+
 fn switchExpr(
     parent_gz: *GenZir,
     scope: *Scope,
src/codegen.zig
@@ -3994,7 +3994,10 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
         fn fail(self: *Self, src: LazySrcLoc, comptime format: []const u8, args: anytype) InnerError {
             @setCold(true);
             assert(self.err_msg == null);
-            const src_loc = src.toSrcLocWithDecl(self.mod_fn.owner_decl);
+            const src_loc = if (src != .unneeded)
+                src.toSrcLocWithDecl(self.mod_fn.owner_decl)
+            else
+                self.src_loc;
             self.err_msg = try ErrorMsg.create(self.bin_file.allocator, src_loc, format, args);
             return error.CodegenFail;
         }
src/RangeSet.zig
@@ -2,13 +2,14 @@ const std = @import("std");
 const Order = std.math.Order;
 const Value = @import("value.zig").Value;
 const RangeSet = @This();
+const SwitchProngSrc = @import("AstGen.zig").SwitchProngSrc;
 
 ranges: std.ArrayList(Range),
 
 pub const Range = struct {
     start: Value,
     end: Value,
-    src: usize,
+    src: SwitchProngSrc,
 };
 
 pub fn init(allocator: *std.mem.Allocator) RangeSet {
@@ -21,7 +22,7 @@ pub fn deinit(self: *RangeSet) void {
     self.ranges.deinit();
 }
 
-pub fn add(self: *RangeSet, start: Value, end: Value, src: usize) !?usize {
+pub fn add(self: *RangeSet, start: Value, end: Value, src: SwitchProngSrc) !?SwitchProngSrc {
     for (self.ranges.items) |range| {
         if ((start.compare(.gte, range.start) and start.compare(.lte, range.end)) or
             (end.compare(.gte, range.start) and end.compare(.lte, range.end)))
src/Sema.zig
@@ -60,6 +60,7 @@ const InnerError = Module.InnerError;
 const Decl = Module.Decl;
 const LazySrcLoc = Module.LazySrcLoc;
 const RangeSet = @import("RangeSet.zig");
+const AstGen = @import("AstGen.zig");
 
 const ValueSrcMap = std.HashMap(Value, LazySrcLoc, Value.hash, Value.eql, std.hash_map.DefaultMaxLoadPercentage);
 
@@ -419,19 +420,27 @@ fn resolveType(sema: *Sema, block: *Scope.Block, src: LazySrcLoc, zir_ref: zir.I
 
 fn resolveConstValue(sema: *Sema, block: *Scope.Block, src: LazySrcLoc, base: *ir.Inst) !Value {
     return (try sema.resolveDefinedValue(block, src, base)) orelse
-        return sema.mod.fail(&block.base, src, "unable to resolve comptime value", .{});
+        return sema.failWithNeededComptime(block, src);
 }
 
 fn resolveDefinedValue(sema: *Sema, block: *Scope.Block, src: LazySrcLoc, base: *ir.Inst) !?Value {
     if (base.value()) |val| {
         if (val.isUndef()) {
-            return sema.mod.fail(&block.base, src, "use of undefined value here causes undefined behavior", .{});
+            return sema.failWithUseOfUndef(block, src);
         }
         return val;
     }
     return null;
 }
 
+fn failWithNeededComptime(sema: *Sema, block: *Scope.Block, src: LazySrcLoc) InnerError {
+    return sema.mod.fail(&block.base, src, "unable to resolve comptime value", .{});
+}
+
+fn failWithUseOfUndef(sema: *Sema, block: *Scope.Block, src: LazySrcLoc) InnerError {
+    return sema.mod.fail(&block.base, src, "use of undefined value here causes undefined behavior", .{});
+}
+
 /// Appropriate to call when the coercion has already been done by result
 /// location semantics. Asserts the value fits in the provided `Int` type.
 /// Only supports `Int` types 64 bits or less.
@@ -2368,7 +2377,7 @@ fn analyzeSwitch(
 
             var extra_index: usize = special.end;
             {
-                var scalar_i: usize = 0;
+                var scalar_i: u32 = 0;
                 while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
                     const item_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[extra_index]);
                     extra_index += 1;
@@ -2382,11 +2391,12 @@ fn analyzeSwitch(
                         &range_set,
                         item_ref,
                         src_node_offset,
+                        .{ .scalar = scalar_i },
                     );
                 }
             }
             {
-                var multi_i: usize = 0;
+                var multi_i: u32 = 0;
                 while (multi_i < multi_cases_len) : (multi_i += 1) {
                     const items_len = sema.code.extra[extra_index];
                     extra_index += 1;
@@ -2397,16 +2407,17 @@ fn analyzeSwitch(
                     const items = sema.code.refSlice(extra_index, items_len);
                     extra_index += items_len;
 
-                    for (items) |item_ref| {
+                    for (items) |item_ref, item_i| {
                         try sema.validateSwitchItem(
                             block,
                             &range_set,
                             item_ref,
                             src_node_offset,
+                            .{ .multi = .{ .prong = multi_i, .item = @intCast(u32, item_i) } },
                         );
                     }
 
-                    var range_i: usize = 0;
+                    var range_i: u32 = 0;
                     while (range_i < ranges_len) : (range_i += 1) {
                         const item_first = @intToEnum(zir.Inst.Ref, sema.code.extra[extra_index]);
                         extra_index += 1;
@@ -2419,6 +2430,7 @@ fn analyzeSwitch(
                             item_first,
                             item_last,
                             src_node_offset,
+                            .{ .range = .{ .prong = multi_i, .item = range_i } },
                         );
                     }
 
@@ -2723,8 +2735,9 @@ fn analyzeSwitch(
         extra_index += body_len;
 
         case_block.instructions.shrinkRetainingCapacity(0);
-        const item = try sema.resolveInst(item_ref);
-        const item_val = try sema.resolveConstValue(&case_block, item.src, item);
+        // We validate these above; these two calls are guaranteed to succeed.
+        const item = sema.resolveInst(item_ref) catch unreachable;
+        const item_val = sema.resolveConstValue(&case_block, .unneeded, item) catch unreachable;
 
         _ = try sema.analyzeBody(&case_block, body);
 
@@ -2836,48 +2849,133 @@ fn analyzeSwitch(
         prev_condbr = new_condbr;
     }
 
-    case_block.instructions.shrinkRetainingCapacity(0);
-    _ = try sema.analyzeBody(&case_block, special.body);
-    const else_body: Body = .{
-        .instructions = try sema.arena.dupe(*Inst, case_block.instructions.items),
-    };
-    first_condbr.else_body = else_body;
-
-    const final_else_body: Body = .{
-        .instructions = try sema.arena.dupe(*Inst, &[1]*Inst{&first_condbr.base}),
+    const final_else_body: Body = blk: {
+        if (special.body.len != 0) {
+            case_block.instructions.shrinkRetainingCapacity(0);
+            _ = try sema.analyzeBody(&case_block, special.body);
+            const else_body: Body = .{
+                .instructions = try sema.arena.dupe(*Inst, case_block.instructions.items),
+            };
+            if (prev_condbr != null) {
+                first_condbr.else_body = else_body;
+                break :blk .{
+                    .instructions = try sema.arena.dupe(*Inst, &[1]*Inst{&first_condbr.base}),
+                };
+            } else {
+                break :blk else_body;
+            }
+        } else {
+            break :blk .{ .instructions = &.{} };
+        }
     };
 
     _ = try child_block.addSwitchBr(src, operand, cases, final_else_body);
     return sema.analyzeBlockBody(block, &child_block, merges);
 }
 
+fn validateSwitchRange(
+    sema: *Sema,
+    block: *Scope.Block,
+    range_set: *RangeSet,
+    first_ref: zir.Inst.Ref,
+    last_ref: zir.Inst.Ref,
+    src_node_offset: i32,
+    switch_prong_src: AstGen.SwitchProngSrc,
+) InnerError!void {
+    const first = try sema.resolveInst(first_ref);
+    const last = try sema.resolveInst(last_ref);
+    // We have to avoid the helper functions here because we cannot construct a LazySrcLoc
+    // because we only have the switch AST node. Only if we know for sure we need to report
+    // a compile error do we resolve the full source locations.
+    const first_val = val: {
+        if (last.value()) |val| {
+            if (val.isUndef()) {
+                const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .first);
+                return sema.failWithUseOfUndef(block, src);
+            }
+            break :val val;
+        }
+        const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .first);
+        return sema.failWithNeededComptime(block, src);
+    };
+    const last_val = val: {
+        if (first.value()) |val| {
+            if (val.isUndef()) {
+                const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .last);
+                return sema.failWithUseOfUndef(block, src);
+            }
+            break :val val;
+        }
+        const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .last);
+        return sema.failWithNeededComptime(block, src);
+    };
+    const maybe_prev_src = try range_set.add(first_val, last_val, switch_prong_src);
+    return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset);
+}
+
 fn validateSwitchItem(
     sema: *Sema,
     block: *Scope.Block,
     range_set: *RangeSet,
     item_ref: zir.Inst.Ref,
     src_node_offset: i32,
+    switch_prong_src: AstGen.SwitchProngSrc,
 ) InnerError!void {
-    @panic("TODO");
+    const item = try sema.resolveInst(item_ref);
+    // We have to avoid the helper functions here because we cannot construct a LazySrcLoc
+    // because we only have the switch AST node. Only if we know for sure we need to report
+    // a compile error do we resolve the full source locations.
+    const value = val: {
+        if (item.value()) |val| {
+            if (val.isUndef()) {
+                const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .none);
+                return sema.failWithUseOfUndef(block, src);
+            }
+            break :val val;
+        }
+        const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .none);
+        return sema.failWithNeededComptime(block, src);
+    };
+    const maybe_prev_src = try range_set.add(value, value, switch_prong_src);
+    return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset);
 }
 
-fn validateSwitchItemBool(
+fn validateSwitchDupe(
     sema: *Sema,
     block: *Scope.Block,
-    true_count: *u8,
-    false_count: *u8,
-    item_ref: zir.Inst.Ref,
+    maybe_prev_src: ?AstGen.SwitchProngSrc,
+    switch_prong_src: AstGen.SwitchProngSrc,
     src_node_offset: i32,
 ) InnerError!void {
-    @panic("TODO");
+    const prev_prong_src = maybe_prev_src orelse return;
+    const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .none);
+    const prev_src = prev_prong_src.resolve(block.src_decl, src_node_offset, .none);
+    const msg = msg: {
+        const msg = try sema.mod.errMsg(
+            &block.base,
+            src,
+            "duplicate switch value",
+            .{},
+        );
+        errdefer msg.destroy(sema.gpa);
+        try sema.mod.errNote(
+            &block.base,
+            prev_src,
+            msg,
+            "previous value here",
+            .{},
+        );
+        break :msg msg;
+    };
+    return sema.mod.failWithOwnedErrorMsg(&block.base, msg);
 }
 
-fn validateSwitchRange(
+fn validateSwitchItemBool(
     sema: *Sema,
     block: *Scope.Block,
-    range_set: *RangeSet,
-    item_first: zir.Inst.Ref,
-    item_last: zir.Inst.Ref,
+    true_count: *u8,
+    false_count: *u8,
+    item_ref: zir.Inst.Ref,
     src_node_offset: i32,
 ) InnerError!void {
     @panic("TODO");