master
  1const std = @import("std");
  2const assert = std.debug.assert;
  3const RegisterManagerFn = @import("../../register_manager.zig").RegisterManager;
  4const Type = @import("../../Type.zig");
  5const Zcu = @import("../../Zcu.zig");
  6
  7pub const Class = union(enum) {
  8    memory,
  9    byval,
 10    i32_array: u8,
 11    i64_array: u8,
 12
 13    fn arrSize(total_size: u64, arr_size: u64) Class {
 14        const count = @as(u8, @intCast(std.mem.alignForward(u64, total_size, arr_size) / arr_size));
 15        if (arr_size == 32) {
 16            return .{ .i32_array = count };
 17        } else {
 18            return .{ .i64_array = count };
 19        }
 20    }
 21};
 22
 23pub const Context = enum { ret, arg };
 24
 25pub fn classifyType(ty: Type, zcu: *Zcu, ctx: Context) Class {
 26    assert(ty.hasRuntimeBitsIgnoreComptime(zcu));
 27
 28    var maybe_float_bits: ?u16 = null;
 29    const max_byval_size = 512;
 30    const ip = &zcu.intern_pool;
 31    switch (ty.zigTypeTag(zcu)) {
 32        .@"struct" => {
 33            const bit_size = ty.bitSize(zcu);
 34            if (ty.containerLayout(zcu) == .@"packed") {
 35                if (bit_size > 64) return .memory;
 36                return .byval;
 37            }
 38            if (bit_size > max_byval_size) return .memory;
 39            const float_count = countFloats(ty, zcu, &maybe_float_bits);
 40            if (float_count <= byval_float_count) return .byval;
 41
 42            const fields = ty.structFieldCount(zcu);
 43            var i: u32 = 0;
 44            while (i < fields) : (i += 1) {
 45                const field_ty = ty.fieldType(i, zcu);
 46                const field_alignment = ty.fieldAlignment(i, zcu);
 47                const field_size = field_ty.bitSize(zcu);
 48                if (field_size > 32 or field_alignment.compare(.gt, .@"32")) {
 49                    return Class.arrSize(bit_size, 64);
 50                }
 51            }
 52            return Class.arrSize(bit_size, 32);
 53        },
 54        .@"union" => {
 55            const bit_size = ty.bitSize(zcu);
 56            const union_obj = zcu.typeToUnion(ty).?;
 57            if (union_obj.flagsUnordered(ip).layout == .@"packed") {
 58                if (bit_size > 64) return .memory;
 59                return .byval;
 60            }
 61            if (bit_size > max_byval_size) return .memory;
 62            const float_count = countFloats(ty, zcu, &maybe_float_bits);
 63            if (float_count <= byval_float_count) return .byval;
 64
 65            for (union_obj.field_types.get(ip), 0..) |field_ty, field_index| {
 66                if (Type.fromInterned(field_ty).bitSize(zcu) > 32 or
 67                    ty.fieldAlignment(field_index, zcu).compare(.gt, .@"32"))
 68                {
 69                    return Class.arrSize(bit_size, 64);
 70                }
 71            }
 72            return Class.arrSize(bit_size, 32);
 73        },
 74        .bool, .float => return .byval,
 75        .int => {
 76            // TODO this is incorrect for _BitInt(128) but implementing
 77            // this correctly makes implementing compiler-rt impossible.
 78            // const bit_size = ty.bitSize(zcu);
 79            // if (bit_size > 64) return .memory;
 80            return .byval;
 81        },
 82        .@"enum", .error_set => {
 83            const bit_size = ty.bitSize(zcu);
 84            if (bit_size > 64) return .memory;
 85            return .byval;
 86        },
 87        .vector => {
 88            const bit_size = ty.bitSize(zcu);
 89            // TODO is this controlled by a cpu feature?
 90            if (ctx == .ret and bit_size > 128) return .memory;
 91            if (bit_size > 512) return .memory;
 92            return .byval;
 93        },
 94        .optional => {
 95            assert(ty.isPtrLikeOptional(zcu));
 96            return .byval;
 97        },
 98        .pointer => {
 99            assert(!ty.isSlice(zcu));
100            return .byval;
101        },
102        .error_union,
103        .frame,
104        .@"anyframe",
105        .noreturn,
106        .void,
107        .type,
108        .comptime_float,
109        .comptime_int,
110        .undefined,
111        .null,
112        .@"fn",
113        .@"opaque",
114        .enum_literal,
115        .array,
116        => unreachable,
117    }
118}
119
120const byval_float_count = 4;
121fn countFloats(ty: Type, zcu: *Zcu, maybe_float_bits: *?u16) u32 {
122    const ip = &zcu.intern_pool;
123    const target = zcu.getTarget();
124    const invalid = std.math.maxInt(u32);
125    switch (ty.zigTypeTag(zcu)) {
126        .@"union" => {
127            const union_obj = zcu.typeToUnion(ty).?;
128            var max_count: u32 = 0;
129            for (union_obj.field_types.get(ip)) |field_ty| {
130                const field_count = countFloats(Type.fromInterned(field_ty), zcu, maybe_float_bits);
131                if (field_count == invalid) return invalid;
132                if (field_count > max_count) max_count = field_count;
133                if (max_count > byval_float_count) return invalid;
134            }
135            return max_count;
136        },
137        .@"struct" => {
138            const fields_len = ty.structFieldCount(zcu);
139            var count: u32 = 0;
140            var i: u32 = 0;
141            while (i < fields_len) : (i += 1) {
142                const field_ty = ty.fieldType(i, zcu);
143                const field_count = countFloats(field_ty, zcu, maybe_float_bits);
144                if (field_count == invalid) return invalid;
145                count += field_count;
146                if (count > byval_float_count) return invalid;
147            }
148            return count;
149        },
150        .float => {
151            const float_bits = maybe_float_bits.* orelse {
152                const float_bits = ty.floatBits(target);
153                if (float_bits != 32 and float_bits != 64) return invalid;
154                maybe_float_bits.* = float_bits;
155                return 1;
156            };
157            if (ty.floatBits(target) == float_bits) return 1;
158            return invalid;
159        },
160        .void => return 0,
161        else => return invalid,
162    }
163}