Commit 0e77259f44

Veikka Tuominen <git@vexu.eu>
2022-09-26 16:30:24
add inline switch union tag captures
1 parent 5baaf90
lib/std/zig/parse.zig
@@ -3100,7 +3100,7 @@ const Parser = struct {
         return identifier;
     }
 
-    /// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrPayload? AssignExpr
+    /// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrIndexPayload? AssignExpr
     /// SwitchCase
     ///     <- SwitchItem (COMMA SwitchItem)* COMMA?
     ///      / KEYWORD_else
@@ -3123,7 +3123,7 @@ const Parser = struct {
             }
         }
         const arrow_token = try p.expectToken(.equal_angle_bracket_right);
-        _ = try p.parsePtrPayload();
+        _ = try p.parsePtrIndexPayload();
 
         const items = p.scratch.items[scratch_top..];
         switch (items.len) {
lib/std/zig/parser_test.zig
@@ -3276,6 +3276,8 @@ test "zig fmt: switch" {
         \\    switch (u) {
         \\        Union.Int => |int| {},
         \\        Union.Float => |*float| unreachable,
+        \\        1 => |a, b| unreachable,
+        \\        2 => |*a, b| unreachable,
         \\    }
         \\}
         \\
lib/std/zig/render.zig
@@ -1541,13 +1541,17 @@ fn renderSwitchCase(
 
     if (switch_case.payload_token) |payload_token| {
         try renderToken(ais, tree, payload_token - 1, .none); // pipe
+        const ident = payload_token + @boolToInt(token_tags[payload_token] == .asterisk);
         if (token_tags[payload_token] == .asterisk) {
             try renderToken(ais, tree, payload_token, .none); // asterisk
-            try renderToken(ais, tree, payload_token + 1, .none); // identifier
-            try renderToken(ais, tree, payload_token + 2, pre_target_space); // pipe
+        }
+        try renderToken(ais, tree, ident, .none); // identifier
+        if (token_tags[ident + 1] == .comma) {
+            try renderToken(ais, tree, ident + 1, .space); // ,
+            try renderToken(ais, tree, ident + 2, .none); // identifier
+            try renderToken(ais, tree, ident + 3, pre_target_space); // pipe
         } else {
-            try renderToken(ais, tree, payload_token, .none); // identifier
-            try renderToken(ais, tree, payload_token + 1, pre_target_space); // pipe
+            try renderToken(ais, tree, ident + 1, pre_target_space); // pipe
         }
     }
 
src/arch/x86_64/Emit.zig
@@ -2159,7 +2159,7 @@ const RegisterOrMemory = union(enum) {
     /// Returns size in bits.
     fn size(reg_or_mem: RegisterOrMemory) u64 {
         return switch (reg_or_mem) {
-            .register => |reg| reg.size(),
+            .register => |register| register.size(),
             .memory => |memory| memory.size(),
         };
     }
src/stage1/parser.cpp
@@ -2306,17 +2306,17 @@ static Optional<PtrIndexPayload> ast_parse_ptr_index_payload(ParseContext *pc) {
     return Optional<PtrIndexPayload>::some(res);
 }
 
-// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrPayload? AssignExpr
+// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrIndexPayload? AssignExpr
 static AstNode *ast_parse_switch_prong(ParseContext *pc) {
     AstNode *res = ast_parse_switch_case(pc);
     if (res == nullptr)
         return nullptr;
 
     expect_token(pc, TokenIdFatArrow);
-    Optional<PtrPayload> opt_payload = ast_parse_ptr_payload(pc);
+    Optional<PtrIndexPayload> opt_payload = ast_parse_ptr_index_payload(pc);
     AstNode *expr = ast_expect(pc, ast_parse_assign_expr);
 
-    PtrPayload payload;
+    PtrIndexPayload payload;
     assert(res->type == NodeTypeSwitchProng);
     res->data.switch_prong.expr = expr;
     if (opt_payload.unwrap(&payload)) {
src/AstGen.zig
@@ -2373,6 +2373,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .switch_capture_ref,
             .switch_capture_multi,
             .switch_capture_multi_ref,
+            .switch_capture_tag,
             .struct_init_empty,
             .struct_init,
             .struct_init_ref,
@@ -6378,8 +6379,12 @@ fn switchExpr(
 
         var dbg_var_name: ?u32 = null;
         var dbg_var_inst: Zir.Inst.Ref = undefined;
+        var dbg_var_tag_name: ?u32 = null;
+        var dbg_var_tag_inst: Zir.Inst.Ref = undefined;
         var capture_inst: Zir.Inst.Index = 0;
+        var tag_inst: Zir.Inst.Index = 0;
         var capture_val_scope: Scope.LocalVal = undefined;
+        var tag_scope: Scope.LocalVal = undefined;
         const sub_scope = blk: {
             const payload_token = case.payload_token orelse break :blk &case_scope.base;
             const ident = if (token_tags[payload_token] == .asterisk)
@@ -6387,59 +6392,96 @@ fn switchExpr(
             else
                 payload_token;
             const is_ptr = ident != payload_token;
-            if (mem.eql(u8, tree.tokenSlice(ident), "_")) {
+            const ident_slice = tree.tokenSlice(ident);
+            var payload_sub_scope: *Scope = undefined;
+            if (mem.eql(u8, ident_slice, "_")) {
                 if (is_ptr) {
                     return astgen.failTok(payload_token, "pointer modifier invalid on discard", .{});
                 }
-                break :blk &case_scope.base;
-            }
-            if (case_node == special_node) {
-                const capture_tag: Zir.Inst.Tag = if (is_ptr)
-                    .switch_capture_ref
-                else
-                    .switch_capture;
-                capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
-                try astgen.instructions.append(gpa, .{
-                    .tag = capture_tag,
-                    .data = .{
-                        .switch_capture = .{
-                            .switch_inst = switch_block,
-                            // Max int communicates that this is the else/underscore prong.
-                            .prong_index = std.math.maxInt(u32),
-                        },
-                    },
-                });
+                payload_sub_scope = &case_scope.base;
             } else {
-                const is_multi_case_bits: u2 = @boolToInt(is_multi_case);
-                const is_ptr_bits: u2 = @boolToInt(is_ptr);
-                const capture_tag: Zir.Inst.Tag = switch ((is_multi_case_bits << 1) | is_ptr_bits) {
-                    0b00 => .switch_capture,
-                    0b01 => .switch_capture_ref,
-                    0b10 => .switch_capture_multi,
-                    0b11 => .switch_capture_multi_ref,
+                if (case_node == special_node) {
+                    const capture_tag: Zir.Inst.Tag = if (is_ptr)
+                        .switch_capture_ref
+                    else
+                        .switch_capture;
+                    capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
+                    try astgen.instructions.append(gpa, .{
+                        .tag = capture_tag,
+                        .data = .{
+                            .switch_capture = .{
+                                .switch_inst = switch_block,
+                                // Max int communicates that this is the else/underscore prong.
+                                .prong_index = std.math.maxInt(u32),
+                            },
+                        },
+                    });
+                } else {
+                    const is_multi_case_bits: u2 = @boolToInt(is_multi_case);
+                    const is_ptr_bits: u2 = @boolToInt(is_ptr);
+                    const capture_tag: Zir.Inst.Tag = switch ((is_multi_case_bits << 1) | is_ptr_bits) {
+                        0b00 => .switch_capture,
+                        0b01 => .switch_capture_ref,
+                        0b10 => .switch_capture_multi,
+                        0b11 => .switch_capture_multi_ref,
+                    };
+                    const capture_index = if (is_multi_case) multi_case_index else scalar_case_index;
+                    capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
+                    try astgen.instructions.append(gpa, .{
+                        .tag = capture_tag,
+                        .data = .{ .switch_capture = .{
+                            .switch_inst = switch_block,
+                            .prong_index = capture_index,
+                        } },
+                    });
+                }
+                const capture_name = try astgen.identAsString(ident);
+                try astgen.detectLocalShadowing(&case_scope.base, capture_name, ident, ident_slice);
+                capture_val_scope = .{
+                    .parent = &case_scope.base,
+                    .gen_zir = &case_scope,
+                    .name = capture_name,
+                    .inst = indexToRef(capture_inst),
+                    .token_src = payload_token,
+                    .id_cat = .@"capture",
                 };
-                const capture_index = if (is_multi_case) multi_case_index else scalar_case_index;
-                capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
-                try astgen.instructions.append(gpa, .{
-                    .tag = capture_tag,
-                    .data = .{ .switch_capture = .{
-                        .switch_inst = switch_block,
-                        .prong_index = capture_index,
-                    } },
-                });
+                dbg_var_name = capture_name;
+                dbg_var_inst = indexToRef(capture_inst);
+                payload_sub_scope = &capture_val_scope.base;
             }
-            const capture_name = try astgen.identAsString(ident);
-            capture_val_scope = .{
-                .parent = &case_scope.base,
+
+            const tag_token = if (token_tags[ident + 1] == .comma)
+                ident + 2
+            else
+                break :blk payload_sub_scope;
+            const tag_slice = tree.tokenSlice(tag_token);
+            if (mem.eql(u8, tag_slice, "_")) {
+                return astgen.failTok(tag_token, "discard of tag capture; omit it instead", .{});
+            } else if (case.inline_token == null) {
+                return astgen.failTok(tag_token, "tag capture on non-inline prong", .{});
+            }
+            const tag_name = try astgen.identAsString(tag_token);
+            try astgen.detectLocalShadowing(payload_sub_scope, tag_name, tag_token, tag_slice);
+            tag_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
+            try astgen.instructions.append(gpa, .{
+                .tag = .switch_capture_tag,
+                .data = .{ .un_tok = .{
+                    .operand = cond,
+                    .src_tok = case_scope.tokenIndexToRelative(tag_token),
+                } },
+            });
+
+            tag_scope = .{
+                .parent = payload_sub_scope,
                 .gen_zir = &case_scope,
-                .name = capture_name,
-                .inst = indexToRef(capture_inst),
-                .token_src = payload_token,
-                .id_cat = .@"capture",
+                .name = tag_name,
+                .inst = indexToRef(tag_inst),
+                .token_src = tag_token,
+                .id_cat = .@"switch tag capture",
             };
-            dbg_var_name = capture_name;
-            dbg_var_inst = indexToRef(capture_inst);
-            break :blk &capture_val_scope.base;
+            dbg_var_tag_name = tag_name;
+            dbg_var_tag_inst = indexToRef(tag_inst);
+            break :blk &tag_scope.base;
         };
 
         const header_index = @intCast(u32, payloads.items.len);
@@ -6494,10 +6536,14 @@ fn switchExpr(
             defer case_scope.unstack();
 
             if (capture_inst != 0) try case_scope.instructions.append(gpa, capture_inst);
+            if (tag_inst != 0) try case_scope.instructions.append(gpa, tag_inst);
             try case_scope.addDbgBlockBegin();
             if (dbg_var_name) |some| {
                 try case_scope.addDbgVar(.dbg_var_val, some, dbg_var_inst);
             }
+            if (dbg_var_tag_name) |some| {
+                try case_scope.addDbgVar(.dbg_var_val, some, dbg_var_tag_inst);
+            }
             const case_result = try expr(&case_scope, sub_scope, block_scope.break_result_loc, case.ast.target_expr);
             try checkUsed(parent_gz, &case_scope.base, sub_scope);
             try case_scope.addDbgBlockEnd();
@@ -10073,6 +10119,7 @@ const Scope = struct {
         @"local constant",
         @"local variable",
         @"loop index capture",
+        @"switch tag capture",
         @"capture",
     };
 
src/print_zir.zig
@@ -237,6 +237,7 @@ const Writer = struct {
             .ret_tok,
             .ensure_err_payload_void,
             .closure_capture,
+            .switch_capture_tag,
             => try self.writeUnTok(stream, inst),
 
             .bool_br_and,
src/Sema.zig
@@ -799,6 +799,7 @@ fn analyzeBodyInner(
             .switch_capture_ref           => try sema.zirSwitchCapture(block, inst, false, true),
             .switch_capture_multi         => try sema.zirSwitchCapture(block, inst, true, false),
             .switch_capture_multi_ref     => try sema.zirSwitchCapture(block, inst, true, true),
+            .switch_capture_tag           => try sema.zirSwitchCaptureTag(block, inst),
             .type_info                    => try sema.zirTypeInfo(block, inst),
             .size_of                      => try sema.zirSizeOf(block, inst),
             .bit_size_of                  => try sema.zirBitSizeOf(block, inst),
@@ -9164,6 +9165,33 @@ fn zirSwitchCapture(
     }
 }
 
+fn zirSwitchCaptureTag(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
+    const zir_datas = sema.code.instructions.items(.data);
+    const inst_data = zir_datas[inst].un_tok;
+    const src = inst_data.src();
+
+    const switch_tag = sema.code.instructions.items(.tag)[Zir.refToIndex(inst_data.operand).?];
+    const is_ref = switch_tag == .switch_cond_ref;
+    const cond_data = zir_datas[Zir.refToIndex(inst_data.operand).?].un_node;
+    const operand_ptr = try sema.resolveInst(cond_data.operand);
+    const operand_ptr_ty = sema.typeOf(operand_ptr);
+    const operand_ty = if (is_ref) operand_ptr_ty.childType() else operand_ptr_ty;
+
+    if (operand_ty.zigTypeTag() != .Union) {
+        const msg = msg: {
+            const msg = try sema.errMsg(block, src, "cannot capture tag of non-union type '{}'", .{
+                operand_ty.fmt(sema.mod),
+            });
+            errdefer msg.destroy(sema.gpa);
+            try sema.addDeclaredHereNote(msg, operand_ty);
+            break :msg msg;
+        };
+        return sema.failWithOwnedErrorMsg(msg);
+    }
+
+    return block.inline_case_capture;
+}
+
 fn zirSwitchCond(
     sema: *Sema,
     block: *Block,
src/Zir.zig
@@ -683,6 +683,9 @@ pub const Inst = struct {
         /// Result is a pointer to the value.
         /// Uses the `switch_capture` field.
         switch_capture_multi_ref,
+        /// Produces the capture value for an inline switch prong tag capture.
+        /// Uses the `un_tok` field.
+        switch_capture_tag,
         /// Given a
         ///   *A returns *A
         ///   *E!A returns *A
@@ -1128,6 +1131,7 @@ pub const Inst = struct {
                 .switch_capture_ref,
                 .switch_capture_multi,
                 .switch_capture_multi_ref,
+                .switch_capture_tag,
                 .switch_block,
                 .switch_cond,
                 .switch_cond_ref,
@@ -1422,6 +1426,7 @@ pub const Inst = struct {
                 .switch_capture_ref,
                 .switch_capture_multi,
                 .switch_capture_multi_ref,
+                .switch_capture_tag,
                 .switch_block,
                 .switch_cond,
                 .switch_cond_ref,
@@ -1681,6 +1686,7 @@ pub const Inst = struct {
                 .switch_capture_ref = .switch_capture,
                 .switch_capture_multi = .switch_capture,
                 .switch_capture_multi_ref = .switch_capture,
+                .switch_capture_tag = .un_tok,
                 .array_base_ptr = .un_node,
                 .field_base_ptr = .un_node,
                 .validate_array_init_ty = .pl_node,
test/behavior/inline_switch.zig
@@ -47,11 +47,21 @@ test "inline switch unions" {
 
     var x: U = .a;
     switch (x) {
-        inline .a, .b => |aorb| {
-            try expect(@TypeOf(aorb) == void or @TypeOf(aorb) == u2);
+        inline .a, .b => |aorb, tag| {
+            if (tag == .a) {
+                try expect(@TypeOf(aorb) == void);
+            } else {
+                try expect(tag == .b);
+                try expect(@TypeOf(aorb) == u2);
+            }
         },
-        inline .c, .d => |cord| {
-            try expect(@TypeOf(cord) == u3 or @TypeOf(cord) == u4);
+        inline .c, .d => |cord, tag| {
+            if (tag == .c) {
+                try expect(@TypeOf(cord) == u3);
+            } else {
+                try expect(tag == .d);
+                try expect(@TypeOf(cord) == u4);
+            }
         },
     }
 }
test/cases/compile_errors/invalid_tag_capture.zig
@@ -0,0 +1,15 @@
+const E = enum { a, b, c, d };
+pub export fn entry() void {
+    var x: E = .a;
+    switch (x) {
+        inline .a, .b => |aorb, d| @compileLog(aorb, d),
+        inline .c, .d => |*cord| @compileLog(cord),
+    }
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :5:33: error: cannot capture tag of non-union type 'tmp.E'
+// :1:11: note: enum declared here
test/cases/compile_errors/tag_capture_on_non_inline_prong.zig
@@ -0,0 +1,14 @@
+const E = enum { a, b, c, d };
+pub export fn entry() void {
+    var x: E = .a;
+    switch (x) {
+        .a, .b => |aorb, d| @compileLog(aorb, d),
+        inline .c, .d => |*cord| @compileLog(cord),
+    }
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :5:26: error: tag capture on non-inline prong