Commit 73a0b5441b

David Rubin <david@vortan.dev>
2025-08-28 16:46:12
AstGen: forward result type through unary float builtins
Uses a new `float_op_result_ty` ZIR instruction tag.
1 parent a31950a
lib/std/zig/AstGen.zig
@@ -9390,24 +9390,25 @@ fn builtinCall(
         .embed_file            => return simpleUnOp(gz, scope, ri, node, .{ .rl = .{ .coerced_ty = .slice_const_u8_type } },   params[0], .embed_file),
         .error_name            => return simpleUnOp(gz, scope, ri, node, .{ .rl = .{ .coerced_ty = .anyerror_type } },         params[0], .error_name),
         .set_runtime_safety    => return simpleUnOp(gz, scope, ri, node, coerced_bool_ri,                                      params[0], .set_runtime_safety),
-        .sqrt                  => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .sqrt),
-        .sin                   => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .sin),
-        .cos                   => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .cos),
-        .tan                   => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .tan),
-        .exp                   => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .exp),
-        .exp2                  => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .exp2),
-        .log                   => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .log),
-        .log2                  => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .log2),
-        .log10                 => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .log10),
         .abs                   => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .abs),
-        .floor                 => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .floor),
-        .ceil                  => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .ceil),
-        .trunc                 => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .trunc),
-        .round                 => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .round),
         .tag_name              => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .tag_name),
         .type_name             => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .type_name),
         .Frame                 => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none },                                     params[0], .frame_type),
 
+        .sqrt  => return floatUnOp(gz, scope, ri, node, params[0], .sqrt),
+        .sin   => return floatUnOp(gz, scope, ri, node, params[0], .sin),
+        .cos   => return floatUnOp(gz, scope, ri, node, params[0], .cos),
+        .tan   => return floatUnOp(gz, scope, ri, node, params[0], .tan),
+        .exp   => return floatUnOp(gz, scope, ri, node, params[0], .exp),
+        .exp2  => return floatUnOp(gz, scope, ri, node, params[0], .exp2),
+        .log   => return floatUnOp(gz, scope, ri, node, params[0], .log),
+        .log2  => return floatUnOp(gz, scope, ri, node, params[0], .log2),
+        .log10 => return floatUnOp(gz, scope, ri, node, params[0], .log10),
+        .floor => return floatUnOp(gz, scope, ri, node, params[0], .floor),
+        .ceil  => return floatUnOp(gz, scope, ri, node, params[0], .ceil),
+        .trunc => return floatUnOp(gz, scope, ri, node, params[0], .trunc),
+        .round => return floatUnOp(gz, scope, ri, node, params[0], .round),
+
         .int_from_float => return typeCast(gz, scope, ri, node, params[0], .int_from_float, builtin_name),
         .float_from_int => return typeCast(gz, scope, ri, node, params[0], .float_from_int, builtin_name),
         .ptr_from_int   => return typeCast(gz, scope, ri, node, params[0], .ptr_from_int, builtin_name),
@@ -9860,6 +9861,26 @@ fn simpleUnOp(
     return rvalue(gz, ri, result, node);
 }
 
+fn floatUnOp(
+    gz: *GenZir,
+    scope: *Scope,
+    ri: ResultInfo,
+    node: Ast.Node.Index,
+    operand_node: Ast.Node.Index,
+    tag: Zir.Inst.Tag,
+) InnerError!Zir.Inst.Ref {
+    const result_type = try ri.rl.resultType(gz, node);
+    const operand_ri: ResultInfo.Loc = if (result_type) |rt| .{
+        .ty = try gz.addExtendedPayload(.float_op_result_ty, Zir.Inst.UnNode{
+            .node = gz.nodeIndexToRelative(node),
+            .operand = rt,
+        }),
+    } else .none;
+    const operand = try expr(gz, scope, .{ .rl = operand_ri }, operand_node);
+    const result = try gz.addUnNode(tag, operand, node);
+    return rvalue(gz, ri, result, node);
+}
+
 fn negation(
     gz: *GenZir,
     scope: *Scope,
lib/std/zig/AstRlAnnotate.zig
@@ -889,6 +889,19 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast.
         .frame_address => return true,
         // These builtins take a single argument with a known result type, but do not consume their
         // result pointer.
+        .sqrt,
+        .sin,
+        .cos,
+        .tan,
+        .exp,
+        .exp2,
+        .log,
+        .log2,
+        .log10,
+        .floor,
+        .ceil,
+        .trunc,
+        .round,
         .size_of,
         .bit_size_of,
         .align_of,
@@ -918,20 +931,7 @@ fn builtinCall(astrl: *AstRlAnnotate, block: ?*Block, ri: ResultInfo, node: Ast.
         // result pointer.
         .int_from_ptr,
         .int_from_enum,
-        .sqrt,
-        .sin,
-        .cos,
-        .tan,
-        .exp,
-        .exp2,
-        .log,
-        .log2,
-        .log10,
         .abs,
-        .floor,
-        .ceil,
-        .trunc,
-        .round,
         .tag_name,
         .type_name,
         .Frame,
lib/std/zig/Zir.zig
@@ -2111,6 +2111,11 @@ pub const Inst = struct {
         /// This instruction is always `noreturn`, however, it is not considered as such by ZIR-level queries. This allows AstGen to assume that
         /// any code may have gone here, avoiding false-positive "unreachable code" errors.
         astgen_error,
+        /// Given a type, strips away any error unions or optionals stacked
+        /// on top and returns the base type. That base type must be a float.
+        /// For example: Provided with error{Foo}!?f64, returns f64.
+        /// `operand` is `operand: Air.Inst.Ref`.
+        float_op_result_ty,
 
         pub const InstData = struct {
             opcode: Extended,
@@ -4436,6 +4441,7 @@ fn findTrackableInner(
                 .tuple_decl,
                 .dbg_empty_stmt,
                 .astgen_error,
+                .float_op_result_ty,
                 => return,
 
                 // `@TypeOf` has a body.
src/Air.zig
@@ -1154,6 +1154,10 @@ pub const Inst = struct {
         pub fn fromValue(v: Value) Ref {
             return .fromIntern(v.toIntern());
         }
+
+        pub fn fromType(t: Type) Ref {
+            return .fromIntern(t.toIntern());
+        }
     };
 
     /// All instructions have an 8-byte payload, which is contained within
src/print_zir.zig
@@ -567,6 +567,7 @@ const Writer = struct {
             .work_group_size,
             .work_group_id,
             .branch_hint,
+            .float_op_result_ty,
             => {
                 const inst_data = self.code.extraData(Zir.Inst.UnNode, extended.operand).data;
                 try self.writeInstRef(stream, inst_data.operand);
src/Sema.zig
@@ -1468,6 +1468,7 @@ fn analyzeBodyInner(
                         continue;
                     },
                     .astgen_error => return error.AnalysisFail,
+                    .float_op_result_ty => try sema.zirFloatOpResultType(block, extended),
                 };
             },
 
@@ -25922,6 +25923,28 @@ fn zirBranchHint(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
     }
 }
 
+fn zirFloatOpResultType(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
+    const pt = sema.pt;
+    const zcu = pt.zcu;
+    const extra = sema.code.extraData(Zir.Inst.UnNode, extended.operand).data;
+    const operand_src = block.builtinCallArgSrc(extra.node, 0);
+
+    const raw_ty = try sema.resolveTypeOrPoison(block, operand_src, extra.operand) orelse return .generic_poison_type;
+    const float_ty = raw_ty.optEuBaseType(zcu);
+
+    switch (float_ty.scalarType(zcu).zigTypeTag(zcu)) {
+        .float, .comptime_float => {},
+        else => return sema.fail(
+            block,
+            operand_src,
+            "expected vector of floats or float type, found '{f}'",
+            .{float_ty.fmt(sema.pt)},
+        ),
+    }
+
+    return .fromType(float_ty);
+}
+
 fn requireRuntimeBlock(sema: *Sema, block: *Block, src: LazySrcLoc, runtime_src: ?LazySrcLoc) !void {
     if (block.isComptime()) {
         const msg, const fail_block = msg: {
test/behavior/floatop.zig
@@ -1737,3 +1737,31 @@ test "comptime calls are only memoized when float arguments are bit-for-bit equa
     try comptime testMemoization();
     try comptime testVectorMemoization(@Vector(4, f32));
 }
+
+test "result location forwarded through unary float builtins" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;
+
+    const S = struct {
+        var x: u32 = 10;
+    };
+
+    var y: f64 = 0.0;
+    y = @sqrt(@floatFromInt(S.x));
+    y = @sin(@floatFromInt(S.x));
+    y = @cos(@floatFromInt(S.x));
+    y = @tan(@floatFromInt(S.x));
+    y = @exp(@floatFromInt(S.x));
+    y = @exp2(@floatFromInt(S.x));
+    y = @log(@floatFromInt(S.x));
+    y = @log2(@floatFromInt(S.x));
+    y = @log10(@floatFromInt(S.x));
+    y = @floor(@floatFromInt(S.x));
+    y = @ceil(@floatFromInt(S.x));
+    y = @trunc(@floatFromInt(S.x));
+    y = @round(@floatFromInt(S.x));
+}