Commit 05da5b32a8

Isaac Freund <mail@isaacfreund.com>
2023-02-20 23:31:48
Sema: implement @fieldParentPtr for unions
1 parent dc1f50e
Changed files (7)
src/arch/arm/CodeGen.zig
@@ -2973,6 +2973,11 @@ fn airFieldParentPtr(self: *Self, inst: Air.Inst.Index) !void {
     const result: MCValue = if (self.liveness.isUnused(inst)) .dead else result: {
         const field_ptr = try self.resolveInst(extra.field_ptr);
         const struct_ty = self.air.getRefType(ty_pl.ty).childType();
+
+        if (struct_ty.zigTypeTag() == .Union) {
+            return self.fail("TODO implement @fieldParentPtr codegen for unions", .{});
+        }
+
         const struct_field_offset = @intCast(u32, struct_ty.structFieldOffset(extra.field_index, self.target.*));
         switch (field_ptr) {
             .ptr_stack_offset => |off| {
src/arch/wasm/CodeGen.zig
@@ -4944,8 +4944,8 @@ fn airFieldParentPtr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
     if (func.liveness.isUnused(inst)) return func.finishAir(inst, .none, &.{extra.field_ptr});
 
     const field_ptr = try func.resolveInst(extra.field_ptr);
-    const struct_ty = func.air.getRefType(ty_pl.ty).childType();
-    const field_offset = struct_ty.structFieldOffset(extra.field_index, func.target);
+    const parent_ty = func.air.getRefType(ty_pl.ty).childType();
+    const field_offset = parent_ty.structFieldOffset(extra.field_index, func.target);
 
     const result = if (field_offset != 0) result: {
         const base = try func.buildPointerOffset(field_ptr, 0, .new);
src/codegen/c.zig
@@ -5367,12 +5367,18 @@ fn airFieldParentPtr(f: *Function, inst: Air.Inst.Index) !CValue {
     }
 
     const struct_ptr_ty = f.air.typeOfIndex(inst);
+
     const field_ptr_ty = f.air.typeOf(extra.field_ptr);
     const field_ptr_val = try f.resolveInst(extra.field_ptr);
     try reap(f, inst, &.{extra.field_ptr});
 
     const target = f.object.dg.module.getTarget();
     const struct_ty = struct_ptr_ty.childType();
+
+    if (struct_ty.zigTypeTag() == .Union) {
+        return f.fail("TODO: CBE: @fieldParentPtr for unions", .{});
+    }
+
     const field_offset = struct_ty.structFieldOffset(extra.field_index, target);
 
     var field_offset_pl = Value.Payload.I64{
src/codegen/llvm.zig
@@ -6020,8 +6020,8 @@ pub const FuncGen = struct {
         const field_ptr = try self.resolveInst(extra.field_ptr);
 
         const target = self.dg.module.getTarget();
-        const struct_ty = self.air.getRefType(ty_pl.ty).childType();
-        const field_offset = struct_ty.structFieldOffset(extra.field_index, target);
+        const parent_ty = self.air.getRefType(ty_pl.ty).childType();
+        const field_offset = parent_ty.structFieldOffset(extra.field_index, target);
 
         const res_ty = try self.dg.lowerType(self.air.getRefType(ty_pl.ty));
         if (field_offset == 0) {
src/Sema.zig
@@ -21482,24 +21482,32 @@ fn zirFieldParentPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileEr
     const name_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
     const ptr_src: LazySrcLoc = .{ .node_offset_builtin_call_arg2 = inst_data.src_node };
 
-    const struct_ty = try sema.resolveType(block, ty_src, extra.parent_type);
+    const parent_ty = try sema.resolveType(block, ty_src, extra.parent_type);
     const field_name = try sema.resolveConstString(block, name_src, extra.field_name, "field name must be comptime-known");
     const field_ptr = try sema.resolveInst(extra.field_ptr);
     const field_ptr_ty = sema.typeOf(field_ptr);
 
-    if (struct_ty.zigTypeTag() != .Struct) {
-        return sema.fail(block, ty_src, "expected struct type, found '{}'", .{struct_ty.fmt(sema.mod)});
+    if (parent_ty.zigTypeTag() != .Struct and parent_ty.zigTypeTag() != .Union) {
+        return sema.fail(block, ty_src, "expected struct or union type, found '{}'", .{parent_ty.fmt(sema.mod)});
     }
-    try sema.resolveTypeLayout(struct_ty);
+    try sema.resolveTypeLayout(parent_ty);
 
-    const field_index = if (struct_ty.isTuple()) blk: {
-        if (mem.eql(u8, field_name, "len")) {
-            return sema.fail(block, src, "cannot get @fieldParentPtr of 'len' field of tuple", .{});
-        }
-        break :blk try sema.tupleFieldIndex(block, struct_ty, field_name, name_src);
-    } else try sema.structFieldIndex(block, struct_ty, field_name, name_src);
+    const field_index = switch (parent_ty.zigTypeTag()) {
+        .Struct => blk: {
+            if (parent_ty.isTuple()) {
+                if (mem.eql(u8, field_name, "len")) {
+                    return sema.fail(block, src, "cannot get @fieldParentPtr of 'len' field of tuple", .{});
+                }
+                break :blk try sema.tupleFieldIndex(block, parent_ty, field_name, name_src);
+            } else {
+                break :blk try sema.structFieldIndex(block, parent_ty, field_name, name_src);
+            }
+        },
+        .Union => try sema.unionFieldIndex(block, parent_ty, field_name, name_src),
+        else => unreachable,
+    };
 
-    if (struct_ty.structFieldIsComptime(field_index)) {
+    if (parent_ty.zigTypeTag() == .Struct and parent_ty.structFieldIsComptime(field_index)) {
         return sema.fail(block, src, "cannot get @fieldParentPtr of a comptime field", .{});
     }
 
@@ -21507,23 +21515,29 @@ fn zirFieldParentPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileEr
     const field_ptr_ty_info = field_ptr_ty.ptrInfo().data;
 
     var ptr_ty_data: Type.Payload.Pointer.Data = .{
-        .pointee_type = struct_ty.structFieldType(field_index),
+        .pointee_type = parent_ty.structFieldType(field_index),
         .mutable = field_ptr_ty_info.mutable,
         .@"addrspace" = field_ptr_ty_info.@"addrspace",
     };
 
-    if (struct_ty.containerLayout() == .Packed) {
-        return sema.fail(block, src, "TODO handle packed structs with @fieldParentPtr", .{});
+    if (parent_ty.containerLayout() == .Packed) {
+        return sema.fail(block, src, "TODO handle packed structs/unions with @fieldParentPtr", .{});
     } else {
-        ptr_ty_data.@"align" = if (struct_ty.castTag(.@"struct")) |struct_obj| b: {
-            break :b struct_obj.data.fields.values()[field_index].abi_align;
-        } else 0;
+        ptr_ty_data.@"align" = blk: {
+            if (parent_ty.castTag(.@"struct")) |struct_obj| {
+                break :blk struct_obj.data.fields.values()[field_index].abi_align;
+            } else if (parent_ty.cast(Type.Payload.Union)) |union_obj| {
+                break :blk union_obj.data.fields.values()[field_index].abi_align;
+            } else {
+                break :blk 0;
+            }
+        };
     }
 
     const actual_field_ptr_ty = try Type.ptr(sema.arena, sema.mod, ptr_ty_data);
     const casted_field_ptr = try sema.coerce(block, actual_field_ptr_ty, field_ptr, ptr_src);
 
-    ptr_ty_data.pointee_type = struct_ty;
+    ptr_ty_data.pointee_type = parent_ty;
     const result_ptr = try Type.ptr(sema.arena, sema.mod, ptr_ty_data);
 
     if (try sema.resolveDefinedValue(block, src, casted_field_ptr)) |field_ptr_val| {
@@ -21540,11 +21554,11 @@ fn zirFieldParentPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileEr
                         field_name,
                         field_index,
                         payload.data.field_index,
-                        struct_ty.fmt(sema.mod),
+                        parent_ty.fmt(sema.mod),
                     },
                 );
                 errdefer msg.destroy(sema.gpa);
-                try sema.addDeclaredHereNote(msg, struct_ty);
+                try sema.addDeclaredHereNote(msg, parent_ty);
                 break :msg msg;
             };
             return sema.failWithOwnedErrorMsg(msg);
test/behavior/field_parent_ptr.zig
@@ -44,3 +44,84 @@ fn testParentFieldPtrFirst(a: *const bool) !void {
     try expect(base == &foo);
     try expect(&base.a == a);
 }
+
+test "@fieldParentPtr untagged union" {
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    try testFieldParentPtrUnion(&bar.c);
+    comptime try testFieldParentPtrUnion(&bar.c);
+}
+
+const Bar = union(enum) {
+    a: bool,
+    b: f32,
+    c: i32,
+    d: i32,
+};
+
+const bar = Bar{ .c = 42 };
+
+fn testFieldParentPtrUnion(c: *const i32) !void {
+    try expect(c == &bar.c);
+
+    const base = @fieldParentPtr(Bar, "c", c);
+    try expect(base == &bar);
+    try expect(&base.c == c);
+}
+
+test "@fieldParentPtr tagged union" {
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    try testFieldParentPtrTaggedUnion(&bar_tagged.c);
+    comptime try testFieldParentPtrTaggedUnion(&bar_tagged.c);
+}
+
+const BarTagged = union(enum) {
+    a: bool,
+    b: f32,
+    c: i32,
+    d: i32,
+};
+
+const bar_tagged = BarTagged{ .c = 42 };
+
+fn testFieldParentPtrTaggedUnion(c: *const i32) !void {
+    try expect(c == &bar_tagged.c);
+
+    const base = @fieldParentPtr(BarTagged, "c", c);
+    try expect(base == &bar_tagged);
+    try expect(&base.c == c);
+}
+
+test "@fieldParentPtr extern union" {
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    try testFieldParentPtrExternUnion(&bar_extern.c);
+    comptime try testFieldParentPtrExternUnion(&bar_extern.c);
+}
+
+const BarExtern = extern union {
+    a: bool,
+    b: f32,
+    c: i32,
+    d: i32,
+};
+
+const bar_extern = BarExtern{ .c = 42 };
+
+fn testFieldParentPtrExternUnion(c: *const i32) !void {
+    try expect(c == &bar_extern.c);
+
+    const base = @fieldParentPtr(BarExtern, "c", c);
+    try expect(base == &bar_extern);
+    try expect(&base.c == c);
+}
test/cases/compile_errors/fieldParentPtr-non_struct.zig
@@ -7,4 +7,4 @@ export fn foo(a: *i32) *Foo {
 // backend=llvm
 // target=native
 //
-// :3:28: error: expected struct type, found 'i32'
+// :3:28: error: expected struct or union type, found 'i32'