Commit 6ae898b244

Luuk de Gram <luuk@degram.dev>
2022-06-13 20:48:53
wasm: more f16 support and cleanup of intrinsics
`genFunctype` now accepts calling convention, param types, and return type as part of its function signature rather than `fnData`. This means we no longer have to create a dummy for our intrinsic call abstraction. This also adds support for f16 division and builtins such as `@ceil` & more.
1 parent ba37bc8
Changed files (2)
src
arch
link
src/arch/wasm/CodeGen.zig
@@ -417,22 +417,24 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode {
             .f64 => return .f64_neg,
         },
         .ceil => switch (args.valtype1.?) {
-            .i32, .i64 => unreachable,
+            .i64 => unreachable,
+            .i32 => return .f32_ceil, // when valtype is f16, we store it in i32.
             .f32 => return .f32_ceil,
             .f64 => return .f64_ceil,
         },
         .floor => switch (args.valtype1.?) {
-            .i32, .i64 => unreachable,
+            .i64 => unreachable,
+            .i32 => return .f32_floor, // when valtype is f16, we store it in i32.
             .f32 => return .f32_floor,
             .f64 => return .f64_floor,
         },
         .trunc => switch (args.valtype1.?) {
-            .i32 => switch (args.valtype2.?) {
+            .i32 => if (args.valtype2) |valty| switch (valty) {
                 .i32 => unreachable,
                 .i64 => unreachable,
                 .f32 => if (args.signedness.? == .signed) return .i32_trunc_f32_s else return .i32_trunc_f32_u,
                 .f64 => if (args.signedness.? == .signed) return .i32_trunc_f64_s else return .i32_trunc_f64_u,
-            },
+            } else return .f32_trunc, // when no valtype2, it's an f16 instead which is stored in an i32.
             .i64 => unreachable,
             .f32 => return .f32_trunc,
             .f64 => return .f64_trunc,
@@ -788,55 +790,53 @@ fn allocLocal(self: *Self, ty: Type) InnerError!WValue {
 
 /// Generates a `wasm.Type` from a given function type.
 /// Memory is owned by the caller.
-fn genFunctype(gpa: Allocator, fn_info: Type.Payload.Function.Data, target: std.Target) !wasm.Type {
-    var params = std.ArrayList(wasm.Valtype).init(gpa);
-    defer params.deinit();
+fn genFunctype(gpa: Allocator, cc: std.builtin.CallingConvention, params: []const Type, return_type: Type, target: std.Target) !wasm.Type {
+    var temp_params = std.ArrayList(wasm.Valtype).init(gpa);
+    defer temp_params.deinit();
     var returns = std.ArrayList(wasm.Valtype).init(gpa);
     defer returns.deinit();
 
-    if (firstParamSRet(fn_info.cc, fn_info.return_type, target)) {
-        try params.append(.i32); // memory address is always a 32-bit handle
-    } else if (fn_info.return_type.hasRuntimeBitsIgnoreComptime()) {
-        if (fn_info.cc == .C) {
-            const res_classes = abi.classifyType(fn_info.return_type, target);
+    if (firstParamSRet(cc, return_type, target)) {
+        try temp_params.append(.i32); // memory address is always a 32-bit handle
+    } else if (return_type.hasRuntimeBitsIgnoreComptime()) {
+        if (cc == .C) {
+            const res_classes = abi.classifyType(return_type, target);
             assert(res_classes[0] == .direct and res_classes[1] == .none);
-            const scalar_type = abi.scalarType(fn_info.return_type, target);
+            const scalar_type = abi.scalarType(return_type, target);
             try returns.append(typeToValtype(scalar_type, target));
         } else {
-            try returns.append(typeToValtype(fn_info.return_type, target));
+            try returns.append(typeToValtype(return_type, target));
         }
-    } else if (fn_info.return_type.isError()) {
+    } else if (return_type.isError()) {
         try returns.append(.i32);
     }
 
     // param types
-    if (fn_info.param_types.len != 0) {
-        for (fn_info.param_types) |param_type| {
-            if (!param_type.hasRuntimeBitsIgnoreComptime()) continue;
-
-            switch (fn_info.cc) {
-                .C => {
-                    const param_classes = abi.classifyType(param_type, target);
-                    for (param_classes) |class| {
-                        if (class == .none) continue;
-                        if (class == .direct) {
-                            const scalar_type = abi.scalarType(param_type, target);
-                            try params.append(typeToValtype(scalar_type, target));
-                        } else {
-                            try params.append(typeToValtype(param_type, target));
-                        }
+    for (params) |param_type| {
+        if (!param_type.hasRuntimeBitsIgnoreComptime()) continue;
+
+        switch (cc) {
+            .C => {
+                const param_classes = abi.classifyType(param_type, target);
+                for (param_classes) |class| {
+                    if (class == .none) continue;
+                    if (class == .direct) {
+                        const scalar_type = abi.scalarType(param_type, target);
+                        try temp_params.append(typeToValtype(scalar_type, target));
+                    } else {
+                        try temp_params.append(typeToValtype(param_type, target));
                     }
-                },
-                else => if (isByRef(param_type, target))
-                    try params.append(.i32)
-                else
-                    try params.append(typeToValtype(param_type, target)),
-            }
+                }
+            },
+            else => if (isByRef(param_type, target))
+                try temp_params.append(.i32)
+            else
+                try temp_params.append(typeToValtype(param_type, target)),
         }
     }
 
     return wasm.Type{
-        .params = params.toOwnedSlice(),
+        .params = temp_params.toOwnedSlice(),
         .returns = returns.toOwnedSlice(),
     };
 }
@@ -877,7 +877,8 @@ pub fn generate(
 }
 
 fn genFunc(self: *Self) InnerError!void {
-    var func_type = try genFunctype(self.gpa, self.decl.ty.fnInfo(), self.target);
+    const fn_info = self.decl.ty.fnInfo();
+    var func_type = try genFunctype(self.gpa, fn_info.cc, fn_info.param_types, fn_info.return_type, self.target);
     defer func_type.deinit(self.gpa);
     self.decl.fn_link.wasm.type_index = try self.bin_file.putOrGetFuncType(func_type);
 
@@ -1733,7 +1734,8 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
             break :blk module.declPtr(func.data.owner_decl);
         } else if (func_val.castTag(.extern_fn)) |extern_fn| {
             const ext_decl = module.declPtr(extern_fn.data.owner_decl);
-            var func_type = try genFunctype(self.gpa, ext_decl.ty.fnInfo(), self.target);
+            const ext_info = ext_decl.ty.fnInfo();
+            var func_type = try genFunctype(self.gpa, ext_info.cc, ext_info.param_types, ext_info.return_type, self.target);
             defer func_type.deinit(self.gpa);
             ext_decl.fn_link.wasm.type_index = try self.bin_file.putOrGetFuncType(func_type);
             try self.bin_file.addOrUpdateImport(
@@ -1774,7 +1776,7 @@ fn airCall(self: *Self, inst: Air.Inst.Index, modifier: std.builtin.CallOptions.
         const operand = try self.resolveInst(pl_op.operand);
         try self.emitWValue(operand);
 
-        var fn_type = try genFunctype(self.gpa, fn_ty.fnInfo(), self.target);
+        var fn_type = try genFunctype(self.gpa, fn_info.cc, fn_info.param_types, fn_info.return_type, self.target);
         defer fn_type.deinit(self.gpa);
 
         const fn_type_index = try self.bin_file.putOrGetFuncType(fn_type);
@@ -4883,12 +4885,38 @@ fn airDivFloor(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
         try self.emitWValue(rem_result);
         try self.addTag(.select);
     } else {
-        const div_result = try self.binOp(lhs, rhs, ty, .div);
-        try self.emitWValue(div_result);
-        switch (ty.floatBits(self.target)) {
-            32 => try self.addTag(.f32_floor),
-            64 => try self.addTag(.f64_floor),
-            else => |bit_size| return self.fail("TODO: `@divFloor` for floats with bitsize: {d}", .{bit_size}),
+        const float_bits = ty.floatBits(self.target);
+        if (float_bits > 64) {
+            return self.fail("TODO: `@divFloor` for floats with bitsize: {d}", .{float_bits});
+        }
+        const is_f16 = float_bits == 16;
+
+        const lhs_operand = if (is_f16) blk: {
+            break :blk try self.fpext(lhs, Type.f16, Type.f32);
+        } else lhs;
+        const rhs_operand = if (is_f16) blk: {
+            break :blk try self.fpext(rhs, Type.f16, Type.f32);
+        } else rhs;
+
+        try self.emitWValue(lhs_operand);
+        try self.emitWValue(rhs_operand);
+
+        switch (float_bits) {
+            16, 32 => {
+                try self.addTag(.f32_div);
+                try self.addTag(.f32_floor);
+            },
+            64 => {
+                try self.addTag(.f64_div);
+                try self.addTag(.f64_floor);
+            },
+            else => unreachable,
+        }
+
+        if (is_f16) {
+            // we can re-use temporary local
+            try self.addLabel(.local_set, lhs_operand.local);
+            return self.fptrunc(lhs_operand, Type.f32, Type.f16);
         }
     }
 
@@ -4961,22 +4989,28 @@ fn airCeilFloorTrunc(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValu
 
     const un_op = self.air.instructions.items(.data)[inst].un_op;
     const ty = self.air.typeOfIndex(inst);
+    const float_bits = ty.floatBits(self.target);
+    const is_f16 = float_bits == 16;
 
     if (ty.zigTypeTag() == .Vector) {
         return self.fail("TODO: Implement `@ceil` for vectors", .{});
     }
+    if (float_bits > 64) {
+        return self.fail("TODO: implement `@ceil`, `@trunc`, `@floor` for floats larger than 64bits", .{});
+    }
 
     const operand = try self.resolveInst(un_op);
-    try self.emitWValue(operand);
-    switch (ty.floatBits(self.target)) {
-        32, 64 => {
-            const opcode = buildOpcode(.{
-                .op = op,
-                .valtype1 = typeToValtype(ty, self.target),
-            });
-            try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
-        },
-        else => |bit_size| return self.fail("TODO: Implement `@ceil` for floats with bitsize {d}", .{bit_size}),
+    const op_to_lower = if (is_f16) blk: {
+        break :blk try self.fpext(operand, Type.f16, Type.f32);
+    } else operand;
+    try self.emitWValue(op_to_lower);
+    const opcode = buildOpcode(.{ .op = op, .valtype1 = typeToValtype(ty, self.target) });
+    try self.addTag(Mir.Inst.Tag.fromOpcode(opcode));
+
+    if (is_f16) {
+        // re-use temporary to save locals
+        try self.addLabel(.local_set, op_to_lower.local);
+        return self.fptrunc(op_to_lower, Type.f32, Type.f16);
     }
 
     const result = try self.allocLocal(ty);
@@ -5212,19 +5246,8 @@ fn callIntrinsic(
         return self.fail("Could not find or create global symbol '{s}'", .{@errorName(err)});
     };
 
-    // TODO: have genFunctype accept individual params so we don't,
-    // need to initialize a fake Fn.Data instance.
-    var pt_tmp = try self.gpa.dupe(Type, param_types);
-    defer self.gpa.free(pt_tmp);
-    var func_type = try genFunctype(self.gpa, .{
-        .param_types = pt_tmp,
-        .comptime_params = undefined,
-        .return_type = return_type,
-        .alignment = 0,
-        .cc = .C,
-        .is_var_args = false,
-        .is_generic = false,
-    }, self.target);
+    // Always pass over C-ABI
+    var func_type = try genFunctype(self.gpa, .C, param_types, return_type, self.target);
     defer func_type.deinit(self.gpa);
     const func_type_index = try self.bin_file.putOrGetFuncType(func_type);
     try self.bin_file.addOrUpdateImport(name, symbol_index, null, func_type_index);
src/link/Wasm.zig
@@ -1840,7 +1840,7 @@ pub fn flushModule(self: *Wasm, comp: *Compilation, prog_node: *std.Progress.Nod
         try positionals.append(c_object.status.success.object_path);
     }
 
-    if (comp.compiler_rt_static_lib) |lib| {
+    if (comp.compiler_rt_lib) |lib| {
         try positionals.append(lib.full_object_path);
     }