Commit ae19f699ab

dweiller <4678790+dweiller@users.noreply.github.com>
2023-11-22 05:01:18
sema: implement switch_block_err_union on comptime operands
1 parent 2cf648f
Changed files (1)
src/Sema.zig
@@ -1097,7 +1097,7 @@ fn analyzeBodyInner(
             .str                          => try sema.zirStr(inst),
             .switch_block                 => try sema.zirSwitchBlock(block, inst, false),
             .switch_block_ref             => try sema.zirSwitchBlock(block, inst, true),
-            .switch_block_err_union       => @panic("TODO: implement lowering of switch_block_err_union"),
+            .switch_block_err_union       => try sema.zirSwitchBlockErrUnion(block, inst),
             .type_info                    => try sema.zirTypeInfo(block, inst),
             .size_of                      => try sema.zirSizeOf(block, inst),
             .bit_size_of                  => try sema.zirBitSizeOf(block, inst),
@@ -11160,6 +11160,195 @@ fn switchCond(
 
 const SwitchErrorSet = std.AutoHashMap(InternPool.NullTerminatedString, Module.SwitchProngSrc);
 
+fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
+    const tracy = trace(@src());
+    defer tracy.end();
+
+    const mod = sema.mod;
+    const gpa = sema.gpa;
+    const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
+    const src = inst_data.src();
+    const src_node_offset = inst_data.src_node;
+    const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = src_node_offset };
+    const else_prong_src: LazySrcLoc = .{ .node_offset_switch_special_prong = src_node_offset };
+    const extra = sema.code.extraData(Zir.Inst.SwitchBlockErrUnion, inst_data.payload_index);
+
+    const raw_operand_val = try sema.resolveInst(extra.data.operand);
+    assert(sema.typeOf(raw_operand_val).zigTypeTag(mod) == .ErrorUnion);
+
+    // AstGen guarantees that the instruction immediately preceding
+    // switch_block_err_union is a dbg_stmt
+    const cond_dbg_node_index: Zir.Inst.Index = @enumFromInt(@intFromEnum(inst) - 1);
+    _ = cond_dbg_node_index;
+
+    var header_extra_index: usize = extra.end;
+
+    const scalar_cases_len = extra.data.bits.scalar_cases_len;
+    const multi_cases_len = if (extra.data.bits.has_multi_cases) blk: {
+        const multi_cases_len = sema.code.extra[header_extra_index];
+        header_extra_index += 1;
+        break :blk multi_cases_len;
+    } else 0;
+
+    var case_vals = try std.ArrayListUnmanaged(Air.Inst.Ref).initCapacity(gpa, scalar_cases_len + 2 * multi_cases_len);
+    defer case_vals.deinit(gpa);
+
+    const NonError = struct {
+        body: []const Zir.Inst.Index,
+        end: usize,
+    };
+
+    const non_error_case: NonError = non_error: {
+        const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[header_extra_index]);
+        const extra_body_start = header_extra_index + 1;
+        break :non_error .{
+            .body = sema.code.bodySlice(extra_body_start, info.body_len),
+            .end = extra_body_start + info.body_len,
+        };
+    };
+
+    const Else = struct {
+        body: []const Zir.Inst.Index,
+        end: usize,
+        is_inline: bool,
+        has_capture: bool,
+    };
+
+    const else_case: Else = if (!extra.data.bits.has_else) .{
+        .body = &.{},
+        .end = non_error_case.end,
+        .is_inline = false,
+        .has_capture = false,
+    } else special: {
+        const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[non_error_case.end]);
+        const extra_body_start = non_error_case.end + 1;
+        assert(info.capture != .by_ref);
+        assert(!info.has_tag_capture);
+        break :special .{
+            .body = sema.code.bodySlice(extra_body_start, info.body_len),
+            .end = extra_body_start + info.body_len,
+            .is_inline = info.is_inline,
+            .has_capture = info.capture == .by_val,
+        };
+    };
+
+    var seen_errors = SwitchErrorSet.init(gpa);
+    defer seen_errors.deinit();
+
+    const operand_ty = sema.typeOf(raw_operand_val);
+    const operand_err_set_ty = operand_ty.errorUnionSet(mod);
+
+    const else_error_ty: ?Type = try validateErrSetSwitch(
+        sema,
+        block,
+        &seen_errors,
+        &case_vals,
+        operand_err_set_ty,
+        inst_data,
+        scalar_cases_len,
+        multi_cases_len,
+        .{ .body = else_case.body, .end = else_case.end, .src = else_prong_src },
+        extra.data.bits.has_else,
+    );
+
+    var spa: SwitchProngAnalysis = .{
+        .sema = sema,
+        .parent_block = block,
+        .operand = raw_operand_val,
+        .operand_ptr = .none,
+        .cond = raw_operand_val,
+        .else_error_ty = else_error_ty,
+        .switch_block_inst = inst,
+        .tag_capture_inst = undefined,
+    };
+
+    const block_inst: Air.Inst.Index = @enumFromInt(sema.air_instructions.len);
+    try sema.air_instructions.append(gpa, .{
+        .tag = .block,
+        .data = undefined,
+    });
+    var label: Block.Label = .{
+        .zir_block = inst,
+        .merges = .{
+            .src_locs = .{},
+            .results = .{},
+            .br_list = .{},
+            .block_inst = block_inst,
+        },
+    };
+
+    var child_block: Block = .{
+        .parent = block,
+        .sema = sema,
+        .src_decl = block.src_decl,
+        .namespace = block.namespace,
+        .wip_capture_scope = block.wip_capture_scope,
+        .instructions = .{},
+        .label = &label,
+        .inlining = block.inlining,
+        .is_comptime = block.is_comptime,
+        .comptime_reason = block.comptime_reason,
+        .is_typeof = block.is_typeof,
+        .c_import_buf = block.c_import_buf,
+        .runtime_cond = block.runtime_cond,
+        .runtime_loop = block.runtime_loop,
+        .runtime_index = block.runtime_index,
+        .error_return_trace_index = block.error_return_trace_index,
+    };
+    const merges = &child_block.label.?.merges;
+    defer child_block.instructions.deinit(gpa);
+    defer merges.deinit(gpa);
+
+    if (try sema.resolveDefinedValue(&child_block, src, raw_operand_val)) |operand_val| {
+        if (operand_val.errorUnionIsPayload(mod)) {
+            return sema.resolveBlockBody(block, operand_src, &child_block, non_error_case.body, inst, merges);
+        } else {
+            const err_val = Value.fromInterned(try mod.intern(.{
+                .err = .{
+                    .ty = operand_err_set_ty.toIntern(),
+                    .name = operand_val.getErrorName(mod).unwrap().?,
+                },
+            }));
+            spa.operand = try sema.analyzeErrUnionCode(block, operand_src, raw_operand_val);
+            return resolveSwitchComptime(
+                sema,
+                spa,
+                &child_block,
+                try sema.switchCond(block, operand_src, spa.operand),
+                err_val,
+                operand_err_set_ty,
+                .{
+                    .body = else_case.body,
+                    .end = else_case.end,
+                    .capture = if (else_case.has_capture) .by_val else .none,
+                    .is_inline = else_case.is_inline,
+                    .has_tag_capture = false,
+                },
+                case_vals,
+                scalar_cases_len,
+                multi_cases_len,
+                true,
+                false,
+            );
+        }
+    }
+
+    if (scalar_cases_len + multi_cases_len == 0) {
+        if (else_error_ty) |ty| if (ty.errorSetIsEmpty(mod)) {
+            return sema.resolveBlockBody(block, operand_src, &child_block, non_error_case.body, inst, merges);
+        };
+    }
+
+    if (child_block.is_comptime) {
+        _ = try sema.resolveConstDefinedValue(&child_block, operand_src, operand, .{
+            .needed_comptime_reason = "condition in comptime switch must be comptime-known",
+            .block_comptime_reason = child_block.comptime_reason,
+        });
+        unreachable;
+    }
+    return sema.fail(block, src, "TODO: implement more of switch_block_err_union", .{});
+}
+
 fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_ref: bool) CompileError!Air.Inst.Ref {
     const tracy = trace(@src());
     defer tracy.end();