Commit 5f2d0d414d

Luuk de Gram <luuk@degram.dev>
2022-04-24 21:49:12
wasm: Implement codegen for C-ABI
This implements passing arguments and storing return values correctly for the C-ABI as specified by the tool-convention: https://github.com/WebAssembly/tool-conventions/blob/main/BasicCABI.md There's definitely room for better codegen in follow-up commits.
1 parent cb49af6
Changed files (2)
src
src/arch/wasm/abi.zig
@@ -52,6 +52,7 @@ pub fn classifyType(ty: Type, target: Target) [2]Class {
             return memory;
         },
         .Bool => return direct,
+        .Array => return memory,
         .ErrorUnion => {
             const has_tag = ty.errorUnionSet().hasRuntimeBitsIgnoreComptime();
             const has_pl = ty.errorUnionPayload().hasRuntimeBitsIgnoreComptime();
@@ -73,16 +74,13 @@ pub fn classifyType(ty: Type, target: Target) [2]Class {
             if (ty.isSlice()) return memory;
             return direct;
         },
-        .Array => {
-            if (ty.arrayLen() == 1) return direct;
-            return memory;
-        },
         .Union => {
             const layout = ty.unionGetLayout(target);
             if (layout.payload_size == 0 and layout.tag_size != 0) {
                 return classifyType(ty.unionTagType().?, target);
             }
-            return classifyType(ty.errorUnionPayload(), target);
+            if (ty.unionFields().count() > 1) return memory;
+            return classifyType(ty.unionFields().values()[0].ty, target);
         },
         .AnyFrame, .Frame => return direct,
 
@@ -100,3 +98,24 @@ pub fn classifyType(ty: Type, target: Target) [2]Class {
         => unreachable,
     }
 }
+
+/// Returns the scalar type a given type can represent.
+/// Asserts given type can be represented as scalar, such as
+/// a struct with a single scalar field.
+pub fn scalarType(ty: Type, target: std.Target) Type {
+    switch (ty.zigTypeTag()) {
+        .Struct => {
+            std.debug.assert(ty.structFieldCount() == 1);
+            return scalarType(ty.structFieldType(0), target);
+        },
+        .Union => {
+            const layout = ty.unionGetLayout(target);
+            if (layout.payload_size == 0 and layout.tag_size != 0) {
+                return scalarType(ty.unionTagType().?, target);
+            }
+            std.debug.assert(ty.unionFields().count() == 1);
+            return scalarType(ty.unionFields().values()[0].ty, target);
+        },
+        else => return ty,
+    }
+}
src/arch/wasm/CodeGen.zig
@@ -21,6 +21,7 @@ const Air = @import("../../Air.zig");
 const Liveness = @import("../../Liveness.zig");
 const Mir = @import("Mir.zig");
 const Emit = @import("Emit.zig");
+const abi = @import("abi.zig");
 
 /// Wasm Value, created when generating an instruction
 const WValue = union(enum) {
@@ -722,18 +723,15 @@ fn typeToValtype(ty: Type, target: std.Target) wasm.Valtype {
             const bits = ty.floatBits(target);
             if (bits == 16 or bits == 32) break :blk wasm.Valtype.f32;
             if (bits == 64) break :blk wasm.Valtype.f64;
+            if (bits == 128) break :blk wasm.Valtype.i64;
             return wasm.Valtype.i32; // represented as pointer to stack
         },
-        .Int => blk: {
+        .Int, .Enum => blk: {
             const info = ty.intInfo(target);
             if (info.bits <= 32) break :blk wasm.Valtype.i32;
-            if (info.bits > 32 and info.bits <= 64) break :blk wasm.Valtype.i64;
+            if (info.bits > 32 and info.bits <= 128) break :blk wasm.Valtype.i64;
             break :blk wasm.Valtype.i32; // represented as pointer to stack
         },
-        .Enum => {
-            var buf: Type.Payload.Bits = undefined;
-            return typeToValtype(ty.intTagType(&buf), target);
-        },
         else => wasm.Valtype.i32, // all represented as reference/immediate
     };
 }
@@ -787,33 +785,46 @@ fn allocLocal(self: *Self, ty: Type) InnerError!WValue {
 
 /// Generates a `wasm.Type` from a given function type.
 /// Memory is owned by the caller.
-fn genFunctype(gpa: Allocator, fn_ty: Type, target: std.Target) !wasm.Type {
+fn genFunctype(gpa: Allocator, fn_info: Type.Payload.Function.Data, target: std.Target) !wasm.Type {
     var params = std.ArrayList(wasm.Valtype).init(gpa);
     defer params.deinit();
     var returns = std.ArrayList(wasm.Valtype).init(gpa);
     defer returns.deinit();
-    const return_type = fn_ty.fnReturnType();
-
-    const want_sret = isByRef(return_type, target);
 
-    if (want_sret) {
-        try params.append(typeToValtype(return_type, target));
+    if (firstParamSRet(fn_info, target)) {
+        try params.append(typeToValtype(fn_info.return_type, target));
+    } else if (fn_info.return_type.hasRuntimeBitsIgnoreComptime()) {
+        if (fn_info.cc == .C) {
+            const res_classes = abi.classifyType(fn_info.return_type, target);
+            assert(res_classes[0] == .direct and res_classes[1] == .none);
+            const scalar_type = abi.scalarType(fn_info.return_type, target);
+            try returns.append(typeToValtype(scalar_type, target));
+        } else {
+            try returns.append(typeToValtype(fn_info.return_type, target));
+        }
     }
 
     // param types
-    if (fn_ty.fnParamLen() != 0) {
-        const fn_params = try gpa.alloc(Type, fn_ty.fnParamLen());
-        defer gpa.free(fn_params);
-        fn_ty.fnParamTypes(fn_params);
-        for (fn_params) |param_type| {
+    if (fn_info.param_types.len != 0) {
+        for (fn_info.param_types) |param_type| {
             if (!param_type.hasRuntimeBitsIgnoreComptime()) continue;
-            try params.append(typeToValtype(param_type, target));
-        }
-    }
 
-    // return type
-    if (!want_sret and return_type.hasRuntimeBitsIgnoreComptime()) {
-        try returns.append(typeToValtype(return_type, target));
+            switch (fn_info.cc) {
+                .C => {
+                    const param_classes = abi.classifyType(param_type, target);
+                    for (param_classes) |class| {
+                        if (class == .none) continue;
+                        if (class == .direct) {
+                            const scalar_type = abi.scalarType(param_type, target);
+                            try params.append(typeToValtype(scalar_type, target));
+                        } else {
+                            try params.append(typeToValtype(param_type, target));
+                        }
+                    }
+                },
+                else => try params.append(typeToValtype(param_type, target)),
+            }
+        }
     }
 
     return wasm.Type{
@@ -857,7 +868,7 @@ pub fn generate(
 }
 
 fn genFunc(self: *Self) InnerError!void {
-    var func_type = try genFunctype(self.gpa, self.decl.ty, self.target);
+    var func_type = try genFunctype(self.gpa, self.decl.ty.fnInfo(), self.target);
     defer func_type.deinit(self.gpa);
     self.decl.fn_link.wasm.type_index = try self.bin_file.putOrGetFuncType(func_type);
 
@@ -957,21 +968,22 @@ fn resolveCallingConventionValues(self: *Self, fn_ty: Type) InnerError!CallWValu
         .args = &.{},
         .return_value = .none,
     };
+    if (cc == .Naked) return result;
+
     var args = std.ArrayList(WValue).init(self.gpa);
     defer args.deinit();
 
-    const ret_ty = fn_ty.fnReturnType();
     // Check if we store the result as a pointer to the stack rather than
     // by value
-    if (isByRef(ret_ty, self.target)) {
+    if (firstParamSRet(fn_ty.fnInfo(), self.target)) {
         // the sret arg will be passed as first argument, therefore we
         // set the `return_value` before allocating locals for regular args.
         result.return_value = .{ .local = self.local_index };
         self.local_index += 1;
     }
+
     switch (cc) {
-        .Naked => return result,
-        .Unspecified, .C => {
+        .Unspecified => {
             for (param_types) |ty| {
                 if (!ty.hasRuntimeBitsIgnoreComptime()) {
                     continue;
@@ -981,12 +993,105 @@ fn resolveCallingConventionValues(self: *Self, fn_ty: Type) InnerError!CallWValu
                 self.local_index += 1;
             }
         },
-        else => return self.fail("TODO implement function parameters for cc '{}' on wasm", .{cc}),
+        .C => {
+            for (param_types) |ty| {
+                const ty_classes = abi.classifyType(ty, self.target);
+                for (ty_classes) |class| {
+                    if (class == .none) continue;
+                    try args.append(.{ .local = self.local_index });
+                    self.local_index += 1;
+                }
+            }
+        },
+        else => return self.fail("calling convention '{s}' not supported for Wasm", .{@tagName(cc)}),
     }
     result.args = args.toOwnedSlice();
     return result;
 }
 
+fn firstParamSRet(fn_info: Type.Payload.Function.Data, target: std.Target) bool {
+    switch (fn_info.cc) {
+        .Unspecified, .Inline => return isByRef(fn_info.return_type, target),
+        .C => {
+            const ty_classes = abi.classifyType(fn_info.return_type, target);
+            if (ty_classes[0] == .indirect) return true;
+            if (ty_classes[0] == .direct and ty_classes[1] == .direct) return true;
+            return false;
+        },
+        else => return false,
+    }
+}
+
+/// Lowers a Zig type and its value based on a given calling convention to ensure
+/// it matches the ABI.
+fn lowerArg(self: *Self, cc: std.builtin.CallingConvention, ty: Type, value: WValue) !void {
+    if (cc != .C) {
+        return self.lowerToStack(value);
+    }
+
+    const ty_classes = abi.classifyType(ty, self.target);
+    assert(ty_classes[0] != .none);
+    switch (ty.zigTypeTag()) {
+        .Struct, .Union => {
+            if (ty_classes[0] == .indirect) {
+                return self.lowerToStack(value);
+            }
+            assert(ty_classes[0] == .direct);
+            const scalar_type = abi.scalarType(ty, self.target);
+            const abi_size = scalar_type.abiSize(self.target);
+            const opcode = buildOpcode(.{
+                .op = .load,
+                .width = @intCast(u8, abi_size),
+                .signedness = if (scalar_type.isSignedInt()) .signed else .unsigned,
+                .valtype1 = typeToValtype(scalar_type, self.target),
+            });
+            try self.emitWValue(value);
+            try self.addMemArg(Mir.Inst.Tag.fromOpcode(opcode), .{
+                .offset = value.offset(),
+                .alignment = scalar_type.abiAlignment(self.target),
+            });
+        },
+        .Int, .Float => {
+            if (ty_classes[1] == .none) {
+                return self.lowerToStack(value);
+            }
+            assert(ty_classes[0] == .direct and ty_classes[1] == .direct);
+            assert(ty.abiSize(self.target) == 16);
+            // in this case we have an integer or float that must be lowered as 2 i64's.
+            try self.emitWValue(value);
+            try self.addMemArg(.i64_load, .{ .offset = value.offset(), .alignment = 16 });
+            try self.emitWValue(value);
+            try self.addMemArg(.i64_load, .{ .offset = value.offset() + 8, .alignment = 16 });
+        },
+        else => return self.lowerToStack(value),
+    }
+}
+
+/// Lowers a `WValue` to the stack. This means when the `value` results in
+/// `.stack_offset` we calculate the pointer of this offset and use that.
+/// The value is left on the stack, and not stored in any temporary.
+fn lowerToStack(self: *Self, value: WValue) !void {
+    switch (value) {
+        .stack_offset => |offset| {
+            try self.emitWValue(value);
+            if (offset > 0) {
+                switch (self.arch()) {
+                    .wasm32 => {
+                        try self.addImm32(@bitCast(i32, offset));
+                        try self.addTag(.i32_add);
+                    },
+                    .wasm64 => {
+                        try self.addImm64(offset);
+                        try self.addTag(.i64_add);
+                    },
+                    else => unreachable,
+                }
+            }
+        },
+        else => try self.emitWValue(value),
+    }
+}
+
 /// Creates a local for the initial stack value
 /// Asserts `initial_stack_value` is `.none`
 fn initializeStack(self: *Self) !void {
@@ -1489,11 +1594,31 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 fn airRet(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const operand = try self.resolveInst(un_op);
+    const ret_ty = self.decl.ty.fnReturnType();
 
     // result must be stored in the stack and we return a pointer
     // to the stack instead
     if (self.return_value != .none) {
         try self.store(self.return_value, operand, self.decl.ty.fnReturnType(), 0);
+    } else if (self.decl.ty.fnInfo().cc == .C and ret_ty.hasRuntimeBitsIgnoreComptime()) {
+        switch (ret_ty.zigTypeTag()) {
+            // Aggregate types can be lowered as a singular value
+            .Struct, .Union => {
+                const scalar_type = abi.scalarType(ret_ty, self.target);
+                try self.emitWValue(operand);
+                const opcode = buildOpcode(.{
+                    .op = .load,
+                    .width = @intCast(u8, scalar_type.abiSize(self.target)),
+                    .signedness = if (scalar_type.isSignedInt()) .signed else .unsigned,
+                    .valtype1 = typeToValtype(scalar_type, self.target),
+                });
+                try self.addMemArg(Mir.Inst.Tag.fromOpcode(opcode), .{
+                    .offset = operand.offset(),
+                    .alignment = scalar_type.abiAlignment(self.target),
+                });
+            },
+            else => try self.emitWValue(operand),
+        }
     } else {
         try self.emitWValue(operand);
     }
@@ -1509,9 +1634,10 @@ fn airRetPtr(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         return self.allocStack(Type.usize); // create pointer to void
     }
 
-    if (isByRef(child_type, self.target)) {
+    if (firstParamSRet(self.decl.ty.fnInfo(), self.target)) {
         return self.return_value;
     }
+
     return self.allocStackPtr(inst);
 }
 
@@ -1521,7 +1647,7 @@ fn airRetLoad(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
     const ret_ty = self.air.typeOf(un_op).childType();
     if (!ret_ty.hasRuntimeBitsIgnoreComptime()) return WValue.none;
 
-    if (!isByRef(ret_ty, self.target)) {
+    if (!firstParamSRet(self.decl.ty.fnInfo(), self.target)) {
         const result = try self.load(operand, ret_ty, 0);
         try self.emitWValue(result);
     }
@@ -1544,7 +1670,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
         else => unreachable,
     };
     const ret_ty = fn_ty.fnReturnType();
-    const first_param_sret = isByRef(ret_ty, self.target);
+    const first_param_sret = firstParamSRet(fn_ty.fnInfo(), self.target);
 
     const callee: ?*Decl = blk: {
         const func_val = self.air.value(pl_op.operand) orelse break :blk null;
@@ -1554,7 +1680,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
             break :blk module.declPtr(func.data.owner_decl);
         } else if (func_val.castTag(.extern_fn)) |extern_fn| {
             const ext_decl = module.declPtr(extern_fn.data.owner_decl);
-            var func_type = try genFunctype(self.gpa, ext_decl.ty, self.target);
+            var func_type = try genFunctype(self.gpa, ext_decl.ty.fnInfo(), self.target);
             defer func_type.deinit(self.gpa);
             ext_decl.fn_link.wasm.type_index = try self.bin_file.putOrGetFuncType(func_type);
             try self.bin_file.addOrUpdateImport(ext_decl);
@@ -1579,10 +1705,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
         const arg_ty = self.air.typeOf(arg_ref);
         if (!arg_ty.hasRuntimeBitsIgnoreComptime()) continue;
 
-        switch (arg_val) {
-            .stack_offset => try self.emitWValue(try self.buildPointerOffset(arg_val, 0, .new)),
-            else => try self.emitWValue(arg_val),
-        }
+        try self.lowerArg(fn_ty.fnInfo().cc, arg_ty, arg_val);
     }
 
     if (callee) |direct| {
@@ -1594,7 +1717,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
         const operand = try self.resolveInst(pl_op.operand);
         try self.emitWValue(operand);
 
-        var fn_type = try genFunctype(self.gpa, fn_ty, self.target);
+        var fn_type = try genFunctype(self.gpa, fn_ty.fnInfo(), self.target);
         defer fn_type.deinit(self.gpa);
 
         const fn_type_index = try self.bin_file.putOrGetFuncType(fn_type);
@@ -1608,6 +1731,14 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
         return WValue.none;
     } else if (first_param_sret) {
         return sret;
+        // TODO: Make this less fragile and optimize
+    } else if (fn_ty.fnInfo().cc == .C and ret_ty.zigTypeTag() == .Struct or ret_ty.zigTypeTag() == .Union) {
+        const result_local = try self.allocLocal(ret_ty);
+        try self.addLabel(.local_set, result_local.local);
+        const scalar_type = abi.scalarType(ret_ty, self.target);
+        const result = try self.allocStack(scalar_type);
+        try self.store(result, result_local, scalar_type, 0);
+        return result;
     } else {
         const result_local = try self.allocLocal(ret_ty);
         try self.addLabel(.local_set, result_local.local);
@@ -1749,9 +1880,20 @@ fn load(self: *Self, operand: WValue, ty: Type, offset: u32) InnerError!WValue {
 }
 
 fn airArg(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
-    _ = inst;
-    defer self.arg_index += 1;
-    return self.args[self.arg_index];
+    const arg = self.args[self.arg_index];
+    const cc = self.decl.ty.fnInfo().cc;
+    if (cc == .C) {
+        const ty = self.air.typeOfIndex(inst);
+        const arg_classes = abi.classifyType(ty, self.target);
+        for (arg_classes) |class| {
+            if (class != .none) {
+                self.arg_index += 1;
+            }
+        }
+    } else {
+        self.arg_index += 1;
+    }
+    return arg;
 }
 
 fn airBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue {