Commit b5459eb987

Andrew Kelley <superjoe30@gmail.com>
2018-04-15 19:21:52
add @sqrt built-in function
See #767
1 parent 4a2bfec
doc/langref.html.in
@@ -4669,6 +4669,16 @@ pub const FloatMode = enum {
       The result is a target-specific compile time constant.
       </p>
       {#header_close#}
+      {#header_open|@sqrt#}
+      <pre><code class="zig">@sqrt(comptime T: type, value: T) -&gt; T</code></pre>
+      <p>
+      Performs the square root of a floating point number. Uses a dedicated hardware instruction
+      when available. Currently only supports f32 and f64 at runtime. f128 at runtime is TODO.
+      </p>
+      <p>
+      This is a low-level intrinsic. Most code can use <code>std.math.sqrt</code> instead.
+      </p>
+      {#header_close#}
       {#header_open|@subWithOverflow#}
       <pre><code class="zig">@subWithOverflow(comptime T: type, a: T, b: T, result: &T) -&gt; bool</code></pre>
       <p>
@@ -5991,7 +6001,7 @@ hljs.registerLanguage("zig", function(t) {
         a = t.IR + "\\s*\\(",
         c = {
             keyword: "const align var extern stdcallcc nakedcc volatile export pub noalias inline struct packed enum union break return try catch test continue unreachable comptime and or asm defer errdefer if else switch while for fn use bool f32 f64 void type noreturn error i8 u8 i16 u16 i32 u32 i64 u64 isize usize i8w u8w i16w i32w u32w i64w u64w isizew usizew c_short c_ushort c_int c_uint c_long c_ulong c_longlong c_ulonglong",
-            built_in: "breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchg fence divExact truncate atomicRmw",
+            built_in: "breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchg fence divExact truncate atomicRmw sqrt",
             literal: "true false null undefined"
         },
         n = [e, t.CLCM, t.CBCM, s, r];
src/all_types.hpp
@@ -1317,6 +1317,7 @@ enum BuiltinFnId {
     BuiltinFnIdDivFloor,
     BuiltinFnIdRem,
     BuiltinFnIdMod,
+    BuiltinFnIdSqrt,
     BuiltinFnIdTruncate,
     BuiltinFnIdIntType,
     BuiltinFnIdSetCold,
@@ -1413,6 +1414,7 @@ enum ZigLLVMFnId {
     ZigLLVMFnIdOverflowArithmetic,
     ZigLLVMFnIdFloor,
     ZigLLVMFnIdCeil,
+    ZigLLVMFnIdSqrt,
 };
 
 enum AddSubMul {
@@ -1433,7 +1435,7 @@ struct ZigLLVMFnKey {
         } clz;
         struct {
             uint32_t bit_count;
-        } floor_ceil;
+        } floating;
         struct {
             AddSubMul add_sub_mul;
             uint32_t bit_count;
@@ -2047,6 +2049,7 @@ enum IrInstructionId {
     IrInstructionIdAddImplicitReturnType,
     IrInstructionIdMergeErrRetTraces,
     IrInstructionIdMarkErrRetTracePtr,
+    IrInstructionIdSqrt,
 };
 
 struct IrInstruction {
@@ -3036,6 +3039,13 @@ struct IrInstructionMarkErrRetTracePtr {
     IrInstruction *err_ret_trace_ptr;
 };
 
+struct IrInstructionSqrt {
+    IrInstruction base;
+
+    IrInstruction *type;
+    IrInstruction *op;
+};
+
 static const size_t slice_ptr_index = 0;
 static const size_t slice_len_index = 1;
 
src/analyze.cpp
@@ -5801,9 +5801,11 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) {
         case ZigLLVMFnIdClz:
             return (uint32_t)(x.data.clz.bit_count) * (uint32_t)2428952817;
         case ZigLLVMFnIdFloor:
-            return (uint32_t)(x.data.floor_ceil.bit_count) * (uint32_t)1899859168;
+            return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1899859168;
         case ZigLLVMFnIdCeil:
-            return (uint32_t)(x.data.floor_ceil.bit_count) * (uint32_t)1953839089;
+            return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1953839089;
+        case ZigLLVMFnIdSqrt:
+            return (uint32_t)(x.data.floating.bit_count) * (uint32_t)2225366385;
         case ZigLLVMFnIdOverflowArithmetic:
             return ((uint32_t)(x.data.overflow_arithmetic.bit_count) * 87135777) +
                 ((uint32_t)(x.data.overflow_arithmetic.add_sub_mul) * 31640542) +
@@ -5822,7 +5824,8 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) {
             return a.data.clz.bit_count == b.data.clz.bit_count;
         case ZigLLVMFnIdFloor:
         case ZigLLVMFnIdCeil:
-            return a.data.floor_ceil.bit_count == b.data.floor_ceil.bit_count;
+        case ZigLLVMFnIdSqrt:
+            return a.data.floating.bit_count == b.data.floating.bit_count;
         case ZigLLVMFnIdOverflowArithmetic:
             return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) &&
                 (a.data.overflow_arithmetic.add_sub_mul == b.data.overflow_arithmetic.add_sub_mul) &&
src/bigfloat.cpp
@@ -181,3 +181,7 @@ bool bigfloat_has_fraction(const BigFloat *bigfloat) {
     f128M_roundToInt(&bigfloat->value, softfloat_round_minMag, false, &floored);
     return !f128M_eq(&floored, &bigfloat->value);
 }
+
+void bigfloat_sqrt(BigFloat *dest, const BigFloat *op) {
+    f128M_sqrt(&op->value, &dest->value);
+}
src/bigfloat.hpp
@@ -42,6 +42,7 @@ void bigfloat_div_trunc(BigFloat *dest, const BigFloat *op1, const BigFloat *op2
 void bigfloat_div_floor(BigFloat *dest, const BigFloat *op1, const BigFloat *op2);
 void bigfloat_rem(BigFloat *dest, const BigFloat *op1, const BigFloat *op2);
 void bigfloat_mod(BigFloat *dest, const BigFloat *op1, const BigFloat *op2);
+void bigfloat_sqrt(BigFloat *dest, const BigFloat *op);
 void bigfloat_append_buf(Buf *buf, const BigFloat *op);
 Cmp bigfloat_cmp(const BigFloat *op1, const BigFloat *op2);
 
src/codegen.cpp
@@ -717,12 +717,12 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, TypeTableEntry *type_entry,
     return fn_val;
 }
 
-static LLVMValueRef get_floor_ceil_fn(CodeGen *g, TypeTableEntry *type_entry, ZigLLVMFnId fn_id) {
+static LLVMValueRef get_float_fn(CodeGen *g, TypeTableEntry *type_entry, ZigLLVMFnId fn_id) {
     assert(type_entry->id == TypeTableEntryIdFloat);
 
     ZigLLVMFnKey key = {};
     key.id = fn_id;
-    key.data.floor_ceil.bit_count = (uint32_t)type_entry->data.floating.bit_count;
+    key.data.floating.bit_count = (uint32_t)type_entry->data.floating.bit_count;
 
     auto existing_entry = g->llvm_fn_table.maybe_get(key);
     if (existing_entry)
@@ -733,6 +733,8 @@ static LLVMValueRef get_floor_ceil_fn(CodeGen *g, TypeTableEntry *type_entry, Zi
         name = "floor";
     } else if (fn_id == ZigLLVMFnIdCeil) {
         name = "ceil";
+    } else if (fn_id == ZigLLVMFnIdSqrt) {
+        name = "sqrt";
     } else {
         zig_unreachable();
     }
@@ -1900,7 +1902,7 @@ static LLVMValueRef gen_floor(CodeGen *g, LLVMValueRef val, TypeTableEntry *type
     if (type_entry->id == TypeTableEntryIdInt)
         return val;
 
-    LLVMValueRef floor_fn = get_floor_ceil_fn(g, type_entry, ZigLLVMFnIdFloor);
+    LLVMValueRef floor_fn = get_float_fn(g, type_entry, ZigLLVMFnIdFloor);
     return LLVMBuildCall(g->builder, floor_fn, &val, 1, "");
 }
 
@@ -1908,7 +1910,7 @@ static LLVMValueRef gen_ceil(CodeGen *g, LLVMValueRef val, TypeTableEntry *type_
     if (type_entry->id == TypeTableEntryIdInt)
         return val;
 
-    LLVMValueRef ceil_fn = get_floor_ceil_fn(g, type_entry, ZigLLVMFnIdCeil);
+    LLVMValueRef ceil_fn = get_float_fn(g, type_entry, ZigLLVMFnIdCeil);
     return LLVMBuildCall(g->builder, ceil_fn, &val, 1, "");
 }
 
@@ -3247,10 +3249,12 @@ static LLVMValueRef get_int_builtin_fn(CodeGen *g, TypeTableEntry *int_type, Bui
         fn_name = "cttz";
         key.id = ZigLLVMFnIdCtz;
         key.data.ctz.bit_count = (uint32_t)int_type->data.integral.bit_count;
-    } else {
+    } else if (fn_id == BuiltinFnIdClz) {
         fn_name = "ctlz";
         key.id = ZigLLVMFnIdClz;
         key.data.clz.bit_count = (uint32_t)int_type->data.integral.bit_count;
+    } else {
+        zig_unreachable();
     }
 
     auto existing_entry = g->llvm_fn_table.maybe_get(key);
@@ -4402,6 +4406,13 @@ static LLVMValueRef ir_render_mark_err_ret_trace_ptr(CodeGen *g, IrExecutable *e
     return nullptr;
 }
 
+static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstructionSqrt *instruction) {
+    LLVMValueRef op = ir_llvm_value(g, instruction->op);
+    assert(instruction->base.value.type->id == TypeTableEntryIdFloat);
+    LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdSqrt);
+    return LLVMBuildCall(g->builder, fn_val, &op, 1, "");
+}
+
 static void set_debug_location(CodeGen *g, IrInstruction *instruction) {
     AstNode *source_node = instruction->source_node;
     Scope *scope = instruction->scope;
@@ -4623,6 +4634,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_merge_err_ret_traces(g, executable, (IrInstructionMergeErrRetTraces *)instruction);
         case IrInstructionIdMarkErrRetTracePtr:
             return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction);
+        case IrInstructionIdSqrt:
+            return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction);
     }
     zig_unreachable();
 }
@@ -6109,6 +6122,7 @@ static void define_builtin_fns(CodeGen *g) {
     create_builtin_fn(g, BuiltinFnIdDivFloor, "divFloor", 2);
     create_builtin_fn(g, BuiltinFnIdRem, "rem", 2);
     create_builtin_fn(g, BuiltinFnIdMod, "mod", 2);
+    create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2);
     create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX);
     create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX);
     create_builtin_fn(g, BuiltinFnIdTypeId, "typeId", 1);
src/ir.cpp
@@ -733,6 +733,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionMarkErrRetTraceP
     return IrInstructionIdMarkErrRetTracePtr;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionSqrt *) {
+    return IrInstructionIdSqrt;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -2731,6 +2735,17 @@ static IrInstruction *ir_build_mark_err_ret_trace_ptr(IrBuilder *irb, Scope *sco
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_sqrt(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op) {
+    IrInstructionSqrt *instruction = ir_build_instruction<IrInstructionSqrt>(irb, scope, source_node);
+    instruction->type = type;
+    instruction->op = op;
+
+    if (type != nullptr) ir_ref_instruction(type, irb->current_basic_block);
+    ir_ref_instruction(op, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
 static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) {
     results[ReturnKindUnconditional] = 0;
     results[ReturnKindError] = 0;
@@ -3845,6 +3860,20 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
 
                 return ir_build_bin_op(irb, scope, node, IrBinOpRemMod, arg0_value, arg1_value, true);
             }
+        case BuiltinFnIdSqrt:
+            {
+                AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
+                IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
+                if (arg0_value == irb->codegen->invalid_instruction)
+                    return arg0_value;
+
+                AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
+                IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope);
+                if (arg1_value == irb->codegen->invalid_instruction)
+                    return arg1_value;
+
+                return ir_build_sqrt(irb, scope, node, arg0_value, arg1_value);
+            }
         case BuiltinFnIdTruncate:
             {
                 AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
@@ -18031,6 +18060,68 @@ static TypeTableEntry *ir_analyze_instruction_mark_err_ret_trace_ptr(IrAnalyze *
     return result->value.type;
 }
 
+static TypeTableEntry *ir_analyze_instruction_sqrt(IrAnalyze *ira, IrInstructionSqrt *instruction) {
+    TypeTableEntry *float_type = ir_resolve_type(ira, instruction->type->other);
+    if (type_is_invalid(float_type))
+        return ira->codegen->builtin_types.entry_invalid;
+
+    IrInstruction *op = instruction->op->other;
+    if (type_is_invalid(op->value.type))
+        return ira->codegen->builtin_types.entry_invalid;
+
+    bool ok_type = float_type->id == TypeTableEntryIdNumLitFloat || float_type->id == TypeTableEntryIdFloat;
+    if (!ok_type) {
+        ir_add_error(ira, instruction->type, buf_sprintf("@sqrt does not support type '%s'", buf_ptr(&float_type->name)));
+        return ira->codegen->builtin_types.entry_invalid;
+    }
+
+    IrInstruction *casted_op = ir_implicit_cast(ira, op, float_type);
+    if (type_is_invalid(casted_op->value.type))
+        return ira->codegen->builtin_types.entry_invalid;
+
+    if (instr_is_comptime(casted_op)) {
+        ConstExprValue *val = ir_resolve_const(ira, casted_op, UndefBad);
+        if (!val)
+            return ira->codegen->builtin_types.entry_invalid;
+
+        ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
+
+        if (float_type->id == TypeTableEntryIdNumLitFloat) {
+            bigfloat_sqrt(&out_val->data.x_bigfloat, &val->data.x_bigfloat);
+        } else if (float_type->id == TypeTableEntryIdFloat) {
+            switch (float_type->data.floating.bit_count) {
+                case 32:
+                    out_val->data.x_f32 = sqrtf(val->data.x_f32);
+                    break;
+                case 64:
+                    out_val->data.x_f64 = sqrt(val->data.x_f64);
+                    break;
+                case 128:
+                    f128M_sqrt(&val->data.x_f128, &out_val->data.x_f128);
+                    break;
+                default:
+                    zig_unreachable();
+            }
+        } else {
+            zig_unreachable();
+        }
+
+        return float_type;
+    }
+
+    assert(float_type->id == TypeTableEntryIdFloat);
+    if (float_type->data.floating.bit_count != 32 && float_type->data.floating.bit_count != 64) {
+        ir_add_error(ira, instruction->type, buf_sprintf("compiler TODO: add implementation of sqrt for '%s'", buf_ptr(&float_type->name)));
+        return ira->codegen->builtin_types.entry_invalid;
+    }
+
+    IrInstruction *result = ir_build_sqrt(&ira->new_irb, instruction->base.scope,
+            instruction->base.source_node, nullptr, casted_op);
+    ir_link_new_instruction(result, &instruction->base);
+    result->value.type = float_type;
+    return result->value.type;
+}
+
 static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstruction *instruction) {
     switch (instruction->id) {
         case IrInstructionIdInvalid:
@@ -18278,6 +18369,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
             return ir_analyze_instruction_merge_err_ret_traces(ira, (IrInstructionMergeErrRetTraces *)instruction);
         case IrInstructionIdMarkErrRetTracePtr:
             return ir_analyze_instruction_mark_err_ret_trace_ptr(ira, (IrInstructionMarkErrRetTracePtr *)instruction);
+        case IrInstructionIdSqrt:
+            return ir_analyze_instruction_sqrt(ira, (IrInstructionSqrt *)instruction);
     }
     zig_unreachable();
 }
@@ -18490,6 +18583,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdCoroFree:
         case IrInstructionIdCoroPromise:
         case IrInstructionIdPromiseResultType:
+        case IrInstructionIdSqrt:
             return false;
 
         case IrInstructionIdAsm:
src/ir_print.cpp
@@ -1204,6 +1204,18 @@ static void ir_print_mark_err_ret_trace_ptr(IrPrint *irp, IrInstructionMarkErrRe
     fprintf(irp->f, ")");
 }
 
+static void ir_print_sqrt(IrPrint *irp, IrInstructionSqrt *instruction) {
+    fprintf(irp->f, "@sqrt(");
+    if (instruction->type != nullptr) {
+        ir_print_other_instruction(irp, instruction->type);
+    } else {
+        fprintf(irp->f, "null");
+    }
+    fprintf(irp->f, ",");
+    ir_print_other_instruction(irp, instruction->op);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -1590,6 +1602,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdMarkErrRetTracePtr:
             ir_print_mark_err_ret_trace_ptr(irp, (IrInstructionMarkErrRetTracePtr *)instruction);
             break;
+        case IrInstructionIdSqrt:
+            ir_print_sqrt(irp, (IrInstructionSqrt *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
std/math/x86_64/sqrt.zig
@@ -1,15 +0,0 @@
-pub fn sqrt32(x: f32) f32 {
-    return asm (
-        \\sqrtss %%xmm0, %%xmm0
-        : [ret] "={xmm0}" (-> f32)
-        : [x] "{xmm0}" (x)
-    );
-}
-
-pub fn sqrt64(x: f64) f64 {
-    return asm (
-        \\sqrtsd %%xmm0, %%xmm0
-        : [ret] "={xmm0}" (-> f64)
-        : [x] "{xmm0}" (x)
-    );
-}
std/math/sqrt.zig
@@ -14,26 +14,8 @@ const TypeId = builtin.TypeId;
 pub fn sqrt(x: var) (if (@typeId(@typeOf(x)) == TypeId.Int) @IntType(false, @typeOf(x).bit_count / 2) else @typeOf(x)) {
     const T = @typeOf(x);
     switch (@typeId(T)) {
-        TypeId.FloatLiteral => {
-            return T(sqrt64(x));
-        },
-        TypeId.Float => {
-            switch (T) {
-                f32 => {
-                    switch (builtin.arch) {
-                        builtin.Arch.x86_64 => return @import("x86_64/sqrt.zig").sqrt32(x),
-                        else => return sqrt32(x),
-                    }
-                },
-                f64 => {
-                    switch (builtin.arch) {
-                        builtin.Arch.x86_64 => return @import("x86_64/sqrt.zig").sqrt64(x),
-                        else => return sqrt64(x),
-                    }
-                },
-                else => @compileError("sqrt not implemented for " ++ @typeName(T)),
-            }
-        },
+        TypeId.FloatLiteral => return T(@sqrt(f64, x)), // TODO upgrade to f128
+        TypeId.Float => return @sqrt(T, x),
         TypeId.IntLiteral => comptime {
             if (x > @maxValue(u128)) {
                 @compileError("sqrt not implemented for comptime_int greater than 128 bits");
@@ -43,269 +25,58 @@ pub fn sqrt(x: var) (if (@typeId(@typeOf(x)) == TypeId.Int) @IntType(false, @typ
             }
             return T(sqrt_int(u128, x));
         },
-        TypeId.Int => {
-            return sqrt_int(T, x);
-        },
+        TypeId.Int => return sqrt_int(T, x),
         else => @compileError("sqrt not implemented for " ++ @typeName(T)),
     }
 }
 
-fn sqrt32(x: f32) f32 {
-    const tiny: f32 = 1.0e-30;
-    const sign: i32 = @bitCast(i32, u32(0x80000000));
-    var ix: i32 = @bitCast(i32, x);
-
-    if ((ix & 0x7F800000) == 0x7F800000) {
-        return x * x + x;   // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = snan
-    }
-
-    // zero
-    if (ix <= 0) {
-        if (ix & ~sign == 0) {
-            return x;       // sqrt (+-0) = +-0
-        }
-        if (ix < 0) {
-            return math.snan(f32);
-        }
-    }
-
-    // normalize
-    var m = ix >> 23;
-    if (m == 0) {
-        // subnormal
-        var i: i32 = 0;
-        while (ix & 0x00800000 == 0) : (i += 1) {
-            ix <<= 1;
-        }
-        m -= i - 1;
-    }
-
-    m -= 127;               // unbias exponent
-    ix = (ix & 0x007FFFFF) | 0x00800000;
-
-    if (m & 1 != 0) {       // odd m, double x to even
-        ix += ix;
-    }
-
-    m >>= 1;                // m = [m / 2]
-
-    // sqrt(x) bit by bit
-    ix += ix;
-    var q: i32 = 0;              // q = sqrt(x)
-    var s: i32 = 0;
-    var r: i32 = 0x01000000;     // r = moving bit right -> left
-
-    while (r != 0) {
-        const t = s + r;
-        if (t <= ix) {
-            s = t + r;
-            ix -= t;
-            q += r;
-        }
-        ix += ix;
-        r >>= 1;
-    }
-
-    // floating add to find rounding direction
-    if (ix != 0) {
-        var z = 1.0 - tiny;     // inexact
-        if (z >= 1.0) {
-            z = 1.0 + tiny;
-            if (z > 1.0) {
-                q += 2;
-            } else {
-                if (q & 1 != 0) {
-                    q += 1;
-                }
-            }
-        }
-    }
-
-    ix = (q >> 1) + 0x3f000000;
-    ix += m << 23;
-    return @bitCast(f32, ix);
-}
-
-// NOTE: The original code is full of implicit signed -> unsigned assumptions and u32 wraparound
-// behaviour. Most intermediate i32 values are changed to u32 where appropriate but there are
-// potentially some edge cases remaining that are not handled in the same way.
-fn sqrt64(x: f64) f64 {
-    const tiny: f64 = 1.0e-300;
-    const sign: u32 = 0x80000000;
-    const u = @bitCast(u64, x);
-
-    var ix0 = u32(u >> 32);
-    var ix1 = u32(u & 0xFFFFFFFF);
-
-    // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = nan
-    if (ix0 & 0x7FF00000 == 0x7FF00000) {
-        return x * x + x;
-    }
-
-    // sqrt(+-0) = +-0
-    if (x == 0.0) {
-        return x;
-    }
-    // sqrt(-ve) = snan
-    if (ix0 & sign != 0) {
-        return math.snan(f64);
-    }
-
-    // normalize x
-    var m = i32(ix0 >> 20);
-    if (m == 0) {
-        // subnormal
-        while (ix0 == 0) {
-            m -= 21;
-            ix0 |= ix1 >> 11;
-            ix1 <<= 21;
-        }
-
-        // subnormal
-        var i: u32 = 0;
-        while (ix0 & 0x00100000 == 0) : (i += 1) {
-            ix0 <<= 1;
-        }
-        m -= i32(i) - 1;
-        ix0 |= ix1 >> u5(32 - i);
-        ix1 <<= u5(i);
-    }
-
-    // unbias exponent
-    m -= 1023;
-    ix0 = (ix0 & 0x000FFFFF) | 0x00100000;
-    if (m & 1 != 0) {
-        ix0 += ix0 + (ix1 >> 31);
-        ix1 = ix1 +% ix1;
-    }
-    m >>= 1;
-
-    // sqrt(x) bit by bit
-    ix0 += ix0 + (ix1 >> 31);
-    ix1 = ix1 +% ix1;
-
-    var q: u32 = 0;
-    var q1: u32 = 0;
-    var s0: u32 = 0;
-    var s1: u32 = 0;
-    var r: u32 = 0x00200000;
-    var t: u32 = undefined;
-    var t1: u32 = undefined;
-
-    while (r != 0) {
-        t = s0 +% r;
-        if (t <= ix0) {
-            s0 = t + r;
-            ix0 -= t;
-            q += r;
-        }
-        ix0 = ix0 +% ix0 +% (ix1 >> 31);
-        ix1 = ix1 +% ix1;
-        r >>= 1;
-    }
-
-    r = sign;
-    while (r != 0) {
-        t = s1 +% r;
-        t = s0;
-        if (t < ix0 or (t == ix0 and t1 <= ix1)) {
-            s1 = t1 +% r;
-            if (t1 & sign == sign and s1 & sign == 0) {
-                s0 += 1;
-            }
-            ix0 -= t;
-            if (ix1 < t1) {
-                ix0 -= 1;
-            }
-            ix1 = ix1 -% t1;
-            q1 += r;
-        }
-        ix0 = ix0 +% ix0 +% (ix1 >> 31);
-        ix1 = ix1 +% ix1;
-        r >>= 1;
-    }
-
-    // rounding direction
-    if (ix0 | ix1 != 0) {
-        var z = 1.0 - tiny;   // raise inexact
-        if (z >= 1.0) {
-            z = 1.0 + tiny;
-            if (q1 == 0xFFFFFFFF) {
-                q1 = 0;
-                q += 1;
-            } else if (z > 1.0) {
-                if (q1 == 0xFFFFFFFE) {
-                    q += 1;
-                }
-                q1 += 2;
-            } else {
-                q1 += q1 & 1;
-            }
-        }
-    }
-
-    ix0 = (q >> 1) + 0x3FE00000;
-    ix1 = q1 >> 1;
-    if (q & 1 != 0) {
-        ix1 |= 0x80000000;
-    }
-
-    // NOTE: musl here appears to rely on signed twos-complement wraparound. +% has the same
-    // behaviour at least.
-    var iix0 = i32(ix0);
-    iix0 = iix0 +% (m << 20);
-
-    const uz = (u64(iix0) << 32) | ix1;
-    return @bitCast(f64, uz);
-}
-
 test "math.sqrt" {
-    assert(sqrt(f32(0.0)) == sqrt32(0.0));
-    assert(sqrt(f64(0.0)) == sqrt64(0.0));
+    assert(sqrt(f32(0.0)) == @sqrt(f32, 0.0));
+    assert(sqrt(f64(0.0)) == @sqrt(f64, 0.0));
 }
 
 test "math.sqrt32" {
     const epsilon = 0.000001;
 
-    assert(sqrt32(0.0) == 0.0);
-    assert(math.approxEq(f32, sqrt32(2.0), 1.414214, epsilon));
-    assert(math.approxEq(f32, sqrt32(3.6), 1.897367, epsilon));
-    assert(sqrt32(4.0) == 2.0);
-    assert(math.approxEq(f32, sqrt32(7.539840), 2.745877, epsilon));
-    assert(math.approxEq(f32, sqrt32(19.230934), 4.385309, epsilon));
-    assert(sqrt32(64.0) == 8.0);
-    assert(math.approxEq(f32, sqrt32(64.1), 8.006248, epsilon));
-    assert(math.approxEq(f32, sqrt32(8942.230469), 94.563370, epsilon));
+    assert(@sqrt(f32, 0.0) == 0.0);
+    assert(math.approxEq(f32, @sqrt(f32, 2.0), 1.414214, epsilon));
+    assert(math.approxEq(f32, @sqrt(f32, 3.6), 1.897367, epsilon));
+    assert(@sqrt(f32, 4.0) == 2.0);
+    assert(math.approxEq(f32, @sqrt(f32, 7.539840), 2.745877, epsilon));
+    assert(math.approxEq(f32, @sqrt(f32, 19.230934), 4.385309, epsilon));
+    assert(@sqrt(f32, 64.0) == 8.0);
+    assert(math.approxEq(f32, @sqrt(f32, 64.1), 8.006248, epsilon));
+    assert(math.approxEq(f32, @sqrt(f32, 8942.230469), 94.563370, epsilon));
 }
 
 test "math.sqrt64" {
     const epsilon = 0.000001;
 
-    assert(sqrt64(0.0) == 0.0);
-    assert(math.approxEq(f64, sqrt64(2.0), 1.414214, epsilon));
-    assert(math.approxEq(f64, sqrt64(3.6), 1.897367, epsilon));
-    assert(sqrt64(4.0) == 2.0);
-    assert(math.approxEq(f64, sqrt64(7.539840), 2.745877, epsilon));
-    assert(math.approxEq(f64, sqrt64(19.230934), 4.385309, epsilon));
-    assert(sqrt64(64.0) == 8.0);
-    assert(math.approxEq(f64, sqrt64(64.1), 8.006248, epsilon));
-    assert(math.approxEq(f64, sqrt64(8942.230469), 94.563367, epsilon));
+    assert(@sqrt(f64, 0.0) == 0.0);
+    assert(math.approxEq(f64, @sqrt(f64, 2.0), 1.414214, epsilon));
+    assert(math.approxEq(f64, @sqrt(f64, 3.6), 1.897367, epsilon));
+    assert(@sqrt(f64, 4.0) == 2.0);
+    assert(math.approxEq(f64, @sqrt(f64, 7.539840), 2.745877, epsilon));
+    assert(math.approxEq(f64, @sqrt(f64, 19.230934), 4.385309, epsilon));
+    assert(@sqrt(f64, 64.0) == 8.0);
+    assert(math.approxEq(f64, @sqrt(f64, 64.1), 8.006248, epsilon));
+    assert(math.approxEq(f64, @sqrt(f64, 8942.230469), 94.563367, epsilon));
 }
 
 test "math.sqrt32.special" {
-    assert(math.isPositiveInf(sqrt32(math.inf(f32))));
-    assert(sqrt32(0.0) == 0.0);
-    assert(sqrt32(-0.0) == -0.0);
-    assert(math.isNan(sqrt32(-1.0)));
-    assert(math.isNan(sqrt32(math.nan(f32))));
+    assert(math.isPositiveInf(@sqrt(f32, math.inf(f32))));
+    assert(@sqrt(f32, 0.0) == 0.0);
+    assert(@sqrt(f32, -0.0) == -0.0);
+    assert(math.isNan(@sqrt(f32, -1.0)));
+    assert(math.isNan(@sqrt(f32, math.nan(f32))));
 }
 
 test "math.sqrt64.special" {
-    assert(math.isPositiveInf(sqrt64(math.inf(f64))));
-    assert(sqrt64(0.0) == 0.0);
-    assert(sqrt64(-0.0) == -0.0);
-    assert(math.isNan(sqrt64(-1.0)));
-    assert(math.isNan(sqrt64(math.nan(f64))));
+    assert(math.isPositiveInf(@sqrt(f64, math.inf(f64))));
+    assert(@sqrt(f64, 0.0) == 0.0);
+    assert(@sqrt(f64, -0.0) == -0.0);
+    assert(math.isNan(@sqrt(f64, -1.0)));
+    assert(math.isNan(@sqrt(f64, math.nan(f64))));
 }
 
 fn sqrt_int(comptime T: type, value: T) @IntType(false, T.bit_count / 2) {
std/special/builtin.zig
@@ -194,3 +194,212 @@ fn isNan(comptime T: type, bits: T) bool {
         unreachable;
     }
 }
+
+// NOTE: The original code is full of implicit signed -> unsigned assumptions and u32 wraparound
+// behaviour. Most intermediate i32 values are changed to u32 where appropriate but there are
+// potentially some edge cases remaining that are not handled in the same way.
+export fn sqrt(x: f64) f64 {
+    const tiny: f64 = 1.0e-300;
+    const sign: u32 = 0x80000000;
+    const u = @bitCast(u64, x);
+
+    var ix0 = u32(u >> 32);
+    var ix1 = u32(u & 0xFFFFFFFF);
+
+    // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = nan
+    if (ix0 & 0x7FF00000 == 0x7FF00000) {
+        return x * x + x;
+    }
+
+    // sqrt(+-0) = +-0
+    if (x == 0.0) {
+        return x;
+    }
+    // sqrt(-ve) = snan
+    if (ix0 & sign != 0) {
+        return math.snan(f64);
+    }
+
+    // normalize x
+    var m = i32(ix0 >> 20);
+    if (m == 0) {
+        // subnormal
+        while (ix0 == 0) {
+            m -= 21;
+            ix0 |= ix1 >> 11;
+            ix1 <<= 21;
+        }
+
+        // subnormal
+        var i: u32 = 0;
+        while (ix0 & 0x00100000 == 0) : (i += 1) {
+            ix0 <<= 1;
+        }
+        m -= i32(i) - 1;
+        ix0 |= ix1 >> u5(32 - i);
+        ix1 <<= u5(i);
+    }
+
+    // unbias exponent
+    m -= 1023;
+    ix0 = (ix0 & 0x000FFFFF) | 0x00100000;
+    if (m & 1 != 0) {
+        ix0 += ix0 + (ix1 >> 31);
+        ix1 = ix1 +% ix1;
+    }
+    m >>= 1;
+
+    // sqrt(x) bit by bit
+    ix0 += ix0 + (ix1 >> 31);
+    ix1 = ix1 +% ix1;
+
+    var q: u32 = 0;
+    var q1: u32 = 0;
+    var s0: u32 = 0;
+    var s1: u32 = 0;
+    var r: u32 = 0x00200000;
+    var t: u32 = undefined;
+    var t1: u32 = undefined;
+
+    while (r != 0) {
+        t = s0 +% r;
+        if (t <= ix0) {
+            s0 = t + r;
+            ix0 -= t;
+            q += r;
+        }
+        ix0 = ix0 +% ix0 +% (ix1 >> 31);
+        ix1 = ix1 +% ix1;
+        r >>= 1;
+    }
+
+    r = sign;
+    while (r != 0) {
+        t = s1 +% r;
+        t = s0;
+        if (t < ix0 or (t == ix0 and t1 <= ix1)) {
+            s1 = t1 +% r;
+            if (t1 & sign == sign and s1 & sign == 0) {
+                s0 += 1;
+            }
+            ix0 -= t;
+            if (ix1 < t1) {
+                ix0 -= 1;
+            }
+            ix1 = ix1 -% t1;
+            q1 += r;
+        }
+        ix0 = ix0 +% ix0 +% (ix1 >> 31);
+        ix1 = ix1 +% ix1;
+        r >>= 1;
+    }
+
+    // rounding direction
+    if (ix0 | ix1 != 0) {
+        var z = 1.0 - tiny;   // raise inexact
+        if (z >= 1.0) {
+            z = 1.0 + tiny;
+            if (q1 == 0xFFFFFFFF) {
+                q1 = 0;
+                q += 1;
+            } else if (z > 1.0) {
+                if (q1 == 0xFFFFFFFE) {
+                    q += 1;
+                }
+                q1 += 2;
+            } else {
+                q1 += q1 & 1;
+            }
+        }
+    }
+
+    ix0 = (q >> 1) + 0x3FE00000;
+    ix1 = q1 >> 1;
+    if (q & 1 != 0) {
+        ix1 |= 0x80000000;
+    }
+
+    // NOTE: musl here appears to rely on signed twos-complement wraparound. +% has the same
+    // behaviour at least.
+    var iix0 = i32(ix0);
+    iix0 = iix0 +% (m << 20);
+
+    const uz = (u64(iix0) << 32) | ix1;
+    return @bitCast(f64, uz);
+}
+
+export fn sqrtf(x: f32) f32 {
+    const tiny: f32 = 1.0e-30;
+    const sign: i32 = @bitCast(i32, u32(0x80000000));
+    var ix: i32 = @bitCast(i32, x);
+
+    if ((ix & 0x7F800000) == 0x7F800000) {
+        return x * x + x;   // sqrt(nan) = nan, sqrt(+inf) = +inf, sqrt(-inf) = snan
+    }
+
+    // zero
+    if (ix <= 0) {
+        if (ix & ~sign == 0) {
+            return x;       // sqrt (+-0) = +-0
+        }
+        if (ix < 0) {
+            return math.snan(f32);
+        }
+    }
+
+    // normalize
+    var m = ix >> 23;
+    if (m == 0) {
+        // subnormal
+        var i: i32 = 0;
+        while (ix & 0x00800000 == 0) : (i += 1) {
+            ix <<= 1;
+        }
+        m -= i - 1;
+    }
+
+    m -= 127;               // unbias exponent
+    ix = (ix & 0x007FFFFF) | 0x00800000;
+
+    if (m & 1 != 0) {       // odd m, double x to even
+        ix += ix;
+    }
+
+    m >>= 1;                // m = [m / 2]
+
+    // sqrt(x) bit by bit
+    ix += ix;
+    var q: i32 = 0;              // q = sqrt(x)
+    var s: i32 = 0;
+    var r: i32 = 0x01000000;     // r = moving bit right -> left
+
+    while (r != 0) {
+        const t = s + r;
+        if (t <= ix) {
+            s = t + r;
+            ix -= t;
+            q += r;
+        }
+        ix += ix;
+        r >>= 1;
+    }
+
+    // floating add to find rounding direction
+    if (ix != 0) {
+        var z = 1.0 - tiny;     // inexact
+        if (z >= 1.0) {
+            z = 1.0 + tiny;
+            if (z > 1.0) {
+                q += 2;
+            } else {
+                if (q & 1 != 0) {
+                    q += 1;
+                }
+            }
+        }
+    }
+
+    ix = (q >> 1) + 0x3f000000;
+    ix += m << 23;
+    return @bitCast(f32, ix);
+}
test/cases/math.zig
@@ -402,3 +402,19 @@ test "comptime float rem int" {
         assert(x == 1.0);
     }
 }
+
+test "@sqrt" {
+    testSqrt(f64, 12.0);
+    comptime testSqrt(f64, 12.0);
+    testSqrt(f32, 13.0);
+    comptime testSqrt(f32, 13.0);
+
+    const x = 14.0;
+    const y = x * x;
+    const z = @sqrt(@typeOf(y), y);
+    comptime assert(z == x);
+}
+
+fn testSqrt(comptime T: type, x: T) void {
+    assert(@sqrt(T, x * x) == x);
+}
CMakeLists.txt
@@ -498,7 +498,6 @@ set(ZIG_STD_FILES
     "math/tan.zig"
     "math/tanh.zig"
     "math/trunc.zig"
-    "math/x86_64/sqrt.zig"
     "mem.zig"
     "net.zig"
     "os/child_process.zig"