Commit 950a0e2405

Veikka Tuominen <git@vexu.eu>
2022-09-27 13:56:56
Sema: implement `inline else` for errors enums and bools
1 parent 0e77259
Changed files (3)
src
test
behavior
cases
src/Sema.zig
@@ -9309,8 +9309,19 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         break :blk sema.typeOf(raw_operand);
     };
     const union_originally = maybe_union_ty.zigTypeTag() == .Union;
-    var seen_union_fields: []?Module.SwitchProngSrc = &.{};
-    defer gpa.free(seen_union_fields);
+
+    // Duplicate checking variables later also used for `inline else`.
+    var seen_enum_fields: []?Module.SwitchProngSrc = &.{};
+    var seen_errors = SwitchErrorSet.init(gpa);
+    var range_set = RangeSet.init(gpa, sema.mod);
+    var true_count: u8 = 0;
+    var false_count: u8 = 0;
+
+    defer {
+        range_set.deinit();
+        gpa.free(seen_enum_fields);
+        seen_errors.deinit();
+    }
 
     var empty_enum = false;
 
@@ -9347,15 +9358,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     switch (operand_ty.zigTypeTag()) {
         .Union => unreachable, // handled in zirSwitchCond
         .Enum => {
-            var seen_fields = try gpa.alloc(?Module.SwitchProngSrc, operand_ty.enumFieldCount());
-            empty_enum = seen_fields.len == 0 and !operand_ty.isNonexhaustiveEnum();
-            defer if (!union_originally) gpa.free(seen_fields);
-            if (union_originally) seen_union_fields = seen_fields;
-            mem.set(?Module.SwitchProngSrc, seen_fields, null);
-
-            // This is used for non-exhaustive enum values that do not correspond to any tags.
-            var range_set = RangeSet.init(gpa, sema.mod);
-            defer range_set.deinit();
+            seen_enum_fields = try gpa.alloc(?Module.SwitchProngSrc, operand_ty.enumFieldCount());
+            empty_enum = seen_enum_fields.len == 0 and !operand_ty.isNonexhaustiveEnum();
+            mem.set(?Module.SwitchProngSrc, seen_enum_fields, null);
+            // `range_set` is used for non-exhaustive enum values that do not correspond to any tags.
 
             var extra_index: usize = special.end;
             {
@@ -9369,7 +9375,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
 
                     try sema.validateSwitchItemEnum(
                         block,
-                        seen_fields,
+                        seen_enum_fields,
                         &range_set,
                         item_ref,
                         src_node_offset,
@@ -9392,7 +9398,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     for (items) |item_ref, item_i| {
                         try sema.validateSwitchItemEnum(
                             block,
-                            seen_fields,
+                            seen_enum_fields,
                             &range_set,
                             item_ref,
                             src_node_offset,
@@ -9403,7 +9409,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                     try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset);
                 }
             }
-            const all_tags_handled = for (seen_fields) |seen_src| {
+            const all_tags_handled = for (seen_enum_fields) |seen_src| {
                 if (seen_src == null) break false;
             } else true;
 
@@ -9423,7 +9429,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
                         .{},
                     );
                     errdefer msg.destroy(sema.gpa);
-                    for (seen_fields) |seen_src, i| {
+                    for (seen_enum_fields) |seen_src, i| {
                         if (seen_src != null) continue;
 
                         const field_name = operand_ty.enumFieldName(i);
@@ -9454,9 +9460,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             }
         },
         .ErrorSet => {
-            var seen_errors = SwitchErrorSet.init(gpa);
-            defer seen_errors.deinit();
-
             var extra_index: usize = special.end;
             {
                 var scalar_i: u32 = 0;
@@ -9596,9 +9599,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             }
         },
         .Int, .ComptimeInt => {
-            var range_set = RangeSet.init(gpa, sema.mod);
-            defer range_set.deinit();
-
             var extra_index: usize = special.end;
             {
                 var scalar_i: u32 = 0;
@@ -9694,9 +9694,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             }
         },
         .Bool => {
-            var true_count: u8 = 0;
-            var false_count: u8 = 0;
-
             var extra_index: usize = special.end;
             {
                 var scalar_i: u32 = 0;
@@ -9950,16 +9947,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
         return sema.resolveBlockBody(block, src, &child_block, special.body, inst, merges);
     }
 
-    if (scalar_cases_len + multi_cases_len == 0) {
+    if (scalar_cases_len + multi_cases_len == 0 and !special.is_inline) {
         if (empty_enum) {
             return Air.Inst.Ref.void_value;
         }
         if (special_prong == .none) {
             return sema.fail(block, src, "switch must handle all possibilities", .{});
         }
-        if (special.is_inline) {
-            return sema.fail(block, src, "TODO special.is_inline", .{});
-        }
         if (err_set and try sema.maybeErrorUnwrap(block, special.body, operand)) {
             return Air.Inst.Ref.unreachable_value;
         }
@@ -10323,16 +10317,181 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     if (special.body.len != 0 or !is_first or case_block.wantSafety()) {
         var wip_captures = try WipCaptureScope.init(gpa, sema.perm_arena, child_block.wip_capture_scope);
         defer wip_captures.deinit();
+        if (special.is_inline) switch (operand_ty.zigTypeTag()) {
+            .Enum => {
+                if (operand_ty.isNonexhaustiveEnum() and !union_originally) {
+                    return sema.fail(block, special_prong_src, "cannot enumerate values of type '{}' for 'inline else'", .{
+                        operand_ty.fmt(sema.mod),
+                    });
+                }
+                var emit_bb = false;
+                for (seen_enum_fields) |f, i| {
+                    if (f != null) continue;
+                    cases_len += 1;
+
+                    const item_val = try Value.Tag.enum_field_index.create(sema.arena, @intCast(u32, i));
+                    const item_ref = try sema.addConstant(operand_ty, item_val);
+                    case_block.inline_case_capture = item_ref;
+
+                    case_block.instructions.shrinkRetainingCapacity(0);
+                    case_block.wip_capture_scope = child_block.wip_capture_scope;
+
+                    const analyze_body = if (union_originally) blk: {
+                        const field_ty = maybe_union_ty.unionFieldType(item_val, sema.mod);
+                        break :blk field_ty.zigTypeTag() != .NoReturn;
+                    } else true;
+
+                    if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src);
+                    emit_bb = true;
+
+                    if (analyze_body) {
+                        _ = sema.analyzeBodyInner(&case_block, special.body) catch |err| switch (err) {
+                            error.ComptimeBreak => {
+                                const zir_datas = sema.code.instructions.items(.data);
+                                const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                                try sema.addRuntimeBreak(&case_block, .{
+                                    .block_inst = break_data.block_inst,
+                                    .operand = break_data.operand,
+                                    .inst = sema.comptime_break_inst,
+                                });
+                            },
+                            else => |e| return e,
+                        };
+                    } else {
+                        _ = try case_block.addNoOp(.unreach);
+                    }
+
+                    // try wip_captures.finalize();
+
+                    try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
+                    cases_extra.appendAssumeCapacity(1); // items_len
+                    cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len));
+                    cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture));
+                    cases_extra.appendSliceAssumeCapacity(case_block.instructions.items);
+                }
+            },
+            .ErrorSet => {
+                if (operand_ty.isAnyError()) {
+                    return sema.fail(block, special_prong_src, "cannot enumerate values of type '{}' for 'inline else'", .{
+                        operand_ty.fmt(sema.mod),
+                    });
+                }
+                var emit_bb = false;
+                for (operand_ty.errorSetNames()) |error_name| {
+                    if (seen_errors.contains(error_name)) continue;
+                    cases_len += 1;
+
+                    const item_val = try Value.Tag.@"error".create(sema.arena, .{ .name = error_name });
+                    const item_ref = try sema.addConstant(operand_ty, item_val);
+                    case_block.inline_case_capture = item_ref;
+
+                    case_block.instructions.shrinkRetainingCapacity(0);
+                    case_block.wip_capture_scope = child_block.wip_capture_scope;
+
+                    if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src);
+                    emit_bb = true;
+
+                    _ = sema.analyzeBodyInner(&case_block, special.body) catch |err| switch (err) {
+                        error.ComptimeBreak => {
+                            const zir_datas = sema.code.instructions.items(.data);
+                            const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                            try sema.addRuntimeBreak(&case_block, .{
+                                .block_inst = break_data.block_inst,
+                                .operand = break_data.operand,
+                                .inst = sema.comptime_break_inst,
+                            });
+                        },
+                        else => |e| return e,
+                    };
+
+                    // try wip_captures.finalize();
+
+                    try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
+                    cases_extra.appendAssumeCapacity(1); // items_len
+                    cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len));
+                    cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture));
+                    cases_extra.appendSliceAssumeCapacity(case_block.instructions.items);
+                }
+            },
+            .Int => {
+                return sema.fail(block, special_prong_src, "TODO 'inline else' Int", .{});
+            },
+            .Bool => {
+                var emit_bb = false;
+                if (true_count == 0) {
+                    cases_len += 1;
+                    case_block.inline_case_capture = Air.Inst.Ref.bool_true;
+
+                    case_block.instructions.shrinkRetainingCapacity(0);
+                    case_block.wip_capture_scope = child_block.wip_capture_scope;
+
+                    if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src);
+                    emit_bb = true;
+
+                    _ = sema.analyzeBodyInner(&case_block, special.body) catch |err| switch (err) {
+                        error.ComptimeBreak => {
+                            const zir_datas = sema.code.instructions.items(.data);
+                            const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                            try sema.addRuntimeBreak(&case_block, .{
+                                .block_inst = break_data.block_inst,
+                                .operand = break_data.operand,
+                                .inst = sema.comptime_break_inst,
+                            });
+                        },
+                        else => |e| return e,
+                    };
+
+                    // try wip_captures.finalize();
+
+                    try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
+                    cases_extra.appendAssumeCapacity(1); // items_len
+                    cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len));
+                    cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture));
+                    cases_extra.appendSliceAssumeCapacity(case_block.instructions.items);
+                }
+                if (false_count == 0) {
+                    cases_len += 1;
+                    case_block.inline_case_capture = Air.Inst.Ref.bool_false;
+
+                    case_block.instructions.shrinkRetainingCapacity(0);
+                    case_block.wip_capture_scope = child_block.wip_capture_scope;
+
+                    if (emit_bb) try sema.emitBackwardBranch(block, special_prong_src);
+                    emit_bb = true;
+
+                    _ = sema.analyzeBodyInner(&case_block, special.body) catch |err| switch (err) {
+                        error.ComptimeBreak => {
+                            const zir_datas = sema.code.instructions.items(.data);
+                            const break_data = zir_datas[sema.comptime_break_inst].@"break";
+                            try sema.addRuntimeBreak(&case_block, .{
+                                .block_inst = break_data.block_inst,
+                                .operand = break_data.operand,
+                                .inst = sema.comptime_break_inst,
+                            });
+                        },
+                        else => |e| return e,
+                    };
+
+                    // try wip_captures.finalize();
+
+                    try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len);
+                    cases_extra.appendAssumeCapacity(1); // items_len
+                    cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len));
+                    cases_extra.appendAssumeCapacity(@enumToInt(case_block.inline_case_capture));
+                    cases_extra.appendSliceAssumeCapacity(case_block.instructions.items);
+                }
+            },
+            else => return sema.fail(block, special_prong_src, "cannot enumerate values of type '{}' for 'inline else'", .{
+                operand_ty.fmt(sema.mod),
+            }),
+        };
 
         case_block.instructions.shrinkRetainingCapacity(0);
         case_block.wip_capture_scope = wip_captures.scope;
         case_block.inline_case_capture = .none;
-        if (special.is_inline) {
-            return sema.fail(block, src, "TODO special.is_inline", .{});
-        }
 
-        const analyze_body = if (union_originally)
-            for (seen_union_fields) |seen_field, index| {
+        const analyze_body = if (union_originally and !special.is_inline)
+            for (seen_enum_fields) |seen_field, index| {
                 if (seen_field != null) continue;
                 const union_obj = maybe_union_ty.cast(Type.Payload.Union).?.data;
                 const field_ty = union_obj.fields.values()[index].ty;
@@ -10344,7 +10503,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
             try sema.maybeErrorUnwrap(&case_block, special.body, operand))
         {
             // nothing to do here
-        } else if (special.body.len != 0 and analyze_body) {
+        } else if (special.body.len != 0 and analyze_body and !special.is_inline) {
             _ = sema.analyzeBodyInner(&case_block, special.body) catch |err| switch (err) {
                 error.ComptimeBreak => {
                     const zir_datas = sema.code.instructions.items(.data);
test/behavior/inline_switch.zig
@@ -65,3 +65,36 @@ test "inline switch unions" {
         },
     }
 }
+
+test "inline else bool" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    var a = true;
+    switch (a) {
+        true => {},
+        inline else => |val| if (val != false) @compileError("bad"),
+    }
+}
+
+test "inline else error" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    const Err = error{ a, b, c };
+    var a = Err.a;
+    switch (a) {
+        error.a => {},
+        inline else => |val| comptime if (val == error.a) @compileError("bad"),
+    }
+}
+
+test "inline else enum" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const E2 = enum(u8) { a = 2, b = 3, c = 4, d = 5 };
+    var a: E2 = .a;
+    switch (a) {
+        .a, .b => {},
+        inline else => |val| comptime if (@enumToInt(val) < 4) @compileError("bad"),
+    }
+}
test/cases/compile_errors/invalid_inline_else_type.zig
@@ -0,0 +1,27 @@
+pub export fn entry1() void {
+    var a: anyerror = undefined;
+    switch (a) {
+        inline else => {},
+    }
+}
+const E = enum(u8) { a, _ };
+pub export fn entry2() void {
+    var a: E = undefined;
+    switch (a) {
+        inline else => {},
+    }
+}
+pub export fn entry3() void {
+    var a: *u32 = undefined;
+    switch (a) {
+        inline else => {},
+    }
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :4:21: error: cannot enumerate values of type 'anyerror' for 'inline else'
+// :11:21: error: cannot enumerate values of type 'tmp.E' for 'inline else'
+// :17:21: error: cannot enumerate values of type '*u32' for 'inline else'