Commit 8509e7111d

Andrew Kelley <andrew@ziglang.org>
2021-10-26 00:06:47
stage2: fix switch on tagged union capture-by-pointer
* AstGen: always use `typeof` and never `typeof_elem` on the `switch_cond`/`switch_cond_ref` instruction because both variants return a value and not a pointer. - Delete the `typeof_elem` ZIR instruction since it is no longer needed. * Sema: validateUnionInit now recognizes a comptime mutable value and no longer emits a compile error saying "cannot evaluate constant expression" - Still to-do is detecting comptime union values in a function that is not being executed at compile-time. - This is still to-do for structs too. * Sema: when emitting a call AIR instruction, call resolveTypeLayout on all the parameter types as well as the return type. * `Type.structFieldOffset` now works for unions in addition to structs.
1 parent a132190
src/AstGen.zig
@@ -2109,7 +2109,6 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: Ast.Node.Index) Inner
             .negate,
             .negate_wrap,
             .typeof,
-            .typeof_elem,
             .xor,
             .optional_type,
             .optional_payload_safe,
@@ -6028,8 +6027,7 @@ fn switchExpr(
     const cond_tag: Zir.Inst.Tag = if (any_payload_is_ref) .switch_cond_ref else .switch_cond;
     const cond = try parent_gz.addUnNode(cond_tag, raw_operand, operand_node);
     // We need the type of the operand to use as the result location for all the prong items.
-    const typeof_tag: Zir.Inst.Tag = if (any_payload_is_ref) .typeof_elem else .typeof;
-    const cond_ty_inst = try parent_gz.addUnNode(typeof_tag, cond, operand_node);
+    const cond_ty_inst = try parent_gz.addUnNode(.typeof, cond, operand_node);
     const item_rl: ResultLoc = .{ .ty = cond_ty_inst };
 
     // These contain the data that goes into the `extra` array for the SwitchBlock/SwitchBlockMulti.
@@ -6214,7 +6212,7 @@ fn switchExpr(
             .has_multi_cases = multi_cases_len != 0,
             .has_else = special_prong == .@"else",
             .has_under = special_prong == .under,
-            .scalar_cases_len = @intCast(u28, scalar_cases_len),
+            .scalar_cases_len = @intCast(Zir.Inst.SwitchBlock.Bits.ScalarCasesLen, scalar_cases_len),
         },
     });
 
src/print_zir.zig
@@ -184,7 +184,6 @@ const Writer = struct {
             .is_non_err,
             .is_non_err_ptr,
             .typeof,
-            .typeof_elem,
             .struct_init_empty,
             .type_info,
             .size_of,
src/Sema.zig
@@ -608,7 +608,6 @@ pub fn analyzeBody(
             .size_of                      => try sema.zirSizeOf(block, inst),
             .bit_size_of                  => try sema.zirBitSizeOf(block, inst),
             .typeof                       => try sema.zirTypeof(block, inst),
-            .typeof_elem                  => try sema.zirTypeofElem(block, inst),
             .log2_int_type                => try sema.zirLog2IntType(block, inst),
             .typeof_log2_int_type         => try sema.zirTypeofLog2IntType(block, inst),
             .xor                          => try sema.zirBitwise(block, inst, .xor),
@@ -2337,11 +2336,21 @@ fn validateUnionInit(
         return sema.failWithBadUnionFieldAccess(block, union_obj, field_src, field_name);
     const field_index = @intCast(u32, field_index_big);
 
-    // TODO here we need to go back and see if we need to convert the union
-    // to a comptime-known value. This will involve editing the AIR code we have
-    // generated so far - in particular deleting some runtime pointer bitcast
-    // instructions which are not actually needed if the initialization expression
-    // ends up being comptime-known.
+    // Handle the possibility of the union value being comptime-known.
+    const union_ptr_inst = Air.refToIndex(sema.resolveInst(field_ptr_extra.lhs)).?;
+    switch (sema.air_instructions.items(.tag)[union_ptr_inst]) {
+        .constant => return, // In this case the tag has already been set. No validation to do.
+        .bitcast => {
+            // TODO here we need to go back and see if we need to convert the union
+            // to a comptime-known value. In such case, we must delete all the instructions
+            // added to the current block starting with the bitcast.
+            // If the bitcast result ptr is an alloc, the alloc should be replaced with
+            // a constant decl_ref.
+            // Otherwise, the bitcast should be preserved and a store instruction should be
+            // emitted to store the constant union value through the bitcast.
+        },
+        else => unreachable,
+    }
 
     // Otherwise, we set the new union tag now.
     const new_tag = try sema.addConstant(
@@ -4091,18 +4100,20 @@ fn analyzeCall(
             zir_tags,
         );
     } else res: {
+        try sema.requireRuntimeBlock(block, call_src);
+
         const args = try sema.arena.alloc(Air.Inst.Ref, uncasted_args.len);
         for (uncasted_args) |uncasted_arg, i| {
+            const arg_src = call_src; // TODO: better source location
             if (i < fn_params_len) {
                 const param_ty = func_ty.fnParamType(i);
-                const arg_src = call_src; // TODO: better source location
+                try sema.resolveTypeLayout(block, arg_src, param_ty);
                 args[i] = try sema.coerce(block, param_ty, uncasted_arg, arg_src);
             } else {
                 args[i] = uncasted_arg;
             }
         }
 
-        try sema.requireRuntimeBlock(block, call_src);
         try sema.resolveTypeLayout(block, call_src, func_ty_info.return_type);
 
         try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.Call).Struct.fields.len +
@@ -4173,6 +4184,7 @@ fn finishGenericCall(
                 const param_ty = new_fn_ty.fnParamType(runtime_i);
                 const arg_src = call_src; // TODO: better source location
                 const uncasted_arg = uncasted_args[total_i];
+                try sema.resolveTypeLayout(block, arg_src, param_ty);
                 const casted_arg = try sema.coerce(block, param_ty, uncasted_arg, arg_src);
                 runtime_args[runtime_i] = casted_arg;
                 runtime_i += 1;
@@ -5548,7 +5560,7 @@ fn zirSwitchCapture(
                     );
                 }
                 try sema.requireRuntimeBlock(block, operand_src);
-                return block.addStructFieldPtr(operand_ptr, field_index, field.ty);
+                return block.addStructFieldPtr(operand_ptr, field_index, field_ty_ptr);
             }
 
             const operand = if (operand_is_ref)
@@ -5669,11 +5681,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     const special_prong_src: LazySrcLoc = .{ .node_offset_switch_special_prong = src_node_offset };
     const extra = sema.code.extraData(Zir.Inst.SwitchBlock, inst_data.payload_index);
 
-    const operand_ptr = sema.resolveInst(extra.data.operand);
-    const operand = if (extra.data.bits.is_ref)
-        try sema.analyzeLoad(block, src, operand_ptr, operand_src)
-    else
-        operand_ptr;
+    const operand = sema.resolveInst(extra.data.operand);
 
     var header_extra_index: usize = extra.end;
 
@@ -8675,14 +8683,6 @@ fn zirTypeof(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
     return sema.addType(operand_ty);
 }
 
-fn zirTypeofElem(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
-    _ = block;
-    const inst_data = sema.code.instructions.items(.data)[inst].un_node;
-    const operand_ptr = sema.resolveInst(inst_data.operand);
-    const elem_ty = sema.typeOf(operand_ptr).elemType();
-    return sema.addType(elem_ty);
-}
-
 fn zirTypeofLog2IntType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const src = inst_data.src();
src/type.zig
@@ -3391,34 +3391,49 @@ pub const Type = extern union {
         }
     }
 
+    /// Supports structs and unions.
     pub fn structFieldOffset(ty: Type, index: usize, target: Target) u64 {
-        const fields = ty.structFields();
-        if (ty.castTag(.@"struct")) |payload| {
-            const struct_obj = payload.data;
-            assert(struct_obj.status == .have_layout);
-            const is_packed = struct_obj.layout == .Packed;
-            if (is_packed) @panic("TODO packed structs");
-        }
+        switch (ty.tag()) {
+            .@"struct" => {
+                const struct_obj = ty.castTag(.@"struct").?.data;
+                assert(struct_obj.status == .have_layout);
+                const is_packed = struct_obj.layout == .Packed;
+                if (is_packed) @panic("TODO packed structs");
 
-        var offset: u64 = 0;
-        var big_align: u32 = 0;
-        for (fields.values()) |field, i| {
-            if (!field.ty.hasCodeGenBits()) continue;
+                var offset: u64 = 0;
+                var big_align: u32 = 0;
+                for (struct_obj.fields.values()) |field, i| {
+                    if (!field.ty.hasCodeGenBits()) continue;
 
-            const field_align = a: {
-                if (field.abi_align.tag() == .abi_align_default) {
-                    break :a field.ty.abiAlignment(target);
+                    const field_align = a: {
+                        if (field.abi_align.tag() == .abi_align_default) {
+                            break :a field.ty.abiAlignment(target);
+                        } else {
+                            break :a @intCast(u32, field.abi_align.toUnsignedInt());
+                        }
+                    };
+                    big_align = @maximum(big_align, field_align);
+                    offset = std.mem.alignForwardGeneric(u64, offset, field_align);
+                    if (i == index) return offset;
+                    offset += field.ty.abiSize(target);
+                }
+                offset = std.mem.alignForwardGeneric(u64, offset, big_align);
+                return offset;
+            },
+            .@"union" => return 0,
+            .union_tagged => {
+                const union_obj = ty.castTag(.union_tagged).?.data;
+                const layout = union_obj.getLayout(target, true);
+                if (layout.tag_align >= layout.payload_align) {
+                    // {Tag, Payload}
+                    return std.mem.alignForwardGeneric(u64, layout.tag_size, layout.payload_align);
                 } else {
-                    break :a @intCast(u32, field.abi_align.toUnsignedInt());
+                    // {Payload, Tag}
+                    return 0;
                 }
-            };
-            big_align = @maximum(big_align, field_align);
-            offset = std.mem.alignForwardGeneric(u64, offset, field_align);
-            if (i == index) return offset;
-            offset += field.ty.abiSize(target);
+            },
+            else => unreachable,
         }
-        offset = std.mem.alignForwardGeneric(u64, offset, big_align);
-        return offset;
     }
 
     pub fn declSrcLoc(ty: Type) Module.SrcLoc {
src/Zir.zig
@@ -544,9 +544,6 @@ pub const Inst = struct {
         /// Returns the type of a value.
         /// Uses the `un_node` field.
         typeof,
-        /// Given a value which is a pointer, returns the element type.
-        /// Uses the `un_node` field.
-        typeof_elem,
         /// Given a value, look at the type of it, which must be an integer type.
         /// Returns the integer type for the RHS of a shift operation.
         /// Uses the `un_node` field.
@@ -1045,7 +1042,6 @@ pub const Inst = struct {
                 .negate,
                 .negate_wrap,
                 .typeof,
-                .typeof_elem,
                 .xor,
                 .optional_type,
                 .optional_payload_safe,
@@ -1312,7 +1308,6 @@ pub const Inst = struct {
                 .negate = .un_node,
                 .negate_wrap = .un_node,
                 .typeof = .un_node,
-                .typeof_elem = .un_node,
                 .typeof_log2_int_type = .un_node,
                 .log2_int_type = .un_node,
                 .@"unreachable" = .@"unreachable",
@@ -2443,6 +2438,13 @@ pub const Inst = struct {
     ///        body member Index for every body_len
     ///    }
     pub const SwitchBlock = struct {
+        /// This is always a `switch_cond` or `switch_cond_ref` instruction.
+        /// If it is a `switch_cond_ref` instruction, bits.is_ref is always true.
+        /// If it is a `switch_cond` instruction, bits.is_ref is always false.
+        /// Both `switch_cond` and `switch_cond_ref` return a value, not a pointer,
+        /// that is useful for the case items, but cannot be used for capture values.
+        /// For the capture values, Sema is expected to find the operand of this operand
+        /// and use that.
         operand: Ref,
         bits: Bits,
 
@@ -2454,8 +2456,11 @@ pub const Inst = struct {
             /// If true, there is an underscore prong. This is mutually exclusive with `has_else`.
             has_under: bool,
             /// If true, the `operand` is a pointer to the value being switched on.
+            /// TODO this flag is redundant with the tag of operand and can be removed.
             is_ref: bool,
-            scalar_cases_len: u28,
+            scalar_cases_len: ScalarCasesLen,
+
+            pub const ScalarCasesLen = u28;
 
             pub fn specialProng(bits: Bits) SpecialProng {
                 const has_else: u2 = @boolToInt(bits.has_else);
test/behavior/switch.zig
@@ -219,3 +219,46 @@ test "switch on global mutable var isn't constant-folded" {
         poll();
     }
 }
+
+const SwitchProngWithVarEnum = union(enum) {
+    One: i32,
+    Two: f32,
+    Meh: void,
+};
+
+test "switch prong with variable" {
+    try switchProngWithVarFn(SwitchProngWithVarEnum{ .One = 13 });
+    try switchProngWithVarFn(SwitchProngWithVarEnum{ .Two = 13.0 });
+    try switchProngWithVarFn(SwitchProngWithVarEnum{ .Meh = {} });
+}
+fn switchProngWithVarFn(a: SwitchProngWithVarEnum) !void {
+    switch (a) {
+        SwitchProngWithVarEnum.One => |x| {
+            try expect(x == 13);
+        },
+        SwitchProngWithVarEnum.Two => |x| {
+            try expect(x == 13.0);
+        },
+        SwitchProngWithVarEnum.Meh => |x| {
+            const v: void = x;
+            _ = v;
+        },
+    }
+}
+
+test "switch on enum using pointer capture" {
+    try testSwitchEnumPtrCapture();
+    comptime try testSwitchEnumPtrCapture();
+}
+
+fn testSwitchEnumPtrCapture() !void {
+    var value = SwitchProngWithVarEnum{ .One = 1234 };
+    switch (value) {
+        SwitchProngWithVarEnum.One => |*x| x.* += 1,
+        else => unreachable,
+    }
+    switch (value) {
+        SwitchProngWithVarEnum.One => |x| try expect(x == 1235),
+        else => unreachable,
+    }
+}
test/behavior/switch_stage1.zig
@@ -3,48 +3,6 @@ const expect = std.testing.expect;
 const expectError = std.testing.expectError;
 const expectEqual = std.testing.expectEqual;
 
-test "switch prong with variable" {
-    try switchProngWithVarFn(SwitchProngWithVarEnum{ .One = 13 });
-    try switchProngWithVarFn(SwitchProngWithVarEnum{ .Two = 13.0 });
-    try switchProngWithVarFn(SwitchProngWithVarEnum{ .Meh = {} });
-}
-const SwitchProngWithVarEnum = union(enum) {
-    One: i32,
-    Two: f32,
-    Meh: void,
-};
-fn switchProngWithVarFn(a: SwitchProngWithVarEnum) !void {
-    switch (a) {
-        SwitchProngWithVarEnum.One => |x| {
-            try expect(x == 13);
-        },
-        SwitchProngWithVarEnum.Two => |x| {
-            try expect(x == 13.0);
-        },
-        SwitchProngWithVarEnum.Meh => |x| {
-            const v: void = x;
-            _ = v;
-        },
-    }
-}
-
-test "switch on enum using pointer capture" {
-    try testSwitchEnumPtrCapture();
-    comptime try testSwitchEnumPtrCapture();
-}
-
-fn testSwitchEnumPtrCapture() !void {
-    var value = SwitchProngWithVarEnum{ .One = 1234 };
-    switch (value) {
-        SwitchProngWithVarEnum.One => |*x| x.* += 1,
-        else => unreachable,
-    }
-    switch (value) {
-        SwitchProngWithVarEnum.One => |x| try expect(x == 1235),
-        else => unreachable,
-    }
-}
-
 test "switch handles all cases of number" {
     try testSwitchHandleAllCases();
     comptime try testSwitchHandleAllCases();