Commit b784f64a6e

dweiller <4678790+dweiller@users.noreply.github.com>
2023-11-22 04:59:02
sema: refactor error set switch logic
1 parent 4136097
Changed files (1)
src/Sema.zig
@@ -11212,16 +11212,8 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
     var case_vals = try std.ArrayListUnmanaged(Air.Inst.Ref).initCapacity(gpa, scalar_cases_len + 2 * multi_cases_len);
     defer case_vals.deinit(gpa);
 
-    const Special = struct {
-        body: []const Zir.Inst.Index,
-        end: usize,
-        capture: Zir.Inst.SwitchBlock.ProngInfo.Capture,
-        is_inline: bool,
-        has_tag_capture: bool,
-    };
-
     const special_prong = extra.data.bits.specialProng();
-    const special: Special = switch (special_prong) {
+    const special: SpecialProng = switch (special_prong) {
         .none => .{
             .body = &.{},
             .end = header_extra_index,
@@ -11401,150 +11393,18 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
                 );
             }
         },
-        .ErrorSet => {
-            var extra_index: usize = special.end;
-            {
-                var scalar_i: u32 = 0;
-                while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
-                    const item_ref: Zir.Inst.Ref = @enumFromInt(sema.code.extra[extra_index]);
-                    extra_index += 1;
-                    const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
-                    extra_index += 1 + info.body_len;
-
-                    case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
-                        block,
-                        &seen_errors,
-                        item_ref,
-                        operand_ty,
-                        src_node_offset,
-                        .{ .scalar = scalar_i },
-                    ));
-                }
-            }
-            {
-                var multi_i: u32 = 0;
-                while (multi_i < multi_cases_len) : (multi_i += 1) {
-                    const items_len = sema.code.extra[extra_index];
-                    extra_index += 1;
-                    const ranges_len = sema.code.extra[extra_index];
-                    extra_index += 1;
-                    const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
-                    extra_index += 1;
-                    const items = sema.code.refSlice(extra_index, items_len);
-                    extra_index += items_len + info.body_len;
-
-                    try case_vals.ensureUnusedCapacity(gpa, items.len);
-                    for (items, 0..) |item_ref, item_i| {
-                        case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
-                            block,
-                            &seen_errors,
-                            item_ref,
-                            operand_ty,
-                            src_node_offset,
-                            .{ .multi = .{ .prong = multi_i, .item = @intCast(item_i) } },
-                        ));
-                    }
-
-                    try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset);
-                }
-            }
-
-            switch (try sema.resolveInferredErrorSetTy(block, src, operand_ty.toIntern())) {
-                .anyerror_type => {
-                    if (special_prong != .@"else") {
-                        return sema.fail(
-                            block,
-                            src,
-                            "else prong required when switching on type 'anyerror'",
-                            .{},
-                        );
-                    }
-                    else_error_ty = Type.anyerror;
-                },
-                else => |err_set_ty_index| else_validation: {
-                    const error_names = ip.indexToKey(err_set_ty_index).error_set_type.names;
-                    var maybe_msg: ?*Module.ErrorMsg = null;
-                    errdefer if (maybe_msg) |msg| msg.destroy(sema.gpa);
-
-                    for (error_names.get(ip)) |error_name| {
-                        if (!seen_errors.contains(error_name) and special_prong != .@"else") {
-                            const msg = maybe_msg orelse blk: {
-                                maybe_msg = try sema.errMsg(
-                                    block,
-                                    src,
-                                    "switch must handle all possibilities",
-                                    .{},
-                                );
-                                break :blk maybe_msg.?;
-                            };
-
-                            try sema.errNote(
-                                block,
-                                src,
-                                msg,
-                                "unhandled error value: 'error.{}'",
-                                .{error_name.fmt(ip)},
-                            );
-                        }
-                    }
-
-                    if (maybe_msg) |msg| {
-                        maybe_msg = null;
-                        try sema.addDeclaredHereNote(msg, operand_ty);
-                        return sema.failWithOwnedErrorMsg(block, msg);
-                    }
-
-                    if (special_prong == .@"else" and
-                        seen_errors.count() == error_names.len)
-                    {
-                        // In order to enable common patterns for generic code allow simple else bodies
-                        // else => unreachable,
-                        // else => return,
-                        // else => |e| return e,
-                        // even if all the possible errors were already handled.
-                        const tags = sema.code.instructions.items(.tag);
-                        for (special.body) |else_inst| switch (tags[@intFromEnum(else_inst)]) {
-                            .dbg_block_begin,
-                            .dbg_block_end,
-                            .dbg_stmt,
-                            .dbg_var_val,
-                            .ret_type,
-                            .as_node,
-                            .ret_node,
-                            .@"unreachable",
-                            .@"defer",
-                            .defer_err_code,
-                            .err_union_code,
-                            .ret_err_value_code,
-                            .restore_err_ret_index,
-                            .is_non_err,
-                            .ret_is_non_err,
-                            .condbr,
-                            => {},
-                            else => break,
-                        } else break :else_validation;
-
-                        return sema.fail(
-                            block,
-                            special_prong_src,
-                            "unreachable else prong; all cases already handled",
-                            .{},
-                        );
-                    }
-
-                    var names: InferredErrorSet.NameMap = .{};
-                    try names.ensureUnusedCapacity(sema.arena, error_names.len);
-                    for (error_names.get(ip)) |error_name| {
-                        if (seen_errors.contains(error_name)) continue;
-
-                        names.putAssumeCapacityNoClobber(error_name, {});
-                    }
-                    // No need to keep the hash map metadata correct; here we
-                    // extract the (sorted) keys only.
-                    else_error_ty = try mod.errorSetFromUnsortedNames(names.keys());
-                },
-            }
-        },
+        .ErrorSet => else_error_ty = try validateErrSetSwitch(
+            sema,
+            block,
+            &seen_errors,
+            &case_vals,
+            operand_ty,
+            inst_data,
+            scalar_cases_len,
+            multi_cases_len,
+            .{ .body = special.body, .end = special.end, .src = special_prong_src },
+            special_prong == .@"else",
+        ),
         .Int, .ComptimeInt => {
             var extra_index: usize = special.end;
             {
@@ -11840,114 +11700,19 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
     defer merges.deinit(gpa);
 
     if (try sema.resolveDefinedValue(&child_block, src, operand)) |operand_val| {
-        const resolved_operand_val = try sema.resolveLazyValue(operand_val);
-        var extra_index: usize = special.end;
-        {
-            var scalar_i: usize = 0;
-            while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
-                extra_index += 1;
-                const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
-                extra_index += 1;
-                const body = sema.code.bodySlice(extra_index, info.body_len);
-                extra_index += info.body_len;
-
-                const item = case_vals.items[scalar_i];
-                const item_val = sema.resolveConstDefinedValue(&child_block, .unneeded, item, undefined) catch unreachable;
-                if (operand_val.eql(item_val, operand_ty, sema.mod)) {
-                    if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
-                    return spa.resolveProngComptime(
-                        &child_block,
-                        .normal,
-                        body,
-                        info.capture,
-                        .{ .scalar_capture = @intCast(scalar_i) },
-                        &.{item},
-                        if (info.is_inline) operand else .none,
-                        info.has_tag_capture,
-                        merges,
-                    );
-                }
-            }
-        }
-        {
-            var multi_i: usize = 0;
-            var case_val_idx: usize = scalar_cases_len;
-            while (multi_i < multi_cases_len) : (multi_i += 1) {
-                const items_len = sema.code.extra[extra_index];
-                extra_index += 1;
-                const ranges_len = sema.code.extra[extra_index];
-                extra_index += 1;
-                const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
-                extra_index += 1 + items_len;
-                const body = sema.code.bodySlice(extra_index + 2 * ranges_len, info.body_len);
-
-                const items = case_vals.items[case_val_idx..][0..items_len];
-                case_val_idx += items_len;
-
-                for (items) |item| {
-                    // Validation above ensured these will succeed.
-                    const item_val = sema.resolveConstDefinedValue(&child_block, .unneeded, item, undefined) catch unreachable;
-                    if (operand_val.eql(item_val, operand_ty, sema.mod)) {
-                        if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
-                        return spa.resolveProngComptime(
-                            &child_block,
-                            .normal,
-                            body,
-                            info.capture,
-                            .{ .multi_capture = @intCast(multi_i) },
-                            items,
-                            if (info.is_inline) operand else .none,
-                            info.has_tag_capture,
-                            merges,
-                        );
-                    }
-                }
-
-                var range_i: usize = 0;
-                while (range_i < ranges_len) : (range_i += 1) {
-                    const range_items = case_vals.items[case_val_idx..][0..2];
-                    extra_index += 2;
-                    case_val_idx += 2;
-
-                    // Validation above ensured these will succeed.
-                    const first_val = sema.resolveConstDefinedValue(&child_block, .unneeded, range_items[0], undefined) catch unreachable;
-                    const last_val = sema.resolveConstDefinedValue(&child_block, .unneeded, range_items[1], undefined) catch unreachable;
-                    if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and
-                        (try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty)))
-                    {
-                        if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, body, operand);
-                        return spa.resolveProngComptime(
-                            &child_block,
-                            .normal,
-                            body,
-                            info.capture,
-                            .{ .multi_capture = @intCast(multi_i) },
-                            undefined, // case_vals may be undefined for ranges
-                            if (info.is_inline) operand else .none,
-                            info.has_tag_capture,
-                            merges,
-                        );
-                    }
-                }
-
-                extra_index += info.body_len;
-            }
-        }
-        if (err_set) try sema.maybeErrorUnwrapComptime(&child_block, special.body, operand);
-        if (empty_enum) {
-            return .void_value;
-        }
-
-        return spa.resolveProngComptime(
+        return resolveSwitchComptime(
+            sema,
+            spa,
             &child_block,
-            .special,
-            special.body,
-            special.capture,
-            .special_capture,
-            undefined, // case_vals may be undefined for special prongs
-            if (special.is_inline) operand else .none,
-            special.has_tag_capture,
-            merges,
+            operand,
+            operand_val,
+            operand_ty,
+            special,
+            case_vals,
+            scalar_cases_len,
+            multi_cases_len,
+            err_set,
+            empty_enum,
         );
     }
 
@@ -12593,6 +12358,140 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
     return sema.analyzeBlockBody(block, src, &child_block, merges);
 }
 
+const SpecialProng = struct {
+    body: []const Zir.Inst.Index,
+    end: usize,
+    capture: Zir.Inst.SwitchBlock.ProngInfo.Capture,
+    is_inline: bool,
+    has_tag_capture: bool,
+};
+
+fn resolveSwitchComptime(
+    sema: *Sema,
+    spa: SwitchProngAnalysis,
+    child_block: *Block,
+    cond_operand: Air.Inst.Ref,
+    operand_val: Value,
+    operand_ty: Type,
+    special: SpecialProng,
+    case_vals: std.ArrayListUnmanaged(Air.Inst.Ref),
+    scalar_cases_len: u32,
+    multi_cases_len: u32,
+    err_set: bool,
+    empty_enum: bool,
+) CompileError!Air.Inst.Ref {
+    const merges = &child_block.label.?.merges;
+    const resolved_operand_val = try sema.resolveLazyValue(operand_val);
+    var extra_index: usize = special.end;
+    {
+        var scalar_i: usize = 0;
+        while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
+            extra_index += 1;
+            const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
+            extra_index += 1;
+            const body = sema.code.bodySlice(extra_index, info.body_len);
+            extra_index += info.body_len;
+
+            const item = case_vals.items[scalar_i];
+            const item_val = sema.resolveConstDefinedValue(child_block, .unneeded, item, undefined) catch unreachable;
+            if (operand_val.eql(item_val, operand_ty, sema.mod)) {
+                if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand);
+                return spa.resolveProngComptime(
+                    child_block,
+                    .normal,
+                    body,
+                    info.capture,
+                    .{ .scalar_capture = @intCast(scalar_i) },
+                    &.{item},
+                    if (info.is_inline) cond_operand else .none,
+                    info.has_tag_capture,
+                    merges,
+                );
+            }
+        }
+    }
+    {
+        var multi_i: usize = 0;
+        var case_val_idx: usize = scalar_cases_len;
+        while (multi_i < multi_cases_len) : (multi_i += 1) {
+            const items_len = sema.code.extra[extra_index];
+            extra_index += 1;
+            const ranges_len = sema.code.extra[extra_index];
+            extra_index += 1;
+            const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
+            extra_index += 1 + items_len;
+            const body = sema.code.bodySlice(extra_index + 2 * ranges_len, info.body_len);
+
+            const items = case_vals.items[case_val_idx..][0..items_len];
+            case_val_idx += items_len;
+
+            for (items) |item| {
+                // Validation above ensured these will succeed.
+                const item_val = sema.resolveConstDefinedValue(child_block, .unneeded, item, undefined) catch unreachable;
+                if (operand_val.eql(item_val, operand_ty, sema.mod)) {
+                    if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand);
+                    return spa.resolveProngComptime(
+                        child_block,
+                        .normal,
+                        body,
+                        info.capture,
+                        .{ .multi_capture = @intCast(multi_i) },
+                        items,
+                        if (info.is_inline) cond_operand else .none,
+                        info.has_tag_capture,
+                        merges,
+                    );
+                }
+            }
+
+            var range_i: usize = 0;
+            while (range_i < ranges_len) : (range_i += 1) {
+                const range_items = case_vals.items[case_val_idx..][0..2];
+                extra_index += 2;
+                case_val_idx += 2;
+
+                // Validation above ensured these will succeed.
+                const first_val = sema.resolveConstDefinedValue(child_block, .unneeded, range_items[0], undefined) catch unreachable;
+                const last_val = sema.resolveConstDefinedValue(child_block, .unneeded, range_items[1], undefined) catch unreachable;
+                if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and
+                    (try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty)))
+                {
+                    if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand);
+                    return spa.resolveProngComptime(
+                        child_block,
+                        .normal,
+                        body,
+                        info.capture,
+                        .{ .multi_capture = @intCast(multi_i) },
+                        undefined, // case_vals may be undefined for ranges
+                        if (info.is_inline) cond_operand else .none,
+                        info.has_tag_capture,
+                        merges,
+                    );
+                }
+            }
+
+            extra_index += info.body_len;
+        }
+    }
+    if (err_set) try sema.maybeErrorUnwrapComptime(child_block, special.body, cond_operand);
+    if (empty_enum) {
+        return .void_value;
+    }
+
+    return spa.resolveProngComptime(
+        child_block,
+        .special,
+        special.body,
+        special.capture,
+        .special_capture,
+        undefined, // case_vals may be undefined for special prongs
+        if (special.is_inline) cond_operand else .none,
+        special.has_tag_capture,
+        merges,
+    );
+}
+
 const RangeSetUnhandledIterator = struct {
     mod: *Module,
     cur: ?InternPool.Index,
@@ -12710,6 +12609,168 @@ fn resolveSwitchItemVal(
     return .{ .ref = new_item, .val = val.toIntern() };
 }
 
+fn validateErrSetSwitch(
+    sema: *Sema,
+    block: *Block,
+    seen_errors: *SwitchErrorSet,
+    case_vals: *std.ArrayListUnmanaged(Air.Inst.Ref),
+    operand_ty: Type,
+    inst_data: std.meta.FieldType(Zir.Inst.Data, .pl_node),
+    scalar_cases_len: u32,
+    multi_cases_len: u32,
+    else_case: struct { body: []const Zir.Inst.Index, end: usize, src: LazySrcLoc },
+    has_else: bool,
+) CompileError!?Type {
+    const gpa = sema.gpa;
+    const mod = sema.mod;
+    const ip = &mod.intern_pool;
+
+    const src_node_offset = inst_data.src_node;
+    const src = inst_data.src();
+
+    var extra_index: usize = else_case.end;
+    {
+        var scalar_i: u32 = 0;
+        while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
+            const item_ref: Zir.Inst.Ref = @enumFromInt(sema.code.extra[extra_index]);
+            extra_index += 1;
+            const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
+            extra_index += 1 + info.body_len;
+
+            case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
+                block,
+                seen_errors,
+                item_ref,
+                operand_ty,
+                src_node_offset,
+                .{ .scalar = scalar_i },
+            ));
+        }
+    }
+    {
+        var multi_i: u32 = 0;
+        while (multi_i < multi_cases_len) : (multi_i += 1) {
+            const items_len = sema.code.extra[extra_index];
+            extra_index += 1;
+            const ranges_len = sema.code.extra[extra_index];
+            extra_index += 1;
+            const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]);
+            extra_index += 1;
+            const items = sema.code.refSlice(extra_index, items_len);
+            extra_index += items_len + info.body_len;
+
+            try case_vals.ensureUnusedCapacity(gpa, items.len);
+            for (items, 0..) |item_ref, item_i| {
+                case_vals.appendAssumeCapacity(try sema.validateSwitchItemError(
+                    block,
+                    seen_errors,
+                    item_ref,
+                    operand_ty,
+                    src_node_offset,
+                    .{ .multi = .{ .prong = multi_i, .item = @intCast(item_i) } },
+                ));
+            }
+
+            try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset);
+        }
+    }
+
+    switch (try sema.resolveInferredErrorSetTy(block, src, operand_ty.toIntern())) {
+        .anyerror_type => {
+            if (!has_else) {
+                return sema.fail(
+                    block,
+                    src,
+                    "else prong required when switching on type 'anyerror'",
+                    .{},
+                );
+            }
+            return Type.anyerror;
+        },
+        else => |err_set_ty_index| else_validation: {
+            const error_names = ip.indexToKey(err_set_ty_index).error_set_type.names;
+            var maybe_msg: ?*Module.ErrorMsg = null;
+            errdefer if (maybe_msg) |msg| msg.destroy(sema.gpa);
+
+            for (error_names.get(ip)) |error_name| {
+                if (!seen_errors.contains(error_name) and !has_else) {
+                    const msg = maybe_msg orelse blk: {
+                        maybe_msg = try sema.errMsg(
+                            block,
+                            src,
+                            "switch must handle all possibilities",
+                            .{},
+                        );
+                        break :blk maybe_msg.?;
+                    };
+
+                    try sema.errNote(
+                        block,
+                        src,
+                        msg,
+                        "unhandled error value: 'error.{}'",
+                        .{error_name.fmt(ip)},
+                    );
+                }
+            }
+
+            if (maybe_msg) |msg| {
+                maybe_msg = null;
+                try sema.addDeclaredHereNote(msg, operand_ty);
+                return sema.failWithOwnedErrorMsg(block, msg);
+            }
+
+            if (has_else and seen_errors.count() == error_names.len) {
+                // In order to enable common patterns for generic code allow simple else bodies
+                // else => unreachable,
+                // else => return,
+                // else => |e| return e,
+                // even if all the possible errors were already handled.
+                const tags = sema.code.instructions.items(.tag);
+                for (else_case.body) |else_inst| switch (tags[@intFromEnum(else_inst)]) {
+                    .dbg_block_begin,
+                    .dbg_block_end,
+                    .dbg_stmt,
+                    .dbg_var_val,
+                    .ret_type,
+                    .as_node,
+                    .ret_node,
+                    .@"unreachable",
+                    .@"defer",
+                    .defer_err_code,
+                    .err_union_code,
+                    .ret_err_value_code,
+                    .restore_err_ret_index,
+                    .is_non_err,
+                    .ret_is_non_err,
+                    .condbr,
+                    => {},
+                    else => break,
+                } else break :else_validation;
+
+                return sema.fail(
+                    block,
+                    else_case.src,
+                    "unreachable else prong; all cases already handled",
+                    .{},
+                );
+            }
+
+            var names: InferredErrorSet.NameMap = .{};
+            try names.ensureUnusedCapacity(sema.arena, error_names.len);
+            for (error_names.get(ip)) |error_name| {
+                if (seen_errors.contains(error_name)) continue;
+
+                names.putAssumeCapacityNoClobber(error_name, {});
+            }
+            // No need to keep the hash map metadata correct; here we
+            // extract the (sorted) keys only.
+            return try mod.errorSetFromUnsortedNames(names.keys());
+        },
+    }
+    return null;
+}
+
 fn validateSwitchRange(
     sema: *Sema,
     block: *Block,