Commit 646d927c79

Veikka Tuominen <git@vexu.eu>
2022-10-20 12:29:58
stage2: fix handling of aarch64 C ABI float array like structs
Closes #11702 Closes #13125
1 parent 07b6173
Changed files (4)
src
arch
aarch64
codegen
test
src/arch/aarch64/abi.zig
@@ -5,29 +5,21 @@ const Register = bits.Register;
 const RegisterManagerFn = @import("../../register_manager.zig").RegisterManager;
 const Type = @import("../../type.zig").Type;
 
-pub const Class = enum { memory, integer, none, float_array };
+pub const Class = enum(u8) { memory, integer, none, float_array, _ };
 
+/// For `float_array` the second element will be the amount of floats.
 pub fn classifyType(ty: Type, target: std.Target) [2]Class {
+    var maybe_float_bits: ?u16 = null;
+    const float_count = countFloats(ty, target, &maybe_float_bits);
+    if (float_count <= sret_float_count) return .{ .float_array, @intToEnum(Class, float_count) };
+    return classifyTypeInner(ty, target);
+}
+
+fn classifyTypeInner(ty: Type, target: std.Target) [2]Class {
     if (!ty.hasRuntimeBitsIgnoreComptime()) return .{ .none, .none };
     switch (ty.zigTypeTag()) {
         .Struct => {
             if (ty.containerLayout() == .Packed) return .{ .integer, .none };
-
-            if (ty.structFieldCount() <= 4) {
-                const fields = ty.structFields();
-                var float_size: ?u64 = null;
-                for (fields.values()) |field| {
-                    if (field.ty.zigTypeTag() != .Float) break;
-                    const field_size = field.ty.bitSize(target);
-                    const prev_size = float_size orelse {
-                        float_size = field_size;
-                        continue;
-                    };
-                    if (field_size != prev_size) break;
-                } else {
-                    return .{ .float_array, .none };
-                }
-            }
             const bit_size = ty.bitSize(target);
             if (bit_size > 128) return .{ .memory, .none };
             if (bit_size > 64) return .{ .integer, .integer };
@@ -67,6 +59,70 @@ pub fn classifyType(ty: Type, target: std.Target) [2]Class {
     }
 }
 
+const sret_float_count = 4;
+fn countFloats(ty: Type, target: std.Target, maybe_float_bits: *?u16) u32 {
+    const invalid = std.math.maxInt(u32);
+    switch (ty.zigTypeTag()) {
+        .Union => {
+            const fields = ty.unionFields();
+            var max_count: u32 = 0;
+            for (fields.values()) |field| {
+                const field_count = countFloats(field.ty, target, maybe_float_bits);
+                if (field_count == invalid) return invalid;
+                if (field_count > max_count) max_count = field_count;
+                if (max_count > sret_float_count) return invalid;
+            }
+            return max_count;
+        },
+        .Struct => {
+            const fields_len = ty.structFieldCount();
+            var count: u32 = 0;
+            var i: u32 = 0;
+            while (i < fields_len) : (i += 1) {
+                const field_ty = ty.structFieldType(i);
+                const field_count = countFloats(field_ty, target, maybe_float_bits);
+                if (field_count == invalid) return invalid;
+                count += field_count;
+                if (count > sret_float_count) return invalid;
+            }
+            return count;
+        },
+        .Float => {
+            const float_bits = maybe_float_bits.* orelse {
+                maybe_float_bits.* = ty.floatBits(target);
+                return 1;
+            };
+            if (ty.floatBits(target) == float_bits) return 1;
+            return invalid;
+        },
+        .Void => return 0,
+        else => return invalid,
+    }
+}
+
+pub fn getFloatArrayType(ty: Type) ?Type {
+    switch (ty.zigTypeTag()) {
+        .Union => {
+            const fields = ty.unionFields();
+            for (fields.values()) |field| {
+                if (getFloatArrayType(field.ty)) |some| return some;
+            }
+            return null;
+        },
+        .Struct => {
+            const fields_len = ty.structFieldCount();
+            var i: u32 = 0;
+            while (i < fields_len) : (i += 1) {
+                const field_ty = ty.structFieldType(i);
+                if (getFloatArrayType(field_ty)) |some| return some;
+            }
+            return null;
+        },
+        .Float => return ty,
+        else => return null,
+    }
+}
+
 const callee_preserved_regs_impl = if (builtin.os.tag.isDarwin()) struct {
     pub const callee_preserved_regs = [_]Register{
         .x20, .x21, .x22, .x23,
src/codegen/llvm.zig
@@ -3125,10 +3125,10 @@ pub const DeclGen = struct {
             .as_u16 => {
                 try llvm_params.append(dg.context.intType(16));
             },
-            .float_array => {
+            .float_array => |count| {
                 const param_ty = fn_info.param_types[it.zig_index - 1];
-                const float_ty = try dg.lowerType(param_ty.structFieldType(0));
-                const field_count = @intCast(c_uint, param_ty.structFieldCount());
+                const float_ty = try dg.lowerType(aarch64_c_abi.getFloatArrayType(param_ty).?);
+                const field_count = @intCast(c_uint, count);
                 const arr_ty = float_ty.arrayType(field_count);
                 try llvm_params.append(arr_ty);
             },
@@ -4801,7 +4801,7 @@ pub const FuncGen = struct {
                 const casted = self.builder.buildBitCast(llvm_arg, self.dg.context.intType(16), "");
                 try llvm_args.append(casted);
             },
-            .float_array => {
+            .float_array => |count| {
                 const arg = args[it.zig_index - 1];
                 const arg_ty = self.air.typeOf(arg);
                 var llvm_arg = try self.resolveInst(arg);
@@ -4812,9 +4812,8 @@ pub const FuncGen = struct {
                     llvm_arg = store_inst;
                 }
 
-                const float_ty = try self.dg.lowerType(arg_ty.structFieldType(0));
-                const field_count = @intCast(u32, arg_ty.structFieldCount());
-                const array_llvm_ty = float_ty.arrayType(field_count);
+                const float_ty = try self.dg.lowerType(aarch64_c_abi.getFloatArrayType(arg_ty).?);
+                const array_llvm_ty = float_ty.arrayType(count);
 
                 const casted = self.builder.buildBitCast(llvm_arg, array_llvm_ty.pointerType(0), "");
                 const alignment = arg_ty.abiAlignment(target);
@@ -10214,7 +10213,7 @@ const ParamTypeIterator = struct {
     llvm_types_buffer: [8]u16,
     byval_attr: bool,
 
-    const Lowering = enum {
+    const Lowering = union(enum) {
         no_bits,
         byval,
         byref,
@@ -10223,7 +10222,7 @@ const ParamTypeIterator = struct {
         multiple_llvm_float,
         slice,
         as_u16,
-        float_array,
+        float_array: u8,
     };
 
     pub fn next(it: *ParamTypeIterator) ?Lowering {
@@ -10400,7 +10399,7 @@ const ParamTypeIterator = struct {
                             return .byref;
                         }
                         if (classes[0] == .float_array) {
-                            return .float_array;
+                            return Lowering{ .float_array = @enumToInt(classes[1]) };
                         }
                         if (classes[1] == .none) {
                             it.llvm_types_len = 1;
test/c_abi/cfuncs.c
@@ -650,3 +650,30 @@ void c_struct_with_array(StructWithArray x) {
 StructWithArray c_ret_struct_with_array() {
     return (StructWithArray) { 4, {}, 155 };
 }
+
+typedef struct {
+    struct Point {
+        double x;
+        double y;
+    } origin;
+    struct Size {
+        double width;
+        double height;
+    } size;
+} FloatArrayStruct;
+
+void c_float_array_struct(FloatArrayStruct x) {
+    assert_or_panic(x.origin.x == 5);
+    assert_or_panic(x.origin.y == 6);
+    assert_or_panic(x.size.width == 7);
+    assert_or_panic(x.size.height == 8);
+}
+
+FloatArrayStruct c_ret_float_array_struct() {
+    FloatArrayStruct x;
+    x.origin.x = 1;
+    x.origin.y = 2;
+    x.size.width = 3;
+    x.size.height = 4;
+    return x;
+}
test/c_abi/main.zig
@@ -700,3 +700,36 @@ test "Struct with array as padding." {
     try std.testing.expect(x.a == 4);
     try std.testing.expect(x.b == 155);
 }
+
+const FloatArrayStruct = extern struct {
+    origin: extern struct {
+        x: f64,
+        y: f64,
+    },
+    size: extern struct {
+        width: f64,
+        height: f64,
+    },
+};
+
+extern fn c_float_array_struct(FloatArrayStruct) void;
+extern fn c_ret_float_array_struct() FloatArrayStruct;
+
+test "Float array like struct" {
+    c_float_array_struct(.{
+        .origin = .{
+            .x = 5,
+            .y = 6,
+        },
+        .size = .{
+            .width = 7,
+            .height = 8,
+        },
+    });
+
+    var x = c_ret_float_array_struct();
+    try std.testing.expect(x.origin.x == 1);
+    try std.testing.expect(x.origin.y == 2);
+    try std.testing.expect(x.size.width == 3);
+    try std.testing.expect(x.size.height == 4);
+}