Commit 8f28e26e7a

Andrew Kelley <andrew@ziglang.org>
2021-04-08 01:39:10
Sema: implement switch validation for enums
1 parent ccdba77
Changed files (3)
src
test
stage2
src/Sema.zig
@@ -2729,6 +2729,8 @@ fn analyzeSwitch(
     src_node_offset: i32,
 ) InnerError!*Inst {
     const gpa = sema.gpa;
+    const mod = sema.mod;
+
     const special: struct { body: []const zir.Inst.Index, end: usize } = switch (special_prong) {
         .none => .{ .body = &.{}, .end = extra_end },
         .under, .@"else" => blk: {
@@ -2748,14 +2750,14 @@ fn analyzeSwitch(
     // Validate usage of '_' prongs.
     if (special_prong == .under and !operand.ty.isExhaustiveEnum()) {
         const msg = msg: {
-            const msg = try sema.mod.errMsg(
+            const msg = try mod.errMsg(
                 &block.base,
                 src,
                 "'_' prong only allowed when switching on non-exhaustive enums",
                 .{},
             );
             errdefer msg.destroy(gpa);
-            try sema.mod.errNote(
+            try mod.errNote(
                 &block.base,
                 special_prong_src,
                 msg,
@@ -2764,14 +2766,121 @@ fn analyzeSwitch(
             );
             break :msg msg;
         };
-        return sema.mod.failWithOwnedErrorMsg(&block.base, msg);
+        return mod.failWithOwnedErrorMsg(&block.base, msg);
     }
 
     // Validate for duplicate items, missing else prong, and invalid range.
     switch (operand.ty.zigTypeTag()) {
-        .Enum => return sema.mod.fail(&block.base, src, "TODO validate switch .Enum", .{}),
-        .ErrorSet => return sema.mod.fail(&block.base, src, "TODO validate switch .ErrorSet", .{}),
-        .Union => return sema.mod.fail(&block.base, src, "TODO validate switch .Union", .{}),
+        .Enum => {
+            var seen_fields = try gpa.alloc(?AstGen.SwitchProngSrc, operand.ty.enumFieldCount());
+            defer gpa.free(seen_fields);
+
+            var extra_index: usize = special.end;
+            {
+                var scalar_i: u32 = 0;
+                while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
+                    const item_ref = @intToEnum(zir.Inst.Ref, sema.code.extra[extra_index]);
+                    extra_index += 1;
+                    const body_len = sema.code.extra[extra_index];
+                    extra_index += 1;
+                    const body = sema.code.extra[extra_index..][0..body_len];
+                    extra_index += body_len;
+
+                    try sema.validateSwitchItemEnum(
+                        block,
+                        seen_fields,
+                        item_ref,
+                        src_node_offset,
+                        .{ .scalar = scalar_i },
+                    );
+                }
+            }
+            {
+                var multi_i: u32 = 0;
+                while (multi_i < multi_cases_len) : (multi_i += 1) {
+                    const items_len = sema.code.extra[extra_index];
+                    extra_index += 1;
+                    const ranges_len = sema.code.extra[extra_index];
+                    extra_index += 1;
+                    const body_len = sema.code.extra[extra_index];
+                    extra_index += 1;
+                    const items = sema.code.refSlice(extra_index, items_len);
+                    extra_index += items_len + body_len;
+
+                    for (items) |item_ref, item_i| {
+                        try sema.validateSwitchItemEnum(
+                            block,
+                            seen_fields,
+                            item_ref,
+                            src_node_offset,
+                            .{ .multi = .{ .prong = multi_i, .item = @intCast(u32, item_i) } },
+                        );
+                    }
+
+                    try sema.validateSwitchNoRange(block, ranges_len, operand.ty, src_node_offset);
+                }
+            }
+            const all_tags_handled = for (seen_fields) |seen_src| {
+                if (seen_src == null) break false;
+            } else true;
+
+            switch (special_prong) {
+                .none => {
+                    if (!all_tags_handled) {
+                        const msg = msg: {
+                            const msg = try mod.errMsg(
+                                &block.base,
+                                src,
+                                "switch must handle all possibilities",
+                                .{},
+                            );
+                            errdefer msg.destroy(sema.gpa);
+                            try mod.errNoteNonLazy(
+                                operand.ty.declSrcLoc(),
+                                msg,
+                                "enum '{}' declared here",
+                                .{operand.ty},
+                            );
+                            for (seen_fields) |seen_src, i| {
+                                if (seen_src != null) continue;
+
+                                const field_name = operand.ty.enumFieldName(i);
+
+                                // TODO have this point to the tag decl instead of here
+                                try mod.errNote(
+                                    &block.base,
+                                    src,
+                                    msg,
+                                    "unhandled enumeration value: '{s}",
+                                    .{field_name},
+                                );
+                            }
+                            break :msg msg;
+                        };
+                        return mod.failWithOwnedErrorMsg(&block.base, msg);
+                    }
+                },
+                .under => {
+                    if (all_tags_handled) return mod.fail(
+                        &block.base,
+                        special_prong_src,
+                        "unreachable '_' prong; all cases already handled",
+                        .{},
+                    );
+                },
+                .@"else" => {
+                    if (all_tags_handled) return mod.fail(
+                        &block.base,
+                        special_prong_src,
+                        "unreachable else prong; all cases already handled",
+                        .{},
+                    );
+                },
+            }
+        },
+
+        .ErrorSet => return mod.fail(&block.base, src, "TODO validate switch .ErrorSet", .{}),
+        .Union => return mod.fail(&block.base, src, "TODO validate switch .Union", .{}),
         .Int, .ComptimeInt => {
             var range_set = RangeSet.init(gpa);
             defer range_set.deinit();
@@ -2844,11 +2953,11 @@ fn analyzeSwitch(
                     var arena = std.heap.ArenaAllocator.init(gpa);
                     defer arena.deinit();
 
-                    const min_int = try operand.ty.minInt(&arena, sema.mod.getTarget());
-                    const max_int = try operand.ty.maxInt(&arena, sema.mod.getTarget());
+                    const min_int = try operand.ty.minInt(&arena, mod.getTarget());
+                    const max_int = try operand.ty.maxInt(&arena, mod.getTarget());
                     if (try range_set.spans(min_int, max_int)) {
                         if (special_prong == .@"else") {
-                            return sema.mod.fail(
+                            return mod.fail(
                                 &block.base,
                                 special_prong_src,
                                 "unreachable else prong; all cases already handled",
@@ -2859,7 +2968,7 @@ fn analyzeSwitch(
                     }
                 }
                 if (special_prong != .@"else") {
-                    return sema.mod.fail(
+                    return mod.fail(
                         &block.base,
                         src,
                         "switch must handle all possibilities",
@@ -2922,7 +3031,7 @@ fn analyzeSwitch(
             switch (special_prong) {
                 .@"else" => {
                     if (true_count + false_count == 2) {
-                        return sema.mod.fail(
+                        return mod.fail(
                             &block.base,
                             src,
                             "unreachable else prong; all cases already handled",
@@ -2932,7 +3041,7 @@ fn analyzeSwitch(
                 },
                 .under, .none => {
                     if (true_count + false_count < 2) {
-                        return sema.mod.fail(
+                        return mod.fail(
                             &block.base,
                             src,
                             "switch must handle all possibilities",
@@ -2944,7 +3053,7 @@ fn analyzeSwitch(
         },
         .EnumLiteral, .Void, .Fn, .Pointer, .Type => {
             if (special_prong != .@"else") {
-                return sema.mod.fail(
+                return mod.fail(
                     &block.base,
                     src,
                     "else prong required when switching on type '{}'",
@@ -3016,7 +3125,7 @@ fn analyzeSwitch(
         .AnyFrame,
         .ComptimeFloat,
         .Float,
-        => return sema.mod.fail(&block.base, operand_src, "invalid switch operand type '{}'", .{
+        => return mod.fail(&block.base, operand_src, "invalid switch operand type '{}'", .{
             operand.ty,
         }),
     }
@@ -3291,7 +3400,7 @@ fn resolveSwitchItemVal(
     switch_node_offset: i32,
     switch_prong_src: AstGen.SwitchProngSrc,
     range_expand: AstGen.SwitchProngSrc.RangeExpand,
-) InnerError!Value {
+) InnerError!TypedValue {
     const item = try sema.resolveInst(item_ref);
     // We have to avoid the other helper functions here because we cannot construct a LazySrcLoc
     // because we only have the switch AST node. Only if we know for sure we need to report
@@ -3301,7 +3410,7 @@ fn resolveSwitchItemVal(
             const src = switch_prong_src.resolve(block.src_decl, switch_node_offset, range_expand);
             return sema.failWithUseOfUndef(block, src);
         }
-        return val;
+        return TypedValue{ .ty = item.ty, .val = val };
     }
     const src = switch_prong_src.resolve(block.src_decl, switch_node_offset, range_expand);
     return sema.failWithNeededComptime(block, src);
@@ -3316,8 +3425,8 @@ fn validateSwitchRange(
     src_node_offset: i32,
     switch_prong_src: AstGen.SwitchProngSrc,
 ) InnerError!void {
-    const first_val = try sema.resolveSwitchItemVal(block, first_ref, src_node_offset, switch_prong_src, .first);
-    const last_val = try sema.resolveSwitchItemVal(block, last_ref, src_node_offset, switch_prong_src, .last);
+    const first_val = (try sema.resolveSwitchItemVal(block, first_ref, src_node_offset, switch_prong_src, .first)).val;
+    const last_val = (try sema.resolveSwitchItemVal(block, last_ref, src_node_offset, switch_prong_src, .last)).val;
     const maybe_prev_src = try range_set.add(first_val, last_val, switch_prong_src);
     return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset);
 }
@@ -3330,11 +3439,46 @@ fn validateSwitchItem(
     src_node_offset: i32,
     switch_prong_src: AstGen.SwitchProngSrc,
 ) InnerError!void {
-    const item_val = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none);
+    const item_val = (try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none)).val;
     const maybe_prev_src = try range_set.add(item_val, item_val, switch_prong_src);
     return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset);
 }
 
+fn validateSwitchItemEnum(
+    sema: *Sema,
+    block: *Scope.Block,
+    seen_fields: []?AstGen.SwitchProngSrc,
+    item_ref: zir.Inst.Ref,
+    src_node_offset: i32,
+    switch_prong_src: AstGen.SwitchProngSrc,
+) InnerError!void {
+    const mod = sema.mod;
+    const item_tv = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none);
+    const field_index = item_tv.ty.enumTagFieldIndex(item_tv.val) orelse {
+        const msg = msg: {
+            const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .none);
+            const msg = try mod.errMsg(
+                &block.base,
+                src,
+                "enum '{}' has no tag with value '{}'",
+                .{ item_tv.ty, item_tv.val },
+            );
+            errdefer msg.destroy(sema.gpa);
+            try mod.errNoteNonLazy(
+                item_tv.ty.declSrcLoc(),
+                msg,
+                "enum declared here",
+                .{},
+            );
+            break :msg msg;
+        };
+        return mod.failWithOwnedErrorMsg(&block.base, msg);
+    };
+    const maybe_prev_src = seen_fields[field_index];
+    seen_fields[field_index] = switch_prong_src;
+    return sema.validateSwitchDupe(block, maybe_prev_src, switch_prong_src, src_node_offset);
+}
+
 fn validateSwitchDupe(
     sema: *Sema,
     block: *Scope.Block,
@@ -3343,17 +3487,18 @@ fn validateSwitchDupe(
     src_node_offset: i32,
 ) InnerError!void {
     const prev_prong_src = maybe_prev_src orelse return;
+    const mod = sema.mod;
     const src = switch_prong_src.resolve(block.src_decl, src_node_offset, .none);
     const prev_src = prev_prong_src.resolve(block.src_decl, src_node_offset, .none);
     const msg = msg: {
-        const msg = try sema.mod.errMsg(
+        const msg = try mod.errMsg(
             &block.base,
             src,
             "duplicate switch value",
             .{},
         );
         errdefer msg.destroy(sema.gpa);
-        try sema.mod.errNote(
+        try mod.errNote(
             &block.base,
             prev_src,
             msg,
@@ -3362,7 +3507,7 @@ fn validateSwitchDupe(
         );
         break :msg msg;
     };
-    return sema.mod.failWithOwnedErrorMsg(&block.base, msg);
+    return mod.failWithOwnedErrorMsg(&block.base, msg);
 }
 
 fn validateSwitchItemBool(
@@ -3374,7 +3519,7 @@ fn validateSwitchItemBool(
     src_node_offset: i32,
     switch_prong_src: AstGen.SwitchProngSrc,
 ) InnerError!void {
-    const item_val = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none);
+    const item_val = (try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none)).val;
     if (item_val.toBool()) {
         true_count.* += 1;
     } else {
@@ -3396,7 +3541,7 @@ fn validateSwitchItemSparse(
     src_node_offset: i32,
     switch_prong_src: AstGen.SwitchProngSrc,
 ) InnerError!void {
-    const item_val = try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none);
+    const item_val = (try sema.resolveSwitchItemVal(block, item_ref, src_node_offset, switch_prong_src, .none)).val;
     const entry = (try seen_values.fetchPut(item_val, switch_prong_src)) orelse return;
     return sema.validateSwitchDupe(block, entry.value, switch_prong_src, src_node_offset);
 }
src/type.zig
@@ -2126,6 +2126,34 @@ pub const Type = extern union {
         };
     }
 
+    pub fn enumFieldCount(ty: Type) usize {
+        switch (ty.tag()) {
+            .enum_full, .enum_nonexhaustive => {
+                const enum_full = ty.cast(Payload.EnumFull).?.data;
+                return enum_full.fields.count();
+            },
+            .enum_simple => {
+                const enum_simple = ty.castTag(.enum_simple).?.data;
+                return enum_simple.fields.count();
+            },
+            else => unreachable,
+        }
+    }
+
+    pub fn enumFieldName(ty: Type, field_index: usize) []const u8 {
+        switch (ty.tag()) {
+            .enum_full, .enum_nonexhaustive => {
+                const enum_full = ty.cast(Payload.EnumFull).?.data;
+                return enum_full.fields.entries.items[field_index].key;
+            },
+            .enum_simple => {
+                const enum_simple = ty.castTag(.enum_simple).?.data;
+                return enum_simple.fields.entries.items[field_index].key;
+            },
+            else => unreachable,
+        }
+    }
+
     pub fn enumFieldIndex(ty: Type, field_name: []const u8) ?usize {
         switch (ty.tag()) {
             .enum_full, .enum_nonexhaustive => {
@@ -2140,6 +2168,42 @@ pub const Type = extern union {
         }
     }
 
+    /// Asserts `ty` is an enum. `enum_tag` can either be `enum_field_index` or
+    /// an integer which represents the enum value. Returns the field index in
+    /// declaration order, or `null` if `enum_tag` does not match any field.
+    pub fn enumTagFieldIndex(ty: Type, enum_tag: Value) ?usize {
+        if (enum_tag.castTag(.enum_field_index)) |payload| {
+            return @as(usize, payload.data);
+        }
+        const S = struct {
+            fn fieldWithRange(int_val: Value, end: usize) ?usize {
+                if (int_val.compareWithZero(.lt)) return null;
+                var end_payload: Value.Payload.U64 = .{
+                    .base = .{ .tag = .int_u64 },
+                    .data = end,
+                };
+                const end_val = Value.initPayload(&end_payload.base);
+                if (int_val.compare(.gte, end_val)) return null;
+                return int_val.toUnsignedInt();
+            }
+        };
+        switch (ty.tag()) {
+            .enum_full, .enum_nonexhaustive => {
+                const enum_full = ty.cast(Payload.EnumFull).?.data;
+                if (enum_full.values.count() == 0) {
+                    return S.fieldWithRange(enum_tag, enum_full.fields.count());
+                } else {
+                    return enum_full.values.getIndex(enum_tag);
+                }
+            },
+            .enum_simple => {
+                const enum_simple = ty.castTag(.enum_simple).?.data;
+                return S.fieldWithRange(enum_tag, enum_simple.fields.count());
+            },
+            else => unreachable,
+        }
+    }
+
     pub fn declSrcLoc(ty: Type) Module.SrcLoc {
         switch (ty.tag()) {
             .enum_full, .enum_nonexhaustive => {
test/stage2/cbe.zig
@@ -552,7 +552,11 @@ pub fn addCases(ctx: *TestContext) !void {
             \\    if (@enumToInt(number3) != 2) return 1;
             \\    var x: Number = .Two;
             \\    if (number2 != x) return 1;
-            \\    return 0;
+            \\    switch (x) {
+            \\        .One => return 1,
+            \\        .Two => return 0,
+            \\        number3 => return 2,
+            \\    }
             \\}
         , "");
     }