Commit 7d0de54ad4

Andrew Kelley <andrew@ziglang.org>
2021-08-07 04:40:55
stage2: fix return pointer result locations
* Introduce `ret_load` ZIR instruction which does return semantics based on a corresponding `ret_ptr` instruction. If the return type of the function has storage for the return type, it simply returns. However if the return type of the function is by-value, it loads the return value from the `ret_ptr` allocation and returns that. * AstGen: improve `finishThenElseBlock` to not emit break instructions after a return instruction in the same block. * Sema: `ret_ptr` instruction works correctly in comptime contexts. Same with `alloc_mut`. The test case with a recursive inline function having an implicitly comptime return value now has a runtime return value because of the fact that it calls a function in a non-comptime context.
1 parent e974d4c
Changed files (5)
src/AstGen.zig
@@ -2131,6 +2131,7 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: ast.Node.Index) Inner
             .condbr_inline,
             .compile_error,
             .ret_node,
+            .ret_load,
             .ret_coerce,
             .ret_err_value,
             .@"unreachable",
@@ -4791,11 +4792,10 @@ fn finishThenElseBlock(
     const strat = rl.strategy(block_scope);
     switch (strat.tag) {
         .break_void => {
-            if (!parent_gz.refIsNoReturn(then_result)) {
+            if (!then_scope.endsWithNoReturn()) {
                 _ = try then_scope.addBreak(break_tag, then_break_block, .void_value);
             }
-            const elide_else = if (else_result != .none) parent_gz.refIsNoReturn(else_result) else false;
-            if (!elide_else) {
+            if (!else_scope.endsWithNoReturn()) {
                 _ = try else_scope.addBreak(break_tag, main_block, .void_value);
             }
             assert(!strat.elide_store_to_block_ptr_instructions);
@@ -4803,11 +4803,11 @@ fn finishThenElseBlock(
             return indexToRef(main_block);
         },
         .break_operand => {
-            if (!parent_gz.refIsNoReturn(then_result)) {
+            if (!then_scope.endsWithNoReturn()) {
                 _ = try then_scope.addBreak(break_tag, then_break_block, then_result);
             }
             if (else_result != .none) {
-                if (!parent_gz.refIsNoReturn(else_result)) {
+                if (!else_scope.endsWithNoReturn()) {
                     _ = try else_scope.addBreak(break_tag, main_block, else_result);
                 }
             } else {
@@ -6236,7 +6236,7 @@ fn ret(gz: *GenZir, scope: *Scope, node: ast.Node.Index) InnerError!Zir.Inst.Ref
             // Value is always an error. Emit both error defers and regular defers.
             const err_code = try gz.addUnNode(.err_union_code, operand, node);
             try genDefers(gz, defer_outer, scope, .{ .both = err_code });
-            _ = try gz.addUnNode(.ret_node, operand, node);
+            try gz.addRet(rl, operand, node);
             return Zir.Inst.Ref.unreachable_value;
         },
         .maybe => {
@@ -6244,7 +6244,7 @@ fn ret(gz: *GenZir, scope: *Scope, node: ast.Node.Index) InnerError!Zir.Inst.Ref
             if (!defer_counts.have_err) {
                 // Only regular defers; no branch needed.
                 try genDefers(gz, defer_outer, scope, .normal_only);
-                _ = try gz.addUnNode(.ret_node, operand, node);
+                try gz.addRet(rl, operand, node);
                 return Zir.Inst.Ref.unreachable_value;
             }
 
@@ -6256,7 +6256,7 @@ fn ret(gz: *GenZir, scope: *Scope, node: ast.Node.Index) InnerError!Zir.Inst.Ref
             defer then_scope.instructions.deinit(astgen.gpa);
 
             try genDefers(&then_scope, defer_outer, scope, .normal_only);
-            _ = try then_scope.addUnNode(.ret_node, operand, node);
+            try then_scope.addRet(rl, operand, node);
 
             var else_scope = gz.makeSubBlock(scope);
             defer else_scope.instructions.deinit(astgen.gpa);
@@ -6265,7 +6265,7 @@ fn ret(gz: *GenZir, scope: *Scope, node: ast.Node.Index) InnerError!Zir.Inst.Ref
                 .both = try else_scope.addUnNode(.err_union_code, operand, node),
             };
             try genDefers(&else_scope, defer_outer, scope, which_ones);
-            _ = try else_scope.addUnNode(.ret_node, operand, node);
+            try else_scope.addRet(rl, operand, node);
 
             try setCondBrPayload(condbr, is_non_err, &then_scope, &else_scope);
 
@@ -9003,6 +9003,14 @@ const GenZir = struct {
         used: bool = false,
     };
 
+    fn endsWithNoReturn(gz: GenZir) bool {
+        const tags = gz.astgen.instructions.items(.tag);
+        if (gz.instructions.items.len == 0) return false;
+        const last_inst = gz.instructions.items[gz.instructions.items.len - 1];
+        return tags[last_inst].isNoReturn();
+    }
+
+    /// TODO all uses of this should be replaced with uses of `endsWithNoReturn`.
     fn refIsNoReturn(gz: GenZir, inst_ref: Zir.Inst.Ref) bool {
         if (inst_ref == .unreachable_value) return true;
         if (refToIndex(inst_ref)) |inst_index| {
@@ -9977,6 +9985,14 @@ const GenZir = struct {
         gz.instructions.appendAssumeCapacity(new_index);
         return new_index;
     }
+
+    fn addRet(gz: *GenZir, rl: ResultLoc, operand: Zir.Inst.Ref, node: ast.Node.Index) !void {
+        switch (rl) {
+            .ptr => |ret_ptr| _ = try gz.addUnNode(.ret_load, ret_ptr, node),
+            .ty => _ = try gz.addUnNode(.ret_node, operand, node),
+            else => unreachable,
+        }
+    }
 };
 
 /// This can only be for short-lived references; the memory becomes invalidated
src/Sema.zig
@@ -366,6 +366,7 @@ pub fn analyzeBody(
             .compile_error  => return sema.zirCompileError(block, inst),
             .ret_coerce     => return sema.zirRetCoerce(block, inst),
             .ret_node       => return sema.zirRetNode(block, inst),
+            .ret_load       => return sema.zirRetLoad(block, inst),
             .ret_err_value  => return sema.zirRetErrValue(block, inst),
             .@"unreachable" => return sema.zirUnreachable(block, inst),
             .repeat         => return sema.zirRepeat(block, inst),
@@ -718,8 +719,8 @@ fn resolveMaybeUndefValAllowVariables(
     if (try sema.typeHasOnePossibleValue(block, src, sema.typeOf(inst))) |opv| {
         return opv;
     }
-
-    switch (sema.air_instructions.items(.tag)[i]) {
+    const air_tags = sema.air_instructions.items(.tag);
+    switch (air_tags[i]) {
         .constant => {
             const ty_pl = sema.air_instructions.items(.data)[i].ty_pl;
             return sema.air_values.items[ty_pl.payload];
@@ -1248,6 +1249,11 @@ fn zirRetPtr(
 
     const src: LazySrcLoc = .{ .node_offset = @bitCast(i32, extended.operand) };
     try sema.requireFunctionBlock(block, src);
+
+    if (block.is_comptime) {
+        return sema.analyzeComptimeAlloc(block, sema.fn_ret_ty);
+    }
+
     const ptr_type = try Module.simplePtrType(sema.arena, sema.fn_ret_ty, true, .One);
     return block.addTy(.alloc, ptr_type);
 }
@@ -1375,21 +1381,7 @@ fn zirAllocComptime(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) Comp
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const ty_src: LazySrcLoc = .{ .node_offset_var_decl_ty = inst_data.src_node };
     const var_type = try sema.resolveType(block, ty_src, inst_data.operand);
-    const ptr_type = try Module.simplePtrType(sema.arena, var_type, true, .One);
-
-    var anon_decl = try block.startAnonDecl();
-    defer anon_decl.deinit();
-    const decl = try anon_decl.finish(
-        try var_type.copy(anon_decl.arena()),
-        // AstGen guarantees there will be a store before the first load, so we put a value
-        // here indicating there is no valid value.
-        Value.initTag(.unreachable_value),
-    );
-    try sema.mod.declareDeclDependency(sema.owner_decl, decl);
-    return sema.addConstant(ptr_type, try Value.Tag.decl_ref_mut.create(sema.arena, .{
-        .runtime_index = block.runtime_index,
-        .decl = decl,
-    }));
+    return sema.analyzeComptimeAlloc(block, var_type);
 }
 
 fn zirAllocInferredComptime(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -1419,6 +1411,9 @@ fn zirAllocMut(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileEr
     const var_decl_src = inst_data.src();
     const ty_src: LazySrcLoc = .{ .node_offset_var_decl_ty = inst_data.src_node };
     const var_type = try sema.resolveType(block, ty_src, inst_data.operand);
+    if (block.is_comptime) {
+        return sema.analyzeComptimeAlloc(block, var_type);
+    }
     try sema.validateVarType(block, ty_src, var_type);
     const ptr_type = try Module.simplePtrType(sema.arena, var_type, true, .One);
     try sema.requireRuntimeBlock(block, var_decl_src);
@@ -6280,6 +6275,21 @@ fn zirRetNode(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileErr
     return sema.analyzeRet(block, operand, src, false);
 }
 
+fn zirRetLoad(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileError!Zir.Inst.Index {
+    const tracy = trace(@src());
+    defer tracy.end();
+
+    const inst_data = sema.code.instructions.items(.data)[inst].un_node;
+    const src = inst_data.src();
+    // TODO: when implementing functions that accept a result location pointer,
+    // this logic will be updated to only do a load in case that the function's return
+    // type in fact does not need a result location pointer. Until then we assume
+    // the `ret_ptr` is the same as an `alloc` and do a load here.
+    const ret_ptr = sema.resolveInst(inst_data.operand);
+    const operand = try sema.analyzeLoad(block, src, ret_ptr, src);
+    return sema.analyzeRet(block, operand, src, false);
+}
+
 fn analyzeRet(
     sema: *Sema,
     block: *Scope.Block,
@@ -9416,3 +9426,25 @@ fn isComptimeKnown(
 ) !bool {
     return (try sema.resolveMaybeUndefVal(block, src, inst)) != null;
 }
+
+fn analyzeComptimeAlloc(
+    sema: *Sema,
+    block: *Scope.Block,
+    var_type: Type,
+) CompileError!Air.Inst.Ref {
+    const ptr_type = try Module.simplePtrType(sema.arena, var_type, true, .One);
+
+    var anon_decl = try block.startAnonDecl();
+    defer anon_decl.deinit();
+    const decl = try anon_decl.finish(
+        try var_type.copy(anon_decl.arena()),
+        // AstGen guarantees there will be a store before the first load, so we put a value
+        // here indicating there is no valid value.
+        Value.initTag(.unreachable_value),
+    );
+    try sema.mod.declareDeclDependency(sema.owner_decl, decl);
+    return sema.addConstant(ptr_type, try Value.Tag.decl_ref_mut.create(sema.arena, .{
+        .runtime_index = block.runtime_index,
+        .decl = decl,
+    }));
+}
src/Zir.zig
@@ -465,6 +465,11 @@ pub const Inst = struct {
         /// Uses the `un_node` union field.
         ret_node,
         /// Sends control flow back to the function's callee.
+        /// The operand is a `ret_ptr` instruction, where the return value can be found.
+        /// Includes an AST node source location.
+        /// Uses the `un_node` union field.
+        ret_load,
+        /// Sends control flow back to the function's callee.
         /// Includes an operand as the return value.
         /// Includes a token source location.
         /// Uses the `un_tok` union field.
@@ -1231,6 +1236,7 @@ pub const Inst = struct {
                 .condbr_inline,
                 .compile_error,
                 .ret_node,
+                .ret_load,
                 .ret_coerce,
                 .ret_err_value,
                 .@"unreachable",
@@ -1335,6 +1341,7 @@ pub const Inst = struct {
                 .param_type = .param_type,
                 .ref = .un_tok,
                 .ret_node = .un_node,
+                .ret_load = .un_node,
                 .ret_coerce = .un_tok,
                 .ret_err_value = .str_tok,
                 .ret_err_value_code = .str_tok,
@@ -2912,6 +2919,7 @@ const Writer = struct {
             .ensure_result_used,
             .ensure_result_non_error,
             .ret_node,
+            .ret_load,
             .resolve_inferred_alloc,
             .optional_type,
             .optional_payload_safe,
test/behavior/generics.zig
@@ -28,16 +28,7 @@ test "simple generic fn" {
 }
 
 fn max(comptime T: type, a: T, b: T) T {
-    if (!builtin.zig_is_stage2) {
-        // TODO: stage2 is incorrectly emitting AIR that allocates a result
-        // value, stores to it, but then returns void instead of the result.
-        return if (a > b) a else b;
-    }
-    if (a > b) {
-        return a;
-    } else {
-        return b;
-    }
+    return if (a > b) a else b;
 }
 
 fn add(comptime a: i32, b: i32) i32 {
@@ -70,29 +61,14 @@ test "fn with comptime args" {
 test "anytype params" {
     try expect(max_i32(12, 34) == 34);
     try expect(max_f64(1.2, 3.4) == 3.4);
-    if (!builtin.zig_is_stage2) {
-        // TODO: stage2 is incorrectly hitting the following problem:
-        // error: unable to resolve comptime value
-        //     return max_anytype(a, b);
-        //                       ^
-        comptime {
-            try expect(max_i32(12, 34) == 34);
-            try expect(max_f64(1.2, 3.4) == 3.4);
-        }
+    comptime {
+        try expect(max_i32(12, 34) == 34);
+        try expect(max_f64(1.2, 3.4) == 3.4);
     }
 }
 
 fn max_anytype(a: anytype, b: anytype) @TypeOf(a, b) {
-    if (!builtin.zig_is_stage2) {
-        // TODO: stage2 is incorrectly emitting AIR that allocates a result
-        // value, stores to it, but then returns void instead of the result.
-        return if (a > b) a else b;
-    }
-    if (a > b) {
-        return a;
-    } else {
-        return b;
-    }
+    return if (a > b) a else b;
 }
 
 fn max_i32(a: i32, b: i32) i32 {
test/stage2/cbe.zig
@@ -240,6 +240,10 @@ pub fn addCases(ctx: *TestContext) !void {
     if (host_supports_custom_stack_size) {
         var case = ctx.exeFromCompiledC("@setEvalBranchQuota", .{});
 
+        // TODO when adding result location support to function calls, revisit this test
+        // case. It can go back to what it was before, with `y` being comptime known.
+        // Because the ret_ptr will passed in with the inline fn call, and there will
+        // only be 1 store to it, and it will be comptime known.
         case.addCompareOutput(
             \\pub export fn main() i32 {
             \\    @setEvalBranchQuota(1001);
@@ -247,7 +251,7 @@ pub fn addCases(ctx: *TestContext) !void {
             \\    return y - 1;
             \\}
             \\
-            \\fn rec(n: usize) callconv(.Inline) usize {
+            \\inline fn rec(n: i32) i32 {
             \\    if (n <= 1) return n;
             \\    return rec(n - 1);
             \\}