Commit 722d4a11bb

John Schmidt <john.schmidt.h@gmail.com>
2022-02-04 20:21:15
stage2: implement @sqrt for f{16,32,64}
Support for f128, comptime_float, and c_longdouble require improvements to compiler_rt and will implemented in a later PR. Some of the code in this commit could be made more generic, for instance `llvm.airSqrt` could probably be `llvm.airUnaryMath`, but let's cross that bridge when we get to it.
1 parent dd49ed1
src/arch/aarch64/CodeGen.zig
@@ -528,6 +528,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .max             => try self.airMax(inst),
             .slice           => try self.airSlice(inst),
 
+            .sqrt            => try self.airUnaryMath(inst),
+
             .add_with_overflow => try self.airAddWithOverflow(inst),
             .sub_with_overflow => try self.airSubWithOverflow(inst),
             .mul_with_overflow => try self.airMulWithOverflow(inst),
@@ -1223,6 +1225,15 @@ fn airPopcount(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
+fn airUnaryMath(self: *Self, inst: Air.Inst.Index) !void {
+    const un_op = self.air.instructions.items(.data)[inst].un_op;
+    const result: MCValue = if (self.liveness.isUnused(inst))
+        .dead
+    else
+        return self.fail("TODO implement airUnaryMath for {}", .{self.target.cpu.arch});
+    return self.finishAir(inst, result, .{ un_op, .none, .none });
+}
+
 fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool {
     if (!self.liveness.operandDies(inst, op_index))
         return false;
src/arch/arm/CodeGen.zig
@@ -520,6 +520,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .max             => try self.airMax(inst),
             .slice           => try self.airSlice(inst),
 
+            .sqrt            => try self.airUnaryMath(inst),
+
             .add_with_overflow => try self.airAddWithOverflow(inst),
             .sub_with_overflow => try self.airSubWithOverflow(inst),
             .mul_with_overflow => try self.airMulWithOverflow(inst),
@@ -1377,6 +1379,15 @@ fn airPopcount(self: *Self, inst: Air.Inst.Index) !void {
     // return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
+fn airUnaryMath(self: *Self, inst: Air.Inst.Index) !void {
+    const un_op = self.air.instructions.items(.data)[inst].un_op;
+    const result: MCValue = if (self.liveness.isUnused(inst))
+        .dead
+    else
+        return self.fail("TODO implement airUnaryMath for {}", .{self.target.cpu.arch});
+    return self.finishAir(inst, result, .{ un_op, .none, .none });
+}
+
 fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool {
     if (!self.liveness.operandDies(inst, op_index))
         return false;
src/arch/riscv64/CodeGen.zig
@@ -507,6 +507,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .max             => try self.airMax(inst),
             .slice           => try self.airSlice(inst),
 
+            .sqrt            => try self.airUnaryMath(inst),
+
             .add_with_overflow => try self.airAddWithOverflow(inst),
             .sub_with_overflow => try self.airSubWithOverflow(inst),
             .mul_with_overflow => try self.airMulWithOverflow(inst),
@@ -1166,6 +1168,15 @@ fn airPopcount(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
+fn airUnaryMath(self: *Self, inst: Air.Inst.Index) !void {
+    const un_op = self.air.instructions.items(.data)[inst].un_op;
+    const result: MCValue = if (self.liveness.isUnused(inst))
+        .dead
+    else
+        return self.fail("TODO implement airUnaryMath for {}", .{self.target.cpu.arch});
+    return self.finishAir(inst, result, .{ un_op, .none, .none });
+}
+
 fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool {
     if (!self.liveness.operandDies(inst, op_index))
         return false;
src/arch/wasm/CodeGen.zig
@@ -1681,6 +1681,8 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
         .unwrap_errunion_payload_ptr,
         .unwrap_errunion_err_ptr,
 
+        .sqrt,
+
         .ptr_slice_len_ptr,
         .ptr_slice_ptr_ptr,
         .int_to_float,
src/arch/x86_64/CodeGen.zig
@@ -599,6 +599,8 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
             .max             => try self.airMax(inst),
             .slice           => try self.airSlice(inst),
 
+            .sqrt            => try self.airUnaryMath(inst),
+
             .add_with_overflow => try self.airAddWithOverflow(inst),
             .sub_with_overflow => try self.airSubWithOverflow(inst),
             .mul_with_overflow => try self.airMulWithOverflow(inst),
@@ -1578,6 +1580,15 @@ fn airPopcount(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
+fn airUnaryMath(self: *Self, inst: Air.Inst.Index) !void {
+    const un_op = self.air.instructions.items(.data)[inst].un_op;
+    const result: MCValue = if (self.liveness.isUnused(inst))
+        .dead
+    else
+        return self.fail("TODO implement airUnaryMath for {}", .{self.target.cpu.arch});
+    return self.finishAir(inst, result, .{ un_op, .none, .none });
+}
+
 fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool {
     if (!self.liveness.operandDies(inst, op_index))
         return false;
src/codegen/c.zig
@@ -1446,6 +1446,8 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
             .mul_sat => try airSatOp(f, inst, "muls_"),
             .shl_sat => try airSatOp(f, inst, "shls_"),
 
+            .sqrt => try airSqrt(f, inst),
+
             .add_with_overflow => try airAddWithOverflow(f, inst),
             .sub_with_overflow => try airSubWithOverflow(f, inst),
             .mul_with_overflow => try airMulWithOverflow(f, inst),
@@ -3393,6 +3395,12 @@ fn airPrefetch(f: *Function, inst: Air.Inst.Index) !CValue {
     return CValue.none;
 }
 
+fn airSqrt(f: *Function, inst: Air.Inst.Index) !CValue {
+    _ = f;
+    _ = inst;
+    return f.fail("TODO: C backend: implement sqrt", .{});
+}
+
 fn toMemoryOrder(order: std.builtin.AtomicOrder) [:0]const u8 {
     return switch (order) {
         .Unordered => "memory_order_relaxed",
src/codegen/llvm.zig
@@ -2050,6 +2050,8 @@ pub const FuncGen = struct {
                 .shr                => try self.airShr(inst, false),
                 .shr_exact          => try self.airShr(inst, true),
 
+                .sqrt => try self.airSqrt(inst),
+
                 .cmp_eq  => try self.airCmp(inst, .eq),
                 .cmp_gt  => try self.airCmp(inst, .gt),
                 .cmp_gte => try self.airCmp(inst, .gte),
@@ -4211,6 +4213,20 @@ pub const FuncGen = struct {
         }
     }
 
+    fn airSqrt(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const un_op = self.air.instructions.items(.data)[inst].un_op;
+        const operand = try self.resolveInst(un_op);
+        const operand_ty = self.air.typeOf(un_op);
+
+        const operand_llvm_ty = try self.dg.llvmType(operand_ty);
+        const fn_val = self.getIntrinsic("llvm.sqrt", &.{operand_llvm_ty});
+        const params = [_]*const llvm.Value{operand};
+
+        return self.builder.buildCall(fn_val, &params, params.len, .C, .Auto, "");
+    }
+
     fn airClzCtz(self: *FuncGen, inst: Air.Inst.Index, prefix: [*:0]const u8) !?*const llvm.Value {
         if (self.liveness.isUnused(inst)) return null;
 
src/Air.zig
@@ -237,6 +237,10 @@ pub const Inst = struct {
         /// Uses the `ty_op` field.
         popcount,
 
+        /// Computes the square root of a floating point number.
+        /// Uses the `un_op` field.
+        sqrt,
+
         /// `<`. Result type is always bool.
         /// Uses the `bin_op` field.
         cmp_lt,
@@ -749,6 +753,8 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
         .max,
         => return air.typeOf(datas[inst].bin_op.lhs),
 
+        .sqrt => return air.typeOf(datas[inst].un_op),
+
         .cmp_lt,
         .cmp_lte,
         .cmp_eq,
src/Liveness.zig
@@ -338,6 +338,7 @@ fn analyzeInst(
         .ret_load,
         .tag_name,
         .error_name,
+        .sqrt,
         => {
             const operand = inst_datas[inst].un_op;
             return trackOperands(a, new_set, inst, main_tomb, .{ operand, .none, .none });
src/print_air.zig
@@ -158,6 +158,7 @@ const Writer = struct {
             .ret_load,
             .tag_name,
             .error_name,
+            .sqrt,
             => try w.writeUnOp(s, inst),
 
             .breakpoint,
src/Sema.zig
@@ -745,19 +745,19 @@ fn analyzeBodyInner(
             .clz => try sema.zirClzCtz(block, inst, .clz, Value.clz),
             .ctz => try sema.zirClzCtz(block, inst, .ctz, Value.ctz),
 
-            .sqrt  => try sema.zirUnaryMath(block, inst),
-            .sin   => try sema.zirUnaryMath(block, inst),
-            .cos   => try sema.zirUnaryMath(block, inst),
-            .exp   => try sema.zirUnaryMath(block, inst),
-            .exp2  => try sema.zirUnaryMath(block, inst),
-            .log   => try sema.zirUnaryMath(block, inst),
-            .log2  => try sema.zirUnaryMath(block, inst),
-            .log10 => try sema.zirUnaryMath(block, inst),
-            .fabs  => try sema.zirUnaryMath(block, inst),
-            .floor => try sema.zirUnaryMath(block, inst),
-            .ceil  => try sema.zirUnaryMath(block, inst),
-            .trunc => try sema.zirUnaryMath(block, inst),
-            .round => try sema.zirUnaryMath(block, inst),
+            .sqrt  => try sema.zirUnaryMath(block, inst, .sqrt),
+            .sin   => try sema.zirUnaryMath(block, inst, .sin),
+            .cos   => try sema.zirUnaryMath(block, inst, .cos),
+            .exp   => try sema.zirUnaryMath(block, inst, .exp),
+            .exp2  => try sema.zirUnaryMath(block, inst, .exp2),
+            .log   => try sema.zirUnaryMath(block, inst, .log),
+            .log2  => try sema.zirUnaryMath(block, inst, .log2),
+            .log10 => try sema.zirUnaryMath(block, inst, .log10),
+            .fabs  => try sema.zirUnaryMath(block, inst, .fabs),
+            .floor => try sema.zirUnaryMath(block, inst, .floor),
+            .ceil  => try sema.zirUnaryMath(block, inst, .ceil),
+            .trunc => try sema.zirUnaryMath(block, inst, .trunc),
+            .round => try sema.zirUnaryMath(block, inst, .round),
 
             .error_set_decl      => try sema.zirErrorSetDecl(block, inst, .parent),
             .error_set_decl_anon => try sema.zirErrorSetDecl(block, inst, .anon),
@@ -11010,10 +11010,64 @@ fn zirErrorName(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
     return block.addUnOp(.error_name, operand);
 }
 
-fn zirUnaryMath(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
+fn zirUnaryMath(
+    sema: *Sema,
+    block: *Block,
+    inst: Zir.Inst.Index,
+    zir_tag: Zir.Inst.Tag,
+) CompileError!Air.Inst.Ref {
+    const tracy = trace(@src());
+    defer tracy.end();
+
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const src = inst_data.src();
-    return sema.fail(block, src, "TODO: Sema.zirUnaryMath", .{});
+    const operand = sema.resolveInst(inst_data.operand);
+    const operand_ty = sema.typeOf(operand);
+    const operand_zig_ty_tag = operand_ty.zigTypeTag();
+
+    const is_float = operand_zig_ty_tag == .Float or operand_zig_ty_tag == .ComptimeFloat;
+    if (!is_float) {
+        return sema.fail(block, src, "expected float type, found '{s}'", .{@tagName(operand_zig_ty_tag)});
+    }
+
+    switch (zir_tag) {
+        .sqrt => {
+            switch (operand_ty.tag()) {
+                .f128,
+                .comptime_float,
+                .c_longdouble,
+                => |t| return sema.fail(block, src, "TODO implement @sqrt for type '{s}'", .{@tagName(t)}),
+                else => {},
+            }
+
+            const maybe_operand_val = try sema.resolveMaybeUndefVal(block, src, operand);
+            if (maybe_operand_val) |val| {
+                if (val.isUndef())
+                    return sema.addConstUndef(operand_ty);
+                const result_val = try val.sqrt(operand_ty, sema.arena);
+                return sema.addConstant(operand_ty, result_val);
+            }
+
+            try sema.requireRuntimeBlock(block, src);
+            return block.addUnOp(.sqrt, operand);
+        },
+
+        .sin,
+        .cos,
+        .exp,
+        .exp2,
+        .log,
+        .log2,
+        .log10,
+        .fabs,
+        .floor,
+        .ceil,
+        .trunc,
+        .round,
+        => return sema.fail(block, src, "TODO: implement zirUnaryMath for ZIR tag '{s}'", .{@tagName(zir_tag)}),
+
+        else => unreachable,
+    }
 }
 
 fn zirTagName(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
src/value.zig
@@ -3265,6 +3265,28 @@ pub const Value = extern union {
         }
     }
 
+    pub fn sqrt(val: Value, float_type: Type, arena: Allocator) !Value {
+        switch (float_type.tag()) {
+            .f16 => {
+                const f = val.toFloat(f16);
+                return Value.Tag.float_16.create(arena, @sqrt(f));
+            },
+            .f32 => {
+                const f = val.toFloat(f32);
+                return Value.Tag.float_32.create(arena, @sqrt(f));
+            },
+            .f64 => {
+                const f = val.toFloat(f64);
+                return Value.Tag.float_64.create(arena, @sqrt(f));
+            },
+
+            // TODO: implement @sqrt for these types
+            .f128, .comptime_float, .c_longdouble => unreachable,
+
+            else => unreachable,
+        }
+    }
+
     /// This type is not copyable since it may contain pointers to its inner data.
     pub const Payload = struct {
         tag: Tag,
test/behavior/floatop.zig
@@ -72,3 +72,45 @@ test "negative f128 floatToInt at compile-time" {
     var b = @floatToInt(i64, a);
     try expect(@as(i64, -2) == b);
 }
+
+test "@sqrt" {
+    comptime try testSqrt();
+    try testSqrt();
+}
+
+fn testSqrt() !void {
+    {
+        var a: f16 = 4;
+        try expect(@sqrt(a) == 2);
+    }
+    {
+        var a: f32 = 9;
+        try expect(@sqrt(a) == 3);
+        var b: f32 = 1.1;
+        try expect(math.approxEqAbs(f32, @sqrt(b), 1.0488088481701516, epsilon));
+    }
+    {
+        var a: f64 = 25;
+        try expect(@sqrt(a) == 5);
+    }
+}
+
+test "more @sqrt f16 tests" {
+    // TODO these are not all passing at comptime
+    try expect(@sqrt(@as(f16, 0.0)) == 0.0);
+    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 2.0)), 1.414214, epsilon));
+    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 3.6)), 1.897367, epsilon));
+    try expect(@sqrt(@as(f16, 4.0)) == 2.0);
+    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 7.539840)), 2.745877, epsilon));
+    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 19.230934)), 4.385309, epsilon));
+    try expect(@sqrt(@as(f16, 64.0)) == 8.0);
+    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 64.1)), 8.006248, epsilon));
+    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 8942.230469)), 94.563370, epsilon));
+
+    // special cases
+    try expect(math.isPositiveInf(@sqrt(@as(f16, math.inf(f16)))));
+    try expect(@sqrt(@as(f16, 0.0)) == 0.0);
+    try expect(@sqrt(@as(f16, -0.0)) == -0.0);
+    try expect(math.isNan(@sqrt(@as(f16, -1.0))));
+    try expect(math.isNan(@sqrt(@as(f16, math.nan(f16)))));
+}
test/behavior/floatop_stage1.zig
@@ -14,20 +14,6 @@ test "@sqrt" {
 }
 
 fn testSqrt() !void {
-    {
-        var a: f16 = 4;
-        try expect(@sqrt(a) == 2);
-    }
-    {
-        var a: f32 = 9;
-        try expect(@sqrt(a) == 3);
-        var b: f32 = 1.1;
-        try expect(math.approxEqAbs(f32, @sqrt(b), 1.0488088481701516, epsilon));
-    }
-    {
-        var a: f64 = 25;
-        try expect(@sqrt(a) == 5);
-    }
     if (has_f80_rt) {
         var a: f80 = 25;
         try expect(@sqrt(a) == 5);
@@ -51,26 +37,6 @@ fn testSqrt() !void {
     }
 }
 
-test "more @sqrt f16 tests" {
-    // TODO these are not all passing at comptime
-    try expect(@sqrt(@as(f16, 0.0)) == 0.0);
-    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 2.0)), 1.414214, epsilon));
-    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 3.6)), 1.897367, epsilon));
-    try expect(@sqrt(@as(f16, 4.0)) == 2.0);
-    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 7.539840)), 2.745877, epsilon));
-    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 19.230934)), 4.385309, epsilon));
-    try expect(@sqrt(@as(f16, 64.0)) == 8.0);
-    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 64.1)), 8.006248, epsilon));
-    try expect(math.approxEqAbs(f16, @sqrt(@as(f16, 8942.230469)), 94.563370, epsilon));
-
-    // special cases
-    try expect(math.isPositiveInf(@sqrt(@as(f16, math.inf(f16)))));
-    try expect(@sqrt(@as(f16, 0.0)) == 0.0);
-    try expect(@sqrt(@as(f16, -0.0)) == -0.0);
-    try expect(math.isNan(@sqrt(@as(f16, -1.0))));
-    try expect(math.isNan(@sqrt(@as(f16, math.nan(f16)))));
-}
-
 test "@sin" {
     comptime try testSin();
     try testSin();
test/behavior/math.zig
@@ -792,8 +792,6 @@ fn remdiv(comptime T: type) !void {
 }
 
 test "@sqrt" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
-
     try testSqrt(f64, 12.0);
     comptime try testSqrt(f64, 12.0);
     try testSqrt(f32, 13.0);
@@ -801,10 +799,12 @@ test "@sqrt" {
     try testSqrt(f16, 13.0);
     comptime try testSqrt(f16, 13.0);
 
-    const x = 14.0;
-    const y = x * x;
-    const z = @sqrt(y);
-    comptime try expect(z == x);
+    if (builtin.zig_backend == .stage1) {
+        const x = 14.0;
+        const y = x * x;
+        const z = @sqrt(y);
+        comptime try expect(z == x);
+    }
 }
 
 fn testSqrt(comptime T: type, x: T) !void {