Commit 0224ad19b8

Andrew Kelley <andrew@ziglang.org>
2022-06-03 01:50:33
AstGen: introduce `try` instruction
This introduces two ZIR instructions: * `try` * `try_inline` This is part of an effort to implement #11772.
1 parent 33826a6
src/AstGen.zig
@@ -2425,6 +2425,8 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: Ast.Node.Index) Inner
             .param_type,
             .ret_ptr,
             .ret_type,
+            .@"try",
+            .try_inline,
             => break :b false,
 
             .extended => switch (gz.astgen.instructions.items(.data)[inst].extended.opcode) {
@@ -4871,68 +4873,30 @@ fn tryExpr(
 
     if (parent_gz.in_defer) return astgen.failNode(node, "'try' not allowed inside defer expression", .{});
 
-    var block_scope = parent_gz.makeSubBlock(scope);
-    block_scope.setBreakResultLoc(rl);
-    defer block_scope.unstack();
-
-    const operand_rl: ResultLoc = switch (block_scope.break_result_loc) {
+    const operand_rl: ResultLoc = switch (rl) {
         .ref => .ref,
         else => .none,
     };
-    const err_ops = switch (operand_rl) {
-        // zig fmt: off
-        .ref => [3]Zir.Inst.Tag{ .is_non_err_ptr, .err_union_code_ptr, .err_union_payload_unsafe_ptr },
-        else => [3]Zir.Inst.Tag{ .is_non_err,     .err_union_code,     .err_union_payload_unsafe },
-        // zig fmt: on
-    };
-    // This could be a pointer or value depending on the `operand_rl` parameter.
-    // We cannot use `block_scope.break_result_loc` because that has the bare
-    // type, whereas this expression has the optional type. Later we make
-    // up for this fact by calling rvalue on the else branch.
-    const operand = try expr(&block_scope, &block_scope.base, operand_rl, operand_node);
-    const cond = try block_scope.addUnNode(err_ops[0], operand, node);
-    const condbr = try block_scope.addCondBr(.condbr, node);
+    // This could be a pointer or value depending on the `rl` parameter.
+    const operand = try expr(parent_gz, scope, operand_rl, operand_node);
+    const is_inline = parent_gz.force_comptime;
+    const block_tag: Zir.Inst.Tag = if (is_inline) .try_inline else .@"try";
+    const try_inst = try parent_gz.makeBlockInst(block_tag, node);
+    try parent_gz.instructions.append(astgen.gpa, try_inst);
 
-    const block = try parent_gz.makeBlockInst(.block, node);
-    try block_scope.setBlockBody(block);
-    // block_scope unstacked now, can add new instructions to parent_gz
-    try parent_gz.instructions.append(astgen.gpa, block);
-
-    var then_scope = parent_gz.makeSubBlock(scope);
-    defer then_scope.unstack();
-
-    block_scope.break_count += 1;
-    // This could be a pointer or value depending on `err_ops[2]`.
-    const unwrapped_payload = try then_scope.addUnNode(err_ops[2], operand, node);
-    const then_result = switch (rl) {
-        .ref => unwrapped_payload,
-        else => try rvalue(&then_scope, block_scope.break_result_loc, unwrapped_payload, node),
-    };
-
-    // else_scope will be stacked on then_scope as both are stacked on parent_gz
     var else_scope = parent_gz.makeSubBlock(scope);
     defer else_scope.unstack();
 
-    const err_code = try else_scope.addUnNode(err_ops[1], operand, node);
+    const err_tag = switch (rl) {
+        .ref => Zir.Inst.Tag.err_union_code_ptr,
+        else => Zir.Inst.Tag.err_union_code,
+    };
+    const err_code = try else_scope.addUnNode(err_tag, operand, node);
     try genDefers(&else_scope, &fn_block.base, scope, .{ .both = err_code });
-    const else_result = try else_scope.addUnNode(.ret_node, err_code, node);
+    _ = try else_scope.addUnNode(.ret_node, err_code, node);
 
-    const break_tag: Zir.Inst.Tag = if (parent_gz.force_comptime) .break_inline else .@"break";
-    return finishThenElseBlock(
-        parent_gz,
-        rl,
-        node,
-        &block_scope,
-        &then_scope,
-        &else_scope,
-        condbr,
-        cond,
-        then_result,
-        else_result,
-        block,
-        block,
-        break_tag,
-    );
+    try else_scope.setTryBody(try_inst, operand);
+    return indexToRef(try_inst);
 }
 
 fn orelseCatchExpr(
@@ -10011,6 +9975,22 @@ const GenZir = struct {
         gz.unstack();
     }
 
+    /// Assumes nothing stacked on `gz`. Unstacks `gz`.
+    fn setTryBody(gz: *GenZir, inst: Zir.Inst.Index, operand: Zir.Inst.Ref) !void {
+        const gpa = gz.astgen.gpa;
+        const body = gz.instructionsSlice();
+        try gz.astgen.extra.ensureUnusedCapacity(gpa, @typeInfo(Zir.Inst.Try).Struct.fields.len + body.len);
+        const zir_datas = gz.astgen.instructions.items(.data);
+        zir_datas[inst].pl_node.payload_index = gz.astgen.addExtraAssumeCapacity(
+            Zir.Inst.Try{
+                .operand = operand,
+                .body_len = @intCast(u32, body.len),
+            },
+        );
+        gz.astgen.extra.appendSliceAssumeCapacity(body);
+        gz.unstack();
+    }
+
     /// Must be called with the following stack set up:
     ///  * gz (bottom)
     ///  * align_gz
src/print_zir.zig
@@ -374,17 +374,21 @@ const Writer = struct {
             .validate_array_init_comptime,
             .c_import,
             .typeof_builtin,
-            => try self.writePlNodeBlock(stream, inst),
+            => try self.writeBlock(stream, inst),
 
             .condbr,
             .condbr_inline,
-            => try self.writePlNodeCondBr(stream, inst),
+            => try self.writeCondBr(stream, inst),
+
+            .@"try",
+            .try_inline,
+            => try self.writeTry(stream, inst),
 
             .error_set_decl => try self.writeErrorSetDecl(stream, inst, .parent),
             .error_set_decl_anon => try self.writeErrorSetDecl(stream, inst, .anon),
             .error_set_decl_func => try self.writeErrorSetDecl(stream, inst, .func),
 
-            .switch_block => try self.writePlNodeSwitchBlock(stream, inst),
+            .switch_block => try self.writeSwitchBlock(stream, inst),
 
             .field_ptr,
             .field_val,
@@ -1171,7 +1175,7 @@ const Writer = struct {
         try self.writeSrc(stream, inst_data.src());
     }
 
-    fn writePlNodeBlock(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
+    fn writeBlock(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[inst].pl_node;
         try self.writePlNodeBlockWithoutSrc(stream, inst);
         try self.writeSrc(stream, inst_data.src());
@@ -1185,7 +1189,7 @@ const Writer = struct {
         try stream.writeAll(") ");
     }
 
-    fn writePlNodeCondBr(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
+    fn writeCondBr(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[inst].pl_node;
         const extra = self.code.extraData(Zir.Inst.CondBr, inst_data.payload_index);
         const then_body = self.code.extra[extra.end..][0..extra.data.then_body_len];
@@ -1199,6 +1203,17 @@ const Writer = struct {
         try self.writeSrc(stream, inst_data.src());
     }
 
+    fn writeTry(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
+        const inst_data = self.code.instructions.items(.data)[inst].pl_node;
+        const extra = self.code.extraData(Zir.Inst.Try, inst_data.payload_index);
+        const body = self.code.extra[extra.end..][0..extra.data.body_len];
+        try self.writeInstRef(stream, extra.data.operand);
+        try stream.writeAll(", ");
+        try self.writeBracedBody(stream, body);
+        try stream.writeAll(") ");
+        try self.writeSrc(stream, inst_data.src());
+    }
+
     fn writeStructDecl(self: *Writer, stream: anytype, extended: Zir.Inst.Extended.InstData) !void {
         const small = @bitCast(Zir.Inst.StructDecl.Small, extended.small);
 
@@ -1746,7 +1761,7 @@ const Writer = struct {
         try self.writeSrc(stream, inst_data.src());
     }
 
-    fn writePlNodeSwitchBlock(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
+    fn writeSwitchBlock(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
         const inst_data = self.code.instructions.items(.data)[inst].pl_node;
         const extra = self.code.extraData(Zir.Inst.SwitchBlock, inst_data.payload_index);
 
src/Sema.zig
@@ -1322,6 +1322,13 @@ fn analyzeBodyInner(
                     break break_data.inst;
                 }
             },
+            .@"try" => blk: {
+                if (!block.is_comptime) break :blk try sema.zirTry(block, inst);
+                @panic("TODO");
+            },
+            .try_inline => {
+                @panic("TODO");
+            },
         };
         if (sema.typeOf(air_inst).isNoReturn())
             break always_noreturn;
@@ -6415,32 +6422,43 @@ fn zirErrUnionPayload(
     const src = inst_data.src();
     const operand = try sema.resolveInst(inst_data.operand);
     const operand_src = src;
-    const operand_ty = sema.typeOf(operand);
-    if (operand_ty.zigTypeTag() != .ErrorUnion) {
+    const err_union_ty = sema.typeOf(operand);
+    if (err_union_ty.zigTypeTag() != .ErrorUnion) {
         return sema.fail(block, operand_src, "expected error union type, found '{}'", .{
-            operand_ty.fmt(sema.mod),
+            err_union_ty.fmt(sema.mod),
         });
     }
+    return sema.analyzeErrUnionPayload(block, src, err_union_ty, operand, operand_src, safety_check);
+}
 
-    const result_ty = operand_ty.errorUnionPayload();
-    if (try sema.resolveDefinedValue(block, src, operand)) |val| {
+fn analyzeErrUnionPayload(
+    sema: *Sema,
+    block: *Block,
+    src: LazySrcLoc,
+    err_union_ty: Type,
+    operand: Zir.Inst.Ref,
+    operand_src: LazySrcLoc,
+    safety_check: bool,
+) CompileError!Air.Inst.Ref {
+    const payload_ty = err_union_ty.errorUnionPayload();
+    if (try sema.resolveDefinedValue(block, operand_src, operand)) |val| {
         if (val.getError()) |name| {
             return sema.fail(block, src, "caught unexpected error '{s}'", .{name});
         }
         const data = val.castTag(.eu_payload).?.data;
-        return sema.addConstant(result_ty, data);
+        return sema.addConstant(payload_ty, data);
     }
 
     try sema.requireRuntimeBlock(block, src);
 
     // If the error set has no fields then no safety check is needed.
     if (safety_check and block.wantSafety() and
-        operand_ty.errorUnionSet().errorSetCardinality() != .zero)
+        err_union_ty.errorUnionSet().errorSetCardinality() != .zero)
     {
         try sema.panicUnwrapError(block, src, operand, .unwrap_errunion_err, .is_non_err);
     }
 
-    return block.addTyOp(.unwrap_errunion_payload, result_ty, operand);
+    return block.addTyOp(.unwrap_errunion_payload, payload_ty, operand);
 }
 
 /// Pointer in, pointer out.
@@ -12958,6 +12976,43 @@ fn zirCondbr(
     return always_noreturn;
 }
 
+fn zirTry(sema: *Sema, parent_block: *Block, inst: Zir.Inst.Index) CompileError!Zir.Inst.Ref {
+    const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
+    const src = inst_data.src();
+    const operand_src: LazySrcLoc = .{ .node_offset_bin_lhs = inst_data.src_node };
+    const extra = sema.code.extraData(Zir.Inst.Try, inst_data.payload_index);
+    const body = sema.code.extra[extra.end..][0..extra.data.body_len];
+    const operand = try sema.resolveInst(extra.data.operand);
+    const is_ptr = sema.typeOf(operand).zigTypeTag() == .Pointer;
+    const err_union = if (is_ptr)
+        try sema.analyzeLoad(parent_block, src, operand, operand_src)
+    else
+        operand;
+    const err_union_ty = sema.typeOf(err_union);
+    if (err_union_ty.zigTypeTag() != .ErrorUnion) {
+        return sema.fail(parent_block, operand_src, "expected error union type, found '{}'", .{
+            err_union_ty.fmt(sema.mod),
+        });
+    }
+    const is_non_err = try sema.analyzeIsNonErr(parent_block, operand_src, err_union);
+
+    if (try sema.resolveDefinedValue(parent_block, operand_src, is_non_err)) |is_non_err_val| {
+        if (is_non_err_val.toBool()) {
+            if (is_ptr) {
+                return sema.analyzeErrUnionPayloadPtr(parent_block, src, operand, false, false);
+            } else {
+                return sema.analyzeErrUnionPayload(parent_block, src, err_union_ty, operand, operand_src, false);
+            }
+        }
+        // We can analyze the body directly in the parent block because we know there are
+        // no breaks from the body possible, and that the body is noreturn.
+        return sema.resolveBody(parent_block, body, inst);
+    }
+    _ = body;
+    _ = is_non_err;
+    @panic("TODO");
+}
+
 // A `break` statement is inside a runtime condition, but trying to
 // break from an inline loop. In such case we must convert it to
 // a runtime break.
src/Zir.zig
@@ -319,6 +319,19 @@ pub const Inst = struct {
         /// only the taken branch is analyzed. The then block and else block must
         /// terminate with an "inline" variant of a noreturn instruction.
         condbr_inline,
+        /// Given an operand which is an error union, splits control flow. In
+        /// case of error, control flow goes into the block that is part of this
+        /// instruction, which is guaranteed to end with a return instruction
+        /// and never breaks out of the block.
+        /// In the case of non-error, control flow proceeds to the next instruction
+        /// after the `try`, with the result of this instruction being the unwrapped
+        /// payload value, as if `err_union_payload_unsafe` was executed on the operand.
+        /// Uses the `pl_node` union field. Payload is `Try`.
+        @"try",
+        /// Same as `try` except the operand is coerced to a comptime value, and
+        /// only the taken branch is analyzed. The block must terminate with an "inline"
+        /// variant of a noreturn instruction.
+        try_inline,
         /// An error set type definition. Contains a list of field names.
         /// Uses the `pl_node` union field. Payload is `ErrorSetDecl`.
         error_set_decl,
@@ -1231,6 +1244,8 @@ pub const Inst = struct {
                 .closure_capture,
                 .ret_ptr,
                 .ret_type,
+                .@"try",
+                .try_inline,
                 => false,
 
                 .@"break",
@@ -1509,6 +1524,8 @@ pub const Inst = struct {
                 .repeat,
                 .repeat_inline,
                 .panic,
+                .@"try",
+                .try_inline,
                 => false,
 
                 .extended => switch (data.extended.opcode) {
@@ -1569,6 +1586,8 @@ pub const Inst = struct {
                 .coerce_result_ptr = .bin,
                 .condbr = .pl_node,
                 .condbr_inline = .pl_node,
+                .@"try" = .pl_node,
+                .try_inline = .pl_node,
                 .error_set_decl = .pl_node,
                 .error_set_decl_anon = .pl_node,
                 .error_set_decl_func = .pl_node,
@@ -2803,6 +2822,14 @@ pub const Inst = struct {
         else_body_len: u32,
     };
 
+    /// This data is stored inside extra, trailed by:
+    /// * 0. body: Index //  for each `body_len`.
+    pub const Try = struct {
+        /// The error union to unwrap.
+        operand: Ref,
+        body_len: u32,
+    };
+
     /// Stored in extra. Depending on the flags in Data, there will be up to 5
     /// trailing Ref fields:
     /// 0. sentinel: Ref // if `has_sentinel` flag is set
@@ -3739,6 +3766,12 @@ fn findDeclsInner(
             try zir.findDeclsBody(list, then_body);
             try zir.findDeclsBody(list, else_body);
         },
+        .@"try", .try_inline => {
+            const inst_data = datas[inst].pl_node;
+            const extra = zir.extraData(Inst.Try, inst_data.payload_index);
+            const body = zir.extra[extra.end..][0..extra.data.body_len];
+            try zir.findDeclsBody(list, body);
+        },
         .switch_block => return findDeclsSwitch(zir, list, inst),
 
         .suspend_block => @panic("TODO iterate suspend block"),