Commit 3c72b4d25e

Luuk de Gram <luuk@degram.dev>
2023-05-25 17:52:39
wasm: support and optimize for all packed unions
For packed unions where its abi size is less than or equal to 8 bytes we store it directly and don't pass it by reference. This means that when retrieving the field, we will perform shifts and bitcasts to ensure the correct type is returned. For larger packed unions, we either allocate a new stack value based on the field type when the field type is also passed by reference, or load it directly into a local if it's not.
1 parent 7cfc44d
Changed files (1)
src
arch
src/arch/wasm/CodeGen.zig
@@ -1716,7 +1716,14 @@ fn isByRef(ty: Type, target: std.Target) bool {
         .Array,
         .Frame,
         .Union,
-        => return ty.hasRuntimeBitsIgnoreComptime(),
+        => {
+            if (ty.castTag(.@"union")) |union_ty| {
+                if (union_ty.data.layout == .Packed) {
+                    return ty.abiSize(target) > 8;
+                }
+            }
+            return ty.hasRuntimeBitsIgnoreComptime();
+        },
         .Struct => {
             if (ty.castTag(.@"struct")) |struct_ty| {
                 const struct_obj = struct_ty.data;
@@ -3131,6 +3138,14 @@ fn lowerConstant(func: *CodeGen, arg_val: Value, ty: Type) InnerError!WValue {
             val.writeToMemory(ty, func.bin_file.base.options.module.?, &buf) catch unreachable;
             return func.storeSimdImmd(buf);
         },
+        .Union => {
+            // in this case we have a packed union which will not be passed by reference.
+            const union_ty = ty.cast(Type.Payload.Union).?.data;
+            const union_obj = val.castTag(.@"union").?.data;
+            const field_index = ty.unionTagFieldIndex(union_obj.tag, func.bin_file.base.options.module.?).?;
+            const field_ty = union_ty.fields.values()[field_index].ty;
+            return func.lowerConstant(union_obj.val, field_ty);
+        },
         else => |zig_type| return func.fail("Wasm TODO: LowerConstant for zigTypeTag {}", .{zig_type}),
     }
 }
@@ -3661,8 +3676,42 @@ fn airStructFieldVal(func: *CodeGen, inst: Air.Inst.Index) InnerError!void {
                 break :result try truncated.toLocal(func, field_ty);
             },
             .Union => result: {
-                const val = try func.load(operand, field_ty, 0);
-                break :result try val.toLocal(func, field_ty);
+                if (isByRef(struct_ty, func.target)) {
+                    if (!isByRef(field_ty, func.target)) {
+                        const val = try func.load(operand, field_ty, 0);
+                        break :result try val.toLocal(func, field_ty);
+                    } else {
+                        const new_stack_val = try func.allocStack(field_ty);
+                        try func.store(new_stack_val, operand, field_ty, 0);
+                        break :result new_stack_val;
+                    }
+                }
+
+                var payload: Type.Payload.Bits = .{
+                    .base = .{ .tag = .int_unsigned },
+                    .data = @intCast(u16, struct_ty.bitSize(func.target)),
+                };
+                const union_int_type = Type.initPayload(&payload.base);
+                if (field_ty.zigTypeTag() == .Float) {
+                    var int_payload: Type.Payload.Bits = .{
+                        .base = .{ .tag = .int_unsigned },
+                        .data = @intCast(u16, field_ty.bitSize(func.target)),
+                    };
+                    const int_type = Type.initPayload(&int_payload.base);
+                    const truncated = try func.trunc(operand, int_type, union_int_type);
+                    const bitcasted = try func.bitcast(field_ty, int_type, truncated);
+                    break :result try bitcasted.toLocal(func, field_ty);
+                } else if (field_ty.isPtrAtRuntime()) {
+                    var int_payload: Type.Payload.Bits = .{
+                        .base = .{ .tag = .int_unsigned },
+                        .data = @intCast(u16, field_ty.bitSize(func.target)),
+                    };
+                    const int_type = Type.initPayload(&int_payload.base);
+                    const truncated = try func.trunc(operand, int_type, union_int_type);
+                    break :result try truncated.toLocal(func, field_ty);
+                }
+                const truncated = try func.trunc(operand, field_ty, union_int_type);
+                break :result try truncated.toLocal(func, field_ty);
             },
             else => unreachable,
         },