Commit 06b1a88a15

Andrew Kelley <andrew@ziglang.org>
2022-03-03 04:17:09
Sema: implement cast from anon struct ptr to union ptr
1 parent ac7028f
Changed files (3)
src
test
behavior
src/Sema.zig
@@ -2733,7 +2733,7 @@ fn zirValidateStructInit(
         ),
         .Union => return sema.validateUnionInit(
             block,
-            agg_ty.cast(Type.Payload.Union).?.data,
+            agg_ty,
             init_src,
             instrs,
             object_ptr,
@@ -2746,12 +2746,14 @@ fn zirValidateStructInit(
 fn validateUnionInit(
     sema: *Sema,
     block: *Block,
-    union_obj: *Module.Union,
+    union_ty: Type,
     init_src: LazySrcLoc,
     instrs: []const Zir.Inst.Index,
     union_ptr: Air.Inst.Ref,
     is_comptime: bool,
 ) CompileError!void {
+    const union_obj = union_ty.cast(Type.Payload.Union).?.data;
+
     if (instrs.len != 1) {
         const msg = msg: {
             const msg = try sema.errMsg(
@@ -2767,7 +2769,7 @@ fn validateUnionInit(
                 const inst_src: LazySrcLoc = .{ .node_offset_back2tok = inst_data.src_node };
                 try sema.errNote(block, inst_src, msg, "additional initializer here", .{});
             }
-            try sema.mod.errNoteNonLazy(union_obj.srcLoc(), msg, "union declared here", .{});
+            try sema.addDeclaredHereNote(msg, union_ty);
             break :msg msg;
         };
         return sema.failWithOwnedErrorMsg(msg);
@@ -2783,9 +2785,7 @@ fn validateUnionInit(
     const field_src: LazySrcLoc = .{ .node_offset_back2tok = field_ptr_data.src_node };
     const field_ptr_extra = sema.code.extraData(Zir.Inst.Field, field_ptr_data.payload_index).data;
     const field_name = sema.code.nullTerminatedString(field_ptr_extra.field_name_start);
-    const field_index_big = union_obj.fields.getIndex(field_name) orelse
-        return sema.failWithBadUnionFieldAccess(block, union_obj, field_src, field_name);
-    const field_index = @intCast(u32, field_index_big);
+    const field_index = try sema.unionFieldIndex(block, union_ty, field_name, field_src);
     const air_tags = sema.air_instructions.items(.tag);
     const air_datas = sema.air_instructions.items(.data);
     const field_ptr_air_ref = sema.inst_map.get(field_ptr).?;
@@ -2844,7 +2844,6 @@ fn validateUnionInit(
             .tag = tag_val,
             .val = val,
         });
-        const union_ty = sema.typeOf(union_ptr).childType();
         const union_init = try sema.addConstant(union_ty, union_val);
         try sema.storePtr2(block, init_src, union_ptr, init_src, union_init, init_src, .store);
         return;
@@ -11494,10 +11493,7 @@ fn unionInit(
     field_name: []const u8,
     field_src: LazySrcLoc,
 ) CompileError!Air.Inst.Ref {
-    const union_obj = union_ty.cast(Type.Payload.Union).?.data;
-    const field_index_usize = union_obj.fields.getIndex(field_name) orelse
-        return sema.failWithBadUnionFieldAccess(block, union_obj, field_src, field_name);
-    const field_index = @intCast(u32, field_index_usize);
+    const field_index = try sema.unionFieldIndex(block, union_ty, field_name, field_src);
 
     if (try sema.resolveMaybeUndefVal(block, init_src, init)) |init_val| {
         const tag_val = try Value.Tag.enum_field_index.create(sema.arena, field_index);
@@ -11594,9 +11590,7 @@ fn zirStructInit(
             }
         }
         return sema.finishStructInit(block, src, field_inits, root_msg, struct_obj, resolved_ty, is_ref);
-    } else if (resolved_ty.cast(Type.Payload.Union)) |union_payload| {
-        const union_obj = union_payload.data;
-
+    } else if (resolved_ty.zigTypeTag() == .Union) {
         if (extra.data.fields_len != 1) {
             return sema.fail(block, src, "union initialization expects exactly one field", .{});
         }
@@ -11607,9 +11601,7 @@ fn zirStructInit(
         const field_src: LazySrcLoc = .{ .node_offset_back2tok = field_type_data.src_node };
         const field_type_extra = sema.code.extraData(Zir.Inst.FieldType, field_type_data.payload_index).data;
         const field_name = sema.code.nullTerminatedString(field_type_extra.name_start);
-        const field_index_usize = union_obj.fields.getIndex(field_name) orelse
-            return sema.failWithBadUnionFieldAccess(block, union_obj, field_src, field_name);
-        const field_index = @intCast(u32, field_index_usize);
+        const field_index = try sema.unionFieldIndex(block, resolved_ty, field_name, field_src);
 
         if (is_ref) {
             return sema.fail(block, src, "TODO: Sema.zirStructInit is_ref=true union", .{});
@@ -11732,7 +11724,11 @@ fn zirStructInitAnon(
 
     if (is_ref) {
         const target = sema.mod.getTarget();
-        const alloc = try block.addTy(.alloc, tuple_ty);
+        const alloc_ty = try Type.ptr(sema.arena, target, .{
+            .pointee_type = tuple_ty,
+            .@"addrspace" = target_util.defaultAddressSpace(target, .local),
+        });
+        const alloc = try block.addTy(.alloc, alloc_ty);
         var extra_index = extra.end;
         for (types) |field_ty, i_usize| {
             const i = @intCast(u32, i_usize);
@@ -11882,7 +11878,11 @@ fn zirArrayInitAnon(
 
     if (is_ref) {
         const target = sema.mod.getTarget();
-        const alloc = try block.addTy(.alloc, tuple_ty);
+        const alloc_ty = try Type.ptr(sema.arena, target, .{
+            .pointee_type = tuple_ty,
+            .@"addrspace" = target_util.defaultAddressSpace(target, .local),
+        });
+        const alloc = try block.addTy(.alloc, alloc_ty);
         for (operands) |operand, i_usize| {
             const i = @intCast(u32, i_usize);
             const field_ptr_ty = try Type.ptr(sema.arena, target, .{
@@ -15220,11 +15220,7 @@ fn unionFieldPtr(
     const union_ptr_ty = sema.typeOf(union_ptr);
     const union_ty = try sema.resolveTypeFields(block, src, unresolved_union_ty);
     const union_obj = union_ty.cast(Type.Payload.Union).?.data;
-
-    const field_index_big = union_obj.fields.getIndex(field_name) orelse
-        return sema.failWithBadUnionFieldAccess(block, union_obj, field_name_src, field_name);
-    const field_index = @intCast(u32, field_index_big);
-
+    const field_index = try sema.unionFieldIndex(block, union_ty, field_name, field_name_src);
     const field = union_obj.fields.values()[field_index];
     const target = sema.mod.getTarget();
     const ptr_field_ty = try Type.ptr(arena, target, .{
@@ -15286,10 +15282,7 @@ fn unionFieldVal(
 
     const union_ty = try sema.resolveTypeFields(block, src, unresolved_union_ty);
     const union_obj = union_ty.cast(Type.Payload.Union).?.data;
-
-    const field_index_usize = union_obj.fields.getIndex(field_name) orelse
-        return sema.failWithBadUnionFieldAccess(block, union_obj, field_name_src, field_name);
-    const field_index = @intCast(u32, field_index_usize);
+    const field_index = try sema.unionFieldIndex(block, union_ty, field_name, field_name_src);
     const field = union_obj.fields.values()[field_index];
 
     if (try sema.resolveMaybeUndefVal(block, src, union_byval)) |union_val| {
@@ -15799,6 +15792,15 @@ fn coerce(
                 return sema.coerceCompatiblePtrs(block, dest_ty, inst, inst_src);
             }
 
+            // cast from pointer to anonymous struct to pointer to union
+            if (dest_info.pointee_type.zigTypeTag() == .Union and
+                inst_ty.zigTypeTag() == .Pointer and
+                inst_ty.childType().tag() == .anon_struct and
+                !dest_info.mutable)
+            {
+                return sema.coerceAnonStructToUnionPtrs(block, dest_ty, dest_ty_src, inst, inst_src);
+            }
+
             // This will give an extra hint on top of what the bottom of this func would provide.
             try sema.checkPtrOperand(block, dest_ty_src, inst_ty);
         },
@@ -15956,8 +15958,8 @@ fn coerce(
         .Union => switch (inst_ty.zigTypeTag()) {
             .Enum, .EnumLiteral => return sema.coerceEnumToUnion(block, dest_ty, dest_ty_src, inst, inst_src),
             .Struct => {
-                if (inst_ty.castTag(.anon_struct)) |anon_struct| {
-                    return sema.coerceAnonStructToUnion(block, dest_ty, dest_ty_src, inst, inst_src, anon_struct.data);
+                if (inst_ty.isAnonStruct()) {
+                    return sema.coerceAnonStructToUnion(block, dest_ty, dest_ty_src, inst, inst_src);
                 }
             },
             else => {},
@@ -17047,8 +17049,9 @@ fn coerceAnonStructToUnion(
     union_ty_src: LazySrcLoc,
     inst: Air.Inst.Ref,
     inst_src: LazySrcLoc,
-    anon_struct: Type.Payload.AnonStruct.Data,
 ) !Air.Inst.Ref {
+    const inst_ty = sema.typeOf(inst);
+    const anon_struct = inst_ty.castTag(.anon_struct).?.data;
     if (anon_struct.types.len != 1) {
         const msg = msg: {
             const msg = try sema.errMsg(
@@ -17069,11 +17072,24 @@ fn coerceAnonStructToUnion(
     }
 
     const field_name = anon_struct.names[0];
-    const inst_ty = sema.typeOf(inst);
     const init = try sema.structFieldVal(block, inst_src, inst, field_name, inst_src, inst_ty);
     return sema.unionInit(block, init, inst_src, union_ty, union_ty_src, field_name, inst_src);
 }
 
+fn coerceAnonStructToUnionPtrs(
+    sema: *Sema,
+    block: *Block,
+    ptr_union_ty: Type,
+    union_ty_src: LazySrcLoc,
+    ptr_anon_struct: Air.Inst.Ref,
+    anon_struct_src: LazySrcLoc,
+) !Air.Inst.Ref {
+    const union_ty = ptr_union_ty.childType();
+    const anon_struct = try sema.analyzeLoad(block, anon_struct_src, ptr_anon_struct, anon_struct_src);
+    const union_inst = try sema.coerceAnonStructToUnion(block, union_ty, union_ty_src, anon_struct, anon_struct_src);
+    return sema.analyzeRef(block, union_ty_src, union_inst);
+}
+
 /// If the lengths match, coerces element-wise.
 fn coerceArrayLike(
     sema: *Sema,
@@ -20080,3 +20096,17 @@ pub fn fnHasRuntimeBits(sema: *Sema, block: *Block, src: LazySrcLoc, ty: Type) C
     }
     return true;
 }
+
+fn unionFieldIndex(
+    sema: *Sema,
+    block: *Block,
+    unresolved_union_ty: Type,
+    field_name: []const u8,
+    field_src: LazySrcLoc,
+) !u32 {
+    const union_ty = try sema.resolveTypeFields(block, field_src, unresolved_union_ty);
+    const union_obj = union_ty.cast(Type.Payload.Union).?.data;
+    const field_index_usize = union_obj.fields.getIndex(field_name) orelse
+        return sema.failWithBadUnionFieldAccess(block, union_obj, field_src, field_name);
+    return @intCast(u32, field_index_usize);
+}
src/type.zig
@@ -5036,6 +5036,13 @@ pub const Type = extern union {
         };
     }
 
+    pub fn isAnonStruct(ty: Type) bool {
+        return switch (ty.tag()) {
+            .anon_struct => true,
+            else => false,
+        };
+    }
+
     pub fn isTupleOrAnonStruct(ty: Type) bool {
         return switch (ty.tag()) {
             .tuple, .empty_struct_literal, .anon_struct => true,
test/behavior/union.zig
@@ -978,7 +978,11 @@ test "cast from anonymous struct to union" {
 }
 
 test "cast from pointer to anonymous struct to pointer to union" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
 
     const S = struct {
         const U = union(enum) {