Commit bcb673d94a
Changed files (3)
src
test
behavior
cases
compile_errors
src/Sema.zig
@@ -2392,6 +2392,34 @@ fn failWithOwnedErrorMsg(sema: *Sema, err_msg: *Module.ErrorMsg) CompileError {
return error.AnalysisFail;
}
+/// Given an ErrorMsg, modify its message and source location to the given values, turning the
+/// original message into a note. Notes on the original message are preserved as further notes.
+/// Reference trace is preserved.
+fn reparentOwnedErrorMsg(
+ sema: *Sema,
+ block: *Block,
+ src: LazySrcLoc,
+ msg: *Module.ErrorMsg,
+ comptime format: []const u8,
+ args: anytype,
+) !void {
+ const mod = sema.mod;
+ const src_decl = mod.declPtr(block.src_decl);
+ const resolved_src = src.toSrcLoc(src_decl, mod);
+ const msg_str = try std.fmt.allocPrint(mod.gpa, format, args);
+
+ const orig_notes = msg.notes.len;
+ msg.notes = try sema.gpa.realloc(msg.notes, orig_notes + 1);
+ std.mem.copyBackwards(Module.ErrorMsg, msg.notes[1..], msg.notes[0..orig_notes]);
+ msg.notes[0] = .{
+ .src_loc = msg.src_loc,
+ .msg = msg.msg,
+ };
+
+ msg.src_loc = resolved_src;
+ msg.msg = msg_str;
+}
+
const align_ty = Type.u29;
fn analyzeAsAlign(
@@ -10082,6 +10110,8 @@ const SwitchProngAnalysis = struct {
operand: Air.Inst.Ref,
/// May be `undefined` if no prong has a by-ref capture.
operand_ptr: Air.Inst.Ref,
+ /// The switch condition value. For unions, `operand` is the union and `cond` is its tag.
+ cond: Air.Inst.Ref,
/// If this switch is on an error set, this is the type to assign to the
/// `else` prong. If `null`, the prong should be unreachable.
else_error_ty: ?Type,
@@ -10315,61 +10345,245 @@ const SwitchProngAnalysis = struct {
const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, mod).?);
const first_field = union_obj.fields.values()[first_field_index];
- for (case_vals[1..], 0..) |item, i| {
+ const field_tys = try sema.arena.alloc(Type, case_vals.len);
+ for (case_vals, field_tys) |item, *field_ty| {
const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable;
+ const field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, sema.mod).?);
+ field_ty.* = union_obj.fields.values()[field_idx].ty;
+ }
- const field_index = operand_ty.unionTagFieldIndex(item_val, mod).?;
- const field = union_obj.fields.values()[field_index];
- if (!field.ty.eql(first_field.ty, mod)) {
- const msg = msg: {
- const capture_src = raw_capture_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .none);
+ // Fast path: if all the operands are the same type already, we don't need to hit
+ // PTR! This will also allow us to emit simpler code.
+ const same_types = for (field_tys[1..]) |field_ty| {
+ if (!field_ty.eql(field_tys[0], sema.mod)) break false;
+ } else true;
- const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{});
- errdefer msg.destroy(sema.gpa);
+ const capture_ty = if (same_types) field_tys[0] else capture_ty: {
+ // We need values to run PTR on, so make a bunch of undef constants.
+ const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len);
+ for (dummy_captures, field_tys) |*dummy, field_ty| {
+ dummy.* = try sema.addConstUndef(field_ty);
+ }
+
+ const case_srcs = try sema.arena.alloc(?LazySrcLoc, case_vals.len);
+ @memset(case_srcs, .unneeded);
+ break :capture_ty sema.resolvePeerTypes(block, .unneeded, dummy_captures, .{ .override = case_srcs }) catch |err| switch (err) {
+ error.NeededSourceLocation => {
// This must be a multi-prong so this must be a `multi_capture` src
const multi_idx = raw_capture_src.multi_capture;
+ const src_decl_ptr = sema.mod.declPtr(block.src_decl);
+ for (case_srcs, 0..) |*case_src, i| {
+ const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } };
+ case_src.* = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
+ }
+ const capture_src = raw_capture_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
+ _ = sema.resolvePeerTypes(block, capture_src, dummy_captures, .{ .override = case_srcs }) catch |err1| switch (err1) {
+ error.AnalysisFail => {
+ const msg = sema.err orelse return error.AnalysisFail;
+ try sema.reparentOwnedErrorMsg(block, capture_src, msg, "capture group with incompatible types", .{});
+ return error.AnalysisFail;
+ },
+ else => |e| return e,
+ };
+ unreachable;
+ },
+ else => |e| return e,
+ };
+ };
- const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } };
- const first_item_src = raw_first_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first);
- const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } };
- const item_src = raw_item_src.resolve(mod, mod.declPtr(block.src_decl), switch_node_offset, .first);
- try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(mod)});
- try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(mod)});
- break :msg msg;
- };
- return sema.failWithOwnedErrorMsg(msg);
- }
- }
-
+ // By-reference captures have some further restrictions which make them easier to emit
if (capture_byref) {
- const field_ty_ptr = try Type.ptr(sema.arena, mod, .{
- .pointee_type = first_field.ty,
- .@"addrspace" = .generic,
- .mutable = operand_ptr_ty.ptrIsMutable(mod),
+ const operand_ptr_info = operand_ptr_ty.ptrInfo(mod);
+ const capture_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{
+ .pointee_type = capture_ty,
+ .@"addrspace" = operand_ptr_info.@"addrspace",
+ .mutable = operand_ptr_info.mutable,
+ .@"volatile" = operand_ptr_info.@"volatile",
+ // TODO: alignment!
});
+ // By-ref captures of hetereogeneous types are only allowed if each field
+ // pointer type is in-memory coercible to the capture pointer type.
+ if (!same_types) {
+ for (field_tys, 0..) |field_ty, i| {
+ const field_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{
+ .pointee_type = field_ty,
+ .@"addrspace" = operand_ptr_info.@"addrspace",
+ .mutable = operand_ptr_info.mutable,
+ .@"volatile" = operand_ptr_info.@"volatile",
+ // TODO: alignment!
+ });
+ if (.ok != try sema.coerceInMemoryAllowed(block, capture_ptr_ty, field_ptr_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
+ const multi_idx = raw_capture_src.multi_capture;
+ const src_decl_ptr = sema.mod.declPtr(block.src_decl);
+ const capture_src = raw_capture_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
+ const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } };
+ const case_src = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
+ const msg = msg: {
+ const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{});
+ errdefer msg.destroy(sema.gpa);
+ try sema.errNote(block, case_src, msg, "pointer type child '{}' cannot cast into resolved pointer type child '{}'", .{
+ field_ty.fmt(sema.mod),
+ capture_ty.fmt(sema.mod),
+ });
+ try sema.errNote(block, capture_src, msg, "this coercion is only possible when capturing by value", .{});
+ break :msg msg;
+ };
+ return sema.failWithOwnedErrorMsg(msg);
+ }
+ }
+ }
+
if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| {
- return sema.addConstant(field_ty_ptr, (try mod.intern(.{ .ptr = .{
- .ty = field_ty_ptr.toIntern(),
- .addr = .{ .field = .{
- .base = op_ptr_val.toIntern(),
- .index = first_field_index,
- } },
- } })).toValue());
+ if (op_ptr_val.isUndef(mod)) return sema.addConstUndef(capture_ptr_ty);
+ return sema.addConstant(
+ capture_ptr_ty,
+ (try mod.intern(.{ .ptr = .{
+ .ty = capture_ptr_ty.toIntern(),
+ .addr = .{ .field = .{
+ .base = op_ptr_val.toIntern(),
+ .index = first_field_index,
+ } },
+ } })).toValue(),
+ );
}
+
try sema.requireRuntimeBlock(block, operand_src, null);
- return block.addStructFieldPtr(spa.operand_ptr, first_field_index, field_ty_ptr);
+ return block.addStructFieldPtr(spa.operand_ptr, first_field_index, capture_ptr_ty);
}
if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |operand_val| {
- return sema.addConstant(
- first_field.ty,
- mod.intern_pool.indexToKey(operand_val.toIntern()).un.val.toValue(),
- );
+ if (operand_val.isUndef(mod)) return sema.addConstUndef(capture_ty);
+ const union_val = mod.intern_pool.indexToKey(operand_val.toIntern()).un;
+ if (union_val.tag.toValue().isUndef(mod)) return sema.addConstUndef(capture_ty);
+ const active_field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(union_val.tag.toValue(), sema.mod).?);
+ const field_ty = union_obj.fields.values()[active_field_idx].ty;
+ const uncoerced = try sema.addConstant(field_ty, union_val.val.toValue());
+ return sema.coerce(block, capture_ty, uncoerced, operand_src);
}
+
try sema.requireRuntimeBlock(block, operand_src, null);
- return block.addStructFieldVal(spa.operand, first_field_index, first_field.ty);
+
+ if (same_types) {
+ return block.addStructFieldVal(spa.operand, first_field_index, capture_ty);
+ }
+
+ // We may have to emit a switch block which coerces the operand to the capture type.
+ // If we can, try to avoid that using in-memory coercions.
+ const first_non_imc = in_mem: {
+ for (field_tys, 0..) |field_ty, i| {
+ if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
+ break :in_mem i;
+ }
+ }
+ // All fields are in-memory coercible to the resolved type!
+ // Just take the first field and bitcast the result.
+ const uncoerced = try block.addStructFieldVal(spa.operand, first_field_index, first_field.ty);
+ return block.addBitCast(capture_ty, uncoerced);
+ };
+
+ // By-val capture with heterogeneous types which are not all in-memory coercible to
+ // the resolved capture type. We finally have to fall back to the ugly method.
+
+ // However, let's first track which operands are in-memory coercible. There may well
+ // be several, and we can squash all of these cases into the same switch prong using
+ // a simple bitcast. We'll make this the 'else' prong.
+
+ var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len);
+ in_mem_coercible.unset(first_non_imc);
+ {
+ const next = first_non_imc + 1;
+ for (field_tys[next..], next..) |field_ty, i| {
+ if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
+ in_mem_coercible.unset(i);
+ }
+ }
+ }
+
+ const capture_block_inst = try block.addInstAsIndex(.{
+ .tag = .block,
+ .data = .{
+ .ty_pl = .{
+ .ty = try sema.addType(capture_ty),
+ .payload = undefined, // updated below
+ },
+ },
+ });
+
+ const prong_count = field_tys.len - in_mem_coercible.count();
+
+ const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts
+ var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra);
+ defer cases_extra.deinit();
+
+ {
+ // Non-bitcast cases
+ var it = in_mem_coercible.iterator(.{ .kind = .unset });
+ while (it.next()) |idx| {
+ var coerce_block = block.makeSubBlock();
+ defer coerce_block.instructions.deinit(sema.gpa);
+
+ const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, idx), field_tys[idx]);
+ const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) {
+ error.NeededSourceLocation => {
+ const multi_idx = raw_capture_src.multi_capture;
+ const src_decl_ptr = sema.mod.declPtr(block.src_decl);
+ const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, idx) } };
+ const case_src = raw_case_src.resolve(mod, src_decl_ptr, switch_node_offset, .none);
+ _ = try sema.coerce(&coerce_block, capture_ty, uncoerced, case_src);
+ unreachable;
+ },
+ else => |e| return e,
+ };
+ _ = try coerce_block.addBr(capture_block_inst, coerced);
+
+ try cases_extra.ensureUnusedCapacity(3 + coerce_block.instructions.items.len);
+ cases_extra.appendAssumeCapacity(1); // items_len
+ cases_extra.appendAssumeCapacity(@intCast(u32, coerce_block.instructions.items.len)); // body_len
+ cases_extra.appendAssumeCapacity(@enumToInt(case_vals[idx])); // item
+ cases_extra.appendSliceAssumeCapacity(coerce_block.instructions.items); // body
+ }
+ }
+ const else_body_len = len: {
+ // 'else' prong uses a bitcast
+ var coerce_block = block.makeSubBlock();
+ defer coerce_block.instructions.deinit(sema.gpa);
+
+ const first_imc = in_mem_coercible.findFirstSet().?;
+ const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, first_imc), field_tys[first_imc]);
+ const coerced = try coerce_block.addBitCast(capture_ty, uncoerced);
+ _ = try coerce_block.addBr(capture_block_inst, coerced);
+
+ try cases_extra.appendSlice(coerce_block.instructions.items);
+ break :len coerce_block.instructions.items.len;
+ };
+
+ try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.SwitchBr).Struct.fields.len +
+ cases_extra.items.len +
+ @typeInfo(Air.Block).Struct.fields.len +
+ 1);
+
+ const switch_br_inst = @intCast(u32, sema.air_instructions.len);
+ try sema.air_instructions.append(sema.gpa, .{
+ .tag = .switch_br,
+ .data = .{ .pl_op = .{
+ .operand = spa.cond,
+ .payload = sema.addExtraAssumeCapacity(Air.SwitchBr{
+ .cases_len = @intCast(u32, prong_count),
+ .else_body_len = @intCast(u32, else_body_len),
+ }),
+ } },
+ });
+ sema.air_extra.appendSliceAssumeCapacity(cases_extra.items);
+
+ // Set up block body
+ sema.air_instructions.items(.data)[capture_block_inst].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{
+ .body_len = 1,
+ });
+ sema.air_extra.appendAssumeCapacity(switch_br_inst);
+
+ return Air.indexToRef(capture_block_inst);
},
.ErrorSet => {
if (capture_byref) {
@@ -11099,6 +11313,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
.parent_block = block,
.operand = raw_operand.val,
.operand_ptr = raw_operand.ptr,
+ .cond = operand,
.else_error_ty = else_error_ty,
.switch_block_inst = inst,
.tag_capture_inst = tag_capture_inst,
test/behavior/switch.zig
@@ -1,5 +1,6 @@
const builtin = @import("builtin");
const std = @import("std");
+const assert = std.debug.assert;
const expect = std.testing.expect;
const expectError = std.testing.expectError;
const expectEqual = std.testing.expectEqual;
@@ -717,3 +718,70 @@ test "comptime inline switch" {
try expectEqual(u32, value);
}
+
+test "switch capture peer type resolution" {
+ if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
+ const U = union(enum) {
+ a: u32,
+ b: u64,
+ fn innerVal(u: @This()) u64 {
+ switch (u) {
+ .a, .b => |x| return x,
+ }
+ }
+ };
+
+ try expectEqual(@as(u64, 100), U.innerVal(.{ .a = 100 }));
+ try expectEqual(@as(u64, 200), U.innerVal(.{ .b = 200 }));
+}
+
+test "switch capture peer type resolution for in-memory coercible payloads" {
+ if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
+ const T1 = c_int;
+ const T2 = @Type(@typeInfo(T1));
+
+ comptime assert(T1 != T2);
+
+ const U = union(enum) {
+ a: T1,
+ b: T2,
+ fn innerVal(u: @This()) c_int {
+ switch (u) {
+ .a, .b => |x| return x,
+ }
+ }
+ };
+
+ try expectEqual(@as(c_int, 100), U.innerVal(.{ .a = 100 }));
+ try expectEqual(@as(c_int, 200), U.innerVal(.{ .b = 200 }));
+}
+
+test "switch pointer capture peer type resolution" {
+ if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
+ const T1 = c_int;
+ const T2 = @Type(@typeInfo(T1));
+
+ comptime assert(T1 != T2);
+
+ const U = union(enum) {
+ a: T1,
+ b: T2,
+ fn innerVal(u: *@This()) *c_int {
+ switch (u.*) {
+ .a, .b => |*ptr| return ptr,
+ }
+ }
+ };
+
+ var ua: U = .{ .a = 100 };
+ var ub: U = .{ .b = 200 };
+
+ ua.innerVal().* = 111;
+ ub.innerVal().* = 222;
+
+ try expectEqual(U{ .a = 111 }, ua);
+ try expectEqual(U{ .b = 222 }, ub);
+}
test/cases/compile_errors/switch_capture_incompatible_types.zig
@@ -0,0 +1,27 @@
+export fn f() void {
+ const U = union(enum) { a: u32, b: *u8 };
+ var u: U = undefined;
+ switch (u) {
+ .a, .b => |val| _ = val,
+ }
+}
+
+export fn g() void {
+ const U = union(enum) { a: u64, b: u32 };
+ var u: U = undefined;
+ switch (u) {
+ .a, .b => |*ptr| _ = ptr,
+ }
+}
+
+// error
+// backend=stage2
+// target=native
+//
+// :5:20: error: capture group with incompatible types
+// :5:20: note: incompatible types: 'u32' and '*u8'
+// :5:10: note: type 'u32' here
+// :5:14: note: type '*u8' here
+// :13:20: error: capture group with incompatible types
+// :13:14: note: pointer type child 'u32' cannot cast into resolved pointer type child 'u64'
+// :13:20: note: this coercion is only possible when capturing by value