Commit f668c8bfd6

Jacob Young <jacobly0@users.noreply.github.com>
2024-04-06 09:45:23
x86_64: fix abi of nested structs
1 parent 34bb670
Changed files (3)
src
arch
x86_64
test
src/arch/x86_64/abi.zig
@@ -13,7 +13,7 @@ pub const Class = enum {
     integer_per_element,
 };
 
-pub fn classifyWindows(ty: Type, mod: *Module) Class {
+pub fn classifyWindows(ty: Type, zcu: *Zcu) Class {
     // https://docs.microsoft.com/en-gb/cpp/build/x64-calling-convention?view=vs-2017
     // "There's a strict one-to-one correspondence between a function call's arguments
     // and the registers used for those arguments. Any argument that doesn't fit in 8
@@ -22,7 +22,7 @@ pub fn classifyWindows(ty: Type, mod: *Module) Class {
     // "All floating point operations are done using the 16 XMM registers."
     // "Structs and unions of size 8, 16, 32, or 64 bits, and __m64 types, are passed
     // as if they were integers of the same size."
-    switch (ty.zigTypeTag(mod)) {
+    switch (ty.zigTypeTag(zcu)) {
         .Pointer,
         .Int,
         .Bool,
@@ -37,12 +37,12 @@ pub fn classifyWindows(ty: Type, mod: *Module) Class {
         .ErrorUnion,
         .AnyFrame,
         .Frame,
-        => switch (ty.abiSize(mod)) {
+        => switch (ty.abiSize(zcu)) {
             0 => unreachable,
             1, 2, 4, 8 => return .integer,
-            else => switch (ty.zigTypeTag(mod)) {
+            else => switch (ty.zigTypeTag(zcu)) {
                 .Int => return .win_i128,
-                .Struct, .Union => if (ty.containerLayout(mod) == .@"packed") {
+                .Struct, .Union => if (ty.containerLayout(zcu) == .@"packed") {
                     return .win_i128;
                 } else {
                     return .memory;
@@ -69,16 +69,16 @@ pub const Context = enum { ret, arg, field, other };
 
 /// There are a maximum of 8 possible return slots. Returned values are in
 /// the beginning of the array; unused slots are filled with .none.
-pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
-    const ip = &mod.intern_pool;
-    const target = mod.getTarget();
+pub fn classifySystemV(ty: Type, zcu: *Zcu, ctx: Context) [8]Class {
+    const ip = &zcu.intern_pool;
+    const target = zcu.getTarget();
     const memory_class = [_]Class{
         .memory, .none, .none, .none,
         .none,   .none, .none, .none,
     };
     var result = [1]Class{.none} ** 8;
-    switch (ty.zigTypeTag(mod)) {
-        .Pointer => switch (ty.ptrSize(mod)) {
+    switch (ty.zigTypeTag(zcu)) {
+        .Pointer => switch (ty.ptrSize(zcu)) {
             .Slice => {
                 result[0] = .integer;
                 result[1] = .integer;
@@ -90,7 +90,7 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
             },
         },
         .Int, .Enum, .ErrorSet => {
-            const bits = ty.intInfo(mod).bits;
+            const bits = ty.intInfo(zcu).bits;
             if (bits <= 64) {
                 result[0] = .integer;
                 return result;
@@ -160,8 +160,8 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
             else => unreachable,
         },
         .Vector => {
-            const elem_ty = ty.childType(mod);
-            const bits = elem_ty.bitSize(mod) * ty.arrayLen(mod);
+            const elem_ty = ty.childType(zcu);
+            const bits = elem_ty.bitSize(zcu) * ty.arrayLen(zcu);
             if (elem_ty.toIntern() == .bool_type) {
                 if (bits <= 32) return .{
                     .integer, .none, .none, .none,
@@ -225,7 +225,7 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
             return memory_class;
         },
         .Optional => {
-            if (ty.isPtrLikeOptional(mod)) {
+            if (ty.isPtrLikeOptional(zcu)) {
                 result[0] = .integer;
                 return result;
             }
@@ -236,9 +236,9 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
             // it contains unaligned fields, it has class MEMORY"
             // "If the size of the aggregate exceeds a single eightbyte, each is classified
             // separately.".
-            const struct_type = mod.typeToStruct(ty).?;
-            const ty_size = ty.abiSize(mod);
-            if (struct_type.layout == .@"packed") {
+            const loaded_struct = ip.loadStructType(ty.toIntern());
+            const ty_size = ty.abiSize(zcu);
+            if (loaded_struct.layout == .@"packed") {
                 assert(ty_size <= 16);
                 result[0] = .integer;
                 if (ty_size > 8) result[1] = .integer;
@@ -247,82 +247,8 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
             if (ty_size > 64)
                 return memory_class;
 
-            var result_i: usize = 0; // out of 8
-            var byte_i: usize = 0; // out of 8
-            for (struct_type.field_types.get(ip), 0..) |field_ty_ip, i| {
-                const field_ty = Type.fromInterned(field_ty_ip);
-                const field_align = struct_type.fieldAlign(ip, i);
-                if (field_align != .none and field_align.compare(.lt, field_ty.abiAlignment(mod)))
-                    return memory_class;
-                const field_size = field_ty.abiSize(mod);
-                const field_class_array = classifySystemV(field_ty, mod, .field);
-                const field_class = std.mem.sliceTo(&field_class_array, .none);
-                if (byte_i + field_size <= 8) {
-                    // Combine this field with the previous one.
-                    combine: {
-                        // "If both classes are equal, this is the resulting class."
-                        if (result[result_i] == field_class[0]) {
-                            if (result[result_i] == .float) {
-                                result[result_i] = .float_combine;
-                            }
-                            break :combine;
-                        }
-
-                        // "If one of the classes is NO_CLASS, the resulting class
-                        // is the other class."
-                        if (result[result_i] == .none) {
-                            result[result_i] = field_class[0];
-                            break :combine;
-                        }
-                        assert(field_class[0] != .none);
-
-                        // "If one of the classes is MEMORY, the result is the MEMORY class."
-                        if (result[result_i] == .memory or field_class[0] == .memory) {
-                            result[result_i] = .memory;
-                            break :combine;
-                        }
-
-                        // "If one of the classes is INTEGER, the result is the INTEGER."
-                        if (result[result_i] == .integer or field_class[0] == .integer) {
-                            result[result_i] = .integer;
-                            break :combine;
-                        }
-
-                        // "If one of the classes is X87, X87UP, COMPLEX_X87 class,
-                        // MEMORY is used as class."
-                        if (result[result_i] == .x87 or
-                            result[result_i] == .x87up or
-                            result[result_i] == .complex_x87 or
-                            field_class[0] == .x87 or
-                            field_class[0] == .x87up or
-                            field_class[0] == .complex_x87)
-                        {
-                            result[result_i] = .memory;
-                            break :combine;
-                        }
-
-                        // "Otherwise class SSE is used."
-                        result[result_i] = .sse;
-                    }
-                    byte_i += @as(usize, @intCast(field_size));
-                    if (byte_i == 8) {
-                        byte_i = 0;
-                        result_i += 1;
-                    }
-                } else {
-                    // Cannot combine this field with the previous one.
-                    if (byte_i != 0) {
-                        byte_i = 0;
-                        result_i += 1;
-                    }
-                    @memcpy(result[result_i..][0..field_class.len], field_class);
-                    result_i += field_class.len;
-                    // If there are any bytes leftover, we have to try to combine
-                    // the next field with them.
-                    byte_i = @as(usize, @intCast(field_size % 8));
-                    if (byte_i != 0) result_i -= 1;
-                }
-            }
+            var byte_offset: u64 = 0;
+            classifySystemVStruct(&result, &byte_offset, loaded_struct, zcu);
 
             // Post-merger cleanup
 
@@ -354,8 +280,8 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
             // it contains unaligned fields, it has class MEMORY"
             // "If the size of the aggregate exceeds a single eightbyte, each is classified
             // separately.".
-            const union_obj = mod.typeToUnion(ty).?;
-            const ty_size = mod.unionAbiSize(union_obj);
+            const union_obj = zcu.typeToUnion(ty).?;
+            const ty_size = zcu.unionAbiSize(union_obj);
             if (union_obj.getLayout(ip) == .@"packed") {
                 assert(ty_size <= 16);
                 result[0] = .integer;
@@ -368,12 +294,12 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
             for (union_obj.field_types.get(ip), 0..) |field_ty, field_index| {
                 const field_align = union_obj.fieldAlign(ip, @intCast(field_index));
                 if (field_align != .none and
-                    field_align.compare(.lt, Type.fromInterned(field_ty).abiAlignment(mod)))
+                    field_align.compare(.lt, Type.fromInterned(field_ty).abiAlignment(zcu)))
                 {
                     return memory_class;
                 }
                 // Combine this field with the previous one.
-                const field_class = classifySystemV(Type.fromInterned(field_ty), mod, .field);
+                const field_class = classifySystemV(Type.fromInterned(field_ty), zcu, .field);
                 for (&result, 0..) |*result_item, i| {
                     const field_item = field_class[i];
                     // "If both classes are equal, this is the resulting class."
@@ -447,7 +373,7 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
             return result;
         },
         .Array => {
-            const ty_size = ty.abiSize(mod);
+            const ty_size = ty.abiSize(zcu);
             if (ty_size <= 8) {
                 result[0] = .integer;
                 return result;
@@ -463,6 +389,82 @@ pub fn classifySystemV(ty: Type, mod: *Module, ctx: Context) [8]Class {
     }
 }
 
+fn classifySystemVStruct(
+    result: *[8]Class,
+    byte_offset: *u64,
+    loaded_struct: InternPool.LoadedStructType,
+    zcu: *Zcu,
+) void {
+    const ip = &zcu.intern_pool;
+    var field_it = loaded_struct.iterateRuntimeOrder(ip);
+    while (field_it.next()) |field_index| {
+        const field_ty = Type.fromInterned(loaded_struct.field_types.get(ip)[field_index]);
+        const field_align = loaded_struct.fieldAlign(ip, field_index);
+        byte_offset.* = std.mem.alignForward(
+            u64,
+            byte_offset.*,
+            field_align.toByteUnits() orelse field_ty.abiAlignment(zcu).toByteUnits().?,
+        );
+        if (zcu.typeToStruct(field_ty)) |field_loaded_struct| {
+            if (field_loaded_struct.layout != .@"packed") {
+                classifySystemVStruct(result, byte_offset, field_loaded_struct, zcu);
+                continue;
+            }
+        }
+        const field_class = std.mem.sliceTo(&classifySystemV(field_ty, zcu, .field), .none);
+        const field_size = field_ty.abiSize(zcu);
+        combine: {
+            // Combine this field with the previous one.
+            const result_class = &result[@intCast(byte_offset.* / 8)];
+            // "If both classes are equal, this is the resulting class."
+            if (result_class.* == field_class[0]) {
+                if (result_class.* == .float) {
+                    result_class.* = .float_combine;
+                }
+                break :combine;
+            }
+
+            // "If one of the classes is NO_CLASS, the resulting class
+            // is the other class."
+            if (result_class.* == .none) {
+                result_class.* = field_class[0];
+                break :combine;
+            }
+            assert(field_class[0] != .none);
+
+            // "If one of the classes is MEMORY, the result is the MEMORY class."
+            if (result_class.* == .memory or field_class[0] == .memory) {
+                result_class.* = .memory;
+                break :combine;
+            }
+
+            // "If one of the classes is INTEGER, the result is the INTEGER."
+            if (result_class.* == .integer or field_class[0] == .integer) {
+                result_class.* = .integer;
+                break :combine;
+            }
+
+            // "If one of the classes is X87, X87UP, COMPLEX_X87 class,
+            // MEMORY is used as class."
+            if (result_class.* == .x87 or
+                result_class.* == .x87up or
+                result_class.* == .complex_x87 or
+                field_class[0] == .x87 or
+                field_class[0] == .x87up or
+                field_class[0] == .complex_x87)
+            {
+                result_class.* = .memory;
+                break :combine;
+            }
+
+            // "Otherwise class SSE is used."
+            result_class.* = .sse;
+        }
+        @memcpy(result[@intCast(byte_offset.* / 8 + 1)..][0 .. field_class.len - 1], field_class[1..]);
+        byte_offset.* += field_size;
+    }
+}
+
 pub const SysV = struct {
     /// Note that .rsp and .rbp also belong to this set, however, we never expect to use them
     /// for anything else but stack offset tracking therefore we exclude them from this set.
@@ -592,8 +594,9 @@ const std = @import("std");
 const assert = std.debug.assert;
 const testing = std.testing;
 
-const Module = @import("../../Module.zig");
+const InternPool = @import("../../InternPool.zig");
 const Register = @import("bits.zig").Register;
 const RegisterManagerFn = @import("../../register_manager.zig").RegisterManager;
 const Type = @import("../../type.zig").Type;
 const Value = @import("../../Value.zig");
+const Zcu = @import("../../Module.zig");
test/c_abi/cfuncs.c
@@ -227,6 +227,48 @@ void c_struct_u64_u64_8(size_t, size_t, size_t, size_t, size_t, size_t, size_t,
     assert_or_panic(s.b == 40);
 }
 
+struct Struct_f32f32_f32 {
+    struct {
+        float b, c;
+    } a;
+    float d;
+};
+
+struct Struct_f32f32_f32 zig_ret_struct_f32f32_f32(void);
+
+void zig_struct_f32f32_f32(struct Struct_f32f32_f32);
+
+struct Struct_f32f32_f32 c_ret_struct_f32f32_f32(void) {
+    return (struct Struct_f32f32_f32){ { 1.0f, 2.0f }, 3.0f };
+}
+
+void c_struct_f32f32_f32(struct Struct_f32f32_f32 s) {
+    assert_or_panic(s.a.b == 1.0f);
+    assert_or_panic(s.a.c == 2.0f);
+    assert_or_panic(s.d == 3.0f);
+}
+
+struct Struct_f32_f32f32 {
+    float a;
+    struct {
+        float c, d;
+    } b;
+};
+
+struct Struct_f32_f32f32 zig_ret_struct_f32_f32f32(void);
+
+void zig_struct_f32_f32f32(struct Struct_f32_f32f32);
+
+struct Struct_f32_f32f32 c_ret_struct_f32_f32f32(void) {
+    return (struct Struct_f32_f32f32){ 1.0f, { 2.0f, 3.0f } };
+}
+
+void c_struct_f32_f32f32(struct Struct_f32_f32f32 s) {
+    assert_or_panic(s.a == 1.0f);
+    assert_or_panic(s.b.c == 2.0f);
+    assert_or_panic(s.b.d == 3.0f);
+}
+
 struct BigStruct {
     uint64_t a;
     uint64_t b;
@@ -2603,9 +2645,25 @@ void run_c_tests(void) {
         zig_struct_u64_u64_7(0, 1, 2, 3, 4, 5, 6, (struct Struct_u64_u64){ .a = 17, .b = 18 });
         zig_struct_u64_u64_8(0, 1, 2, 3, 4, 5, 6, 7, (struct Struct_u64_u64){ .a = 19, .b = 20 });
     }
+
+#if !defined(ZIG_RISCV64)
+    {
+        struct Struct_f32f32_f32 s = zig_ret_struct_f32f32_f32();
+        assert_or_panic(s.a.b == 1.0f);
+        assert_or_panic(s.a.c == 2.0f);
+        assert_or_panic(s.d == 3.0f);
+        zig_struct_f32f32_f32((struct Struct_f32f32_f32){ { 1.0f, 2.0f }, 3.0f });
+    }
+
+    {
+        struct Struct_f32_f32f32 s = zig_ret_struct_f32_f32f32();
+        assert_or_panic(s.a == 1.0f);
+        assert_or_panic(s.b.c == 2.0f);
+        assert_or_panic(s.b.d == 3.0f);
+        zig_struct_f32_f32f32((struct Struct_f32_f32f32){ 1.0f, { 2.0f, 3.0f } });
+    }
 #endif
 
-#if !defined __mips__ && !defined ZIG_PPC32
     {
         struct BigStruct s = {1, 2, 3, 4, 5};
         zig_big_struct(s);
test/c_abi/main.zig
@@ -273,6 +273,7 @@ const Struct_u64_u64 = extern struct {
 export fn zig_ret_struct_u64_u64() Struct_u64_u64 {
     return .{ .a = 1, .b = 2 };
 }
+
 export fn zig_struct_u64_u64_0(s: Struct_u64_u64) void {
     expect(s.a == 3) catch @panic("test failure");
     expect(s.b == 4) catch @panic("test failure");
@@ -326,11 +327,9 @@ test "C ABI struct u64 u64" {
     if (builtin.cpu.arch.isMIPS()) return error.SkipZigTest;
     if (builtin.cpu.arch.isPPC()) return error.SkipZigTest;
 
-    {
-        const s = c_ret_struct_u64_u64();
-        try expect(s.a == 21);
-        try expect(s.b == 22);
-    }
+    const s = c_ret_struct_u64_u64();
+    try expect(s.a == 21);
+    try expect(s.b == 22);
     c_struct_u64_u64_0(.{ .a = 23, .b = 24 });
     c_struct_u64_u64_1(0, .{ .a = 25, .b = 26 });
     c_struct_u64_u64_2(0, 1, .{ .a = 27, .b = 28 });
@@ -342,6 +341,66 @@ test "C ABI struct u64 u64" {
     c_struct_u64_u64_8(0, 1, 2, 3, 4, 5, 6, 7, .{ .a = 39, .b = 40 });
 }
 
+const Struct_f32f32_f32 = extern struct {
+    a: extern struct { b: f32, c: f32 },
+    d: f32,
+};
+
+export fn zig_ret_struct_f32f32_f32() Struct_f32f32_f32 {
+    return .{ .a = .{ .b = 1.0, .c = 2.0 }, .d = 3.0 };
+}
+
+export fn zig_struct_f32f32_f32(s: Struct_f32f32_f32) void {
+    expect(s.a.b == 1.0) catch @panic("test failure");
+    expect(s.a.c == 2.0) catch @panic("test failure");
+    expect(s.d == 3.0) catch @panic("test failure");
+}
+
+extern fn c_ret_struct_f32f32_f32() Struct_f32f32_f32;
+
+extern fn c_struct_f32f32_f32(Struct_f32f32_f32) void;
+
+test "C ABI struct {f32,f32} f32" {
+    if (builtin.cpu.arch.isMIPS()) return error.SkipZigTest;
+    if (builtin.cpu.arch.isPPC()) return error.SkipZigTest;
+
+    const s = c_ret_struct_f32f32_f32();
+    try expect(s.a.b == 1.0);
+    try expect(s.a.c == 2.0);
+    try expect(s.d == 3.0);
+    c_struct_f32f32_f32(.{ .a = .{ .b = 1.0, .c = 2.0 }, .d = 3.0 });
+}
+
+const Struct_f32_f32f32 = extern struct {
+    a: f32,
+    b: extern struct { c: f32, d: f32 },
+};
+
+export fn zig_ret_struct_f32_f32f32() Struct_f32_f32f32 {
+    return .{ .a = 1.0, .b = .{ .c = 2.0, .d = 3.0 } };
+}
+
+export fn zig_struct_f32_f32f32(s: Struct_f32_f32f32) void {
+    expect(s.a == 1.0) catch @panic("test failure");
+    expect(s.b.c == 2.0) catch @panic("test failure");
+    expect(s.b.d == 3.0) catch @panic("test failure");
+}
+
+extern fn c_ret_struct_f32_f32f32() Struct_f32_f32f32;
+
+extern fn c_struct_f32_f32f32(Struct_f32_f32f32) void;
+
+test "C ABI struct f32 {f32,f32}" {
+    if (builtin.cpu.arch.isMIPS()) return error.SkipZigTest;
+    if (builtin.cpu.arch.isPPC()) return error.SkipZigTest;
+
+    const s = c_ret_struct_f32_f32f32();
+    try expect(s.a == 1.0);
+    try expect(s.b.c == 2.0);
+    try expect(s.b.d == 3.0);
+    c_struct_f32_f32f32(.{ .a = 1.0, .b = .{ .c = 2.0, .d = 3.0 } });
+}
+
 const BigStruct = extern struct {
     a: u64,
     b: u64,