Commit 09e1f37cb6
Changed files (11)
src/codegen/c.zig
@@ -956,6 +956,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
.memset => try airMemset(f, inst),
.memcpy => try airMemcpy(f, inst),
.set_union_tag => try airSetUnionTag(f, inst),
+ .get_union_tag => try airGetUnionTag(f, inst),
.int_to_float,
.float_to_int,
@@ -2096,6 +2097,22 @@ fn airSetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue {
return CValue.none;
}
+fn airGetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue {
+ if (f.liveness.isUnused(inst))
+ return CValue.none;
+
+ const inst_ty = f.air.typeOfIndex(inst);
+ const local = try f.allocLocal(inst_ty, .Const);
+ const ty_op = f.air.instructions.items(.data)[inst].ty_op;
+ const writer = f.object.writer();
+ const operand = try f.resolveInst(ty_op.operand);
+
+ try writer.writeAll("get_union_tag(");
+ try f.writeCValue(writer, operand);
+ try writer.writeAll(");\n");
+ return local;
+}
+
fn toMemoryOrder(order: std.builtin.AtomicOrder) [:0]const u8 {
return switch (order) {
.Unordered => "memory_order_relaxed",
src/codegen/llvm.zig
@@ -1304,6 +1304,7 @@ pub const FuncGen = struct {
.memset => try self.airMemset(inst),
.memcpy => try self.airMemcpy(inst),
.set_union_tag => try self.airSetUnionTag(inst),
+ .get_union_tag => try self.airGetUnionTag(inst),
.atomic_store_unordered => try self.airAtomicStore(inst, .Unordered),
.atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic),
@@ -2557,6 +2558,18 @@ pub const FuncGen = struct {
return null;
}
+ fn airGetUnionTag(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+ if (self.liveness.isUnused(inst))
+ return null;
+
+ const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+ const un_ty = self.air.typeOf(ty_op.operand);
+ const un = try self.resolveInst(ty_op.operand);
+
+ _ = un_ty; // TODO handle when onlyTagHasCodegenBits() == true and other union forms
+ return self.builder.buildExtractValue(un, 1, "");
+ }
+
fn fieldPtr(
self: *FuncGen,
inst: Air.Inst.Index,
src/Air.zig
@@ -290,6 +290,9 @@ pub const Inst = struct {
/// Result type is always void.
/// Uses the `bin_op` field. LHS is union pointer, RHS is new tag value.
set_union_tag,
+ /// Given a tagged union value, get its tag value.
+ /// Uses the `ty_op` field.
+ get_union_tag,
/// Given a slice value, return the length.
/// Result type is always usize.
/// Uses the `ty_op` field.
@@ -630,6 +633,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
.array_to_slice,
.float_to_int,
.int_to_float,
+ .get_union_tag,
=> return air.getRefType(datas[inst].ty_op.ty),
.loop,
src/codegen.zig
@@ -890,6 +890,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
.memcpy => try self.airMemcpy(inst),
.memset => try self.airMemset(inst),
.set_union_tag => try self.airSetUnionTag(inst),
+ .get_union_tag => try self.airGetUnionTag(inst),
.atomic_store_unordered => try self.airAtomicStore(inst, .Unordered),
.atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic),
@@ -1552,6 +1553,14 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
}
+ fn airGetUnionTag(self: *Self, inst: Air.Inst.Index) !void {
+ const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+ const result: MCValue = if (self.liveness.isUnused(inst)) .dead else switch (arch) {
+ else => return self.fail("TODO implement airGetUnionTag for {}", .{self.target.cpu.arch}),
+ };
+ return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
+ }
+
fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool {
if (!self.liveness.operandDies(inst, op_index))
return false;
src/Liveness.zig
@@ -297,6 +297,7 @@ fn analyzeInst(
.array_to_slice,
.float_to_int,
.int_to_float,
+ .get_union_tag,
=> {
const o = inst_datas[inst].ty_op;
return trackOperands(a, new_set, inst, main_tomb, .{ o.operand, .none, .none });
src/print_air.zig
@@ -179,6 +179,7 @@ const Writer = struct {
.array_to_slice,
.int_to_float,
.float_to_int,
+ .get_union_tag,
=> try w.writeTyOp(s, inst),
.block,
src/Sema.zig
@@ -1349,7 +1349,13 @@ fn zirUnionDecl(
errdefer new_decl_arena.deinit();
const union_obj = try new_decl_arena.allocator.create(Module.Union);
- const union_ty = try Type.Tag.@"union".create(&new_decl_arena.allocator, union_obj);
+ const type_tag: Type.Tag = if (small.has_tag_type or small.auto_enum_tag) .union_tagged else .@"union";
+ const union_payload = try new_decl_arena.allocator.create(Type.Payload.Union);
+ union_payload.* = .{
+ .base = .{ .tag = type_tag },
+ .data = union_obj,
+ };
+ const union_ty = Type.initPayload(&union_payload.base);
const union_val = try Value.Tag.ty.create(&new_decl_arena.allocator, union_ty);
const type_name = try sema.createTypeName(block, small.name_strategy);
const new_decl = try sema.mod.createAnonymousDeclNamed(&block.base, .{
@@ -6477,10 +6483,11 @@ fn zirCmpEq(
const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty;
return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
}
- if (((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
- (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
- {
- return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
+ if (lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) {
+ return sema.analyzeCmpUnionTag(block, rhs, rhs_src, lhs, lhs_src, op);
+ }
+ if (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union) {
+ return sema.analyzeCmpUnionTag(block, lhs, lhs_src, rhs, rhs_src, op);
}
if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
const runtime_src: LazySrcLoc = src: {
@@ -6521,6 +6528,28 @@ fn zirCmpEq(
return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, true);
}
+fn analyzeCmpUnionTag(
+ sema: *Sema,
+ block: *Scope.Block,
+ un: Air.Inst.Ref,
+ un_src: LazySrcLoc,
+ tag: Air.Inst.Ref,
+ tag_src: LazySrcLoc,
+ op: std.math.CompareOperator,
+) CompileError!Air.Inst.Ref {
+ const union_ty = sema.typeOf(un);
+ const union_tag_ty = union_ty.unionTagType() orelse {
+ // TODO note at declaration site that says "union foo is not tagged"
+ return sema.mod.fail(&block.base, un_src, "comparison of union and enum literal is only valid for tagged union types", .{});
+ };
+ // Coerce both the union and the tag to the union's tag type, and then execute the
+ // enum comparison codepath.
+ const coerced_tag = try sema.coerce(block, union_tag_ty, tag, tag_src);
+ const coerced_union = try sema.coerce(block, union_tag_ty, un, un_src);
+
+ return sema.cmpSelf(block, coerced_union, coerced_tag, op, un_src, tag_src);
+}
+
/// Only called for non-equality operators. See also `zirCmpEq`.
fn zirCmp(
sema: *Sema,
@@ -6567,10 +6596,21 @@ fn analyzeCmp(
@tagName(op), resolved_type,
});
}
-
const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
+ return sema.cmpSelf(block, casted_lhs, casted_rhs, op, lhs_src, rhs_src);
+}
+fn cmpSelf(
+ sema: *Sema,
+ block: *Scope.Block,
+ casted_lhs: Air.Inst.Ref,
+ casted_rhs: Air.Inst.Ref,
+ op: std.math.CompareOperator,
+ lhs_src: LazySrcLoc,
+ rhs_src: LazySrcLoc,
+) CompileError!Air.Inst.Ref {
+ const resolved_type = sema.typeOf(casted_lhs);
const runtime_src: LazySrcLoc = src: {
if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| {
if (lhs_val.isUndef()) return sema.addConstUndef(resolved_type);
@@ -9919,9 +9959,9 @@ fn coerce(
}
}
},
- .Enum => {
- // enum literal to enum
- if (inst_ty.zigTypeTag() == .EnumLiteral) {
+ .Enum => switch (inst_ty.zigTypeTag()) {
+ .EnumLiteral => {
+ // enum literal to enum
const val = try sema.resolveConstValue(block, inst_src, inst);
const bytes = val.castTag(.enum_literal).?.data;
const resolved_dest_type = try sema.resolveTypeFields(block, inst_src, dest_type);
@@ -9948,7 +9988,15 @@ fn coerce(
resolved_dest_type,
try Value.Tag.enum_field_index.create(arena, @intCast(u32, field_index)),
);
- }
+ },
+ .Union => blk: {
+ // union to its own tag type
+ const union_tag_ty = inst_ty.unionTagType() orelse break :blk;
+ if (union_tag_ty.eql(dest_type)) {
+ return sema.unionToTag(block, dest_type, inst, inst_src);
+ }
+ },
+ else => {},
},
.ErrorUnion => {
// T to E!T or E to E!T
@@ -10802,6 +10850,20 @@ fn wrapErrorUnion(
}
}
+fn unionToTag(
+ sema: *Sema,
+ block: *Scope.Block,
+ dest_type: Type,
+ un: Air.Inst.Ref,
+ un_src: LazySrcLoc,
+) !Air.Inst.Ref {
+ if (try sema.resolveMaybeUndefVal(block, un_src, un)) |un_val| {
+ return sema.addConstant(dest_type, un_val.unionTag());
+ }
+ try sema.requireRuntimeBlock(block, un_src);
+ return block.addTyOp(.get_union_tag, dest_type, un);
+}
+
fn resolvePeerTypes(
sema: *Sema,
block: *Scope.Block,
src/type.zig
@@ -2487,6 +2487,12 @@ pub const Type = extern union {
};
}
+ pub fn unionFieldType(ty: Type, enum_tag: Value) Type {
+ const union_obj = ty.cast(Payload.Union).?.data;
+ const index = union_obj.tag_ty.enumTagFieldIndex(enum_tag).?;
+ return union_obj.fields.values()[index].ty;
+ }
+
/// Asserts that the type is an error union.
pub fn errorUnionPayload(self: Type) Type {
return switch (self.tag()) {
@@ -3801,6 +3807,8 @@ pub const Type = extern union {
};
};
+ pub const @"bool" = initTag(.bool);
+
pub fn ptr(arena: *Allocator, d: Payload.Pointer.Data) !Type {
assert(d.host_size == 0 or d.bit_offset < d.host_size * 8);
src/value.zig
@@ -1275,7 +1275,12 @@ pub const Value = extern union {
}
},
.Union => {
- @panic("TODO implement hashing union values");
+ const union_obj = val.castTag(.@"union").?.data;
+ if (ty.unionTagType()) |tag_ty| {
+ union_obj.tag.hash(tag_ty, hasher);
+ }
+ const active_field_ty = ty.unionFieldType(union_obj.tag);
+ union_obj.val.hash(active_field_ty, hasher);
},
.Fn => {
@panic("TODO implement hashing function values");
@@ -1431,6 +1436,14 @@ pub const Value = extern union {
}
}
+ pub fn unionTag(val: Value) Value {
+ switch (val.tag()) {
+ .undef => return val,
+ .@"union" => return val.castTag(.@"union").?.data.tag,
+ else => unreachable,
+ }
+ }
+
/// Returns a pointer to the element value at the index.
pub fn elemPtr(self: Value, allocator: *Allocator, index: usize) !Value {
if (self.castTag(.elem_ptr)) |elem_ptr| {
test/behavior/union.zig
@@ -14,3 +14,21 @@ test "basic unions" {
foo = Foo{ .float = 12.34 };
try expect(foo.float == 12.34);
}
+
+test "init union with runtime value" {
+ var foo: Foo = undefined;
+
+ setFloat(&foo, 12.34);
+ try expect(foo.float == 12.34);
+
+ setInt(&foo, 42);
+ try expect(foo.int == 42);
+}
+
+fn setFloat(foo: *Foo, x: f64) void {
+ foo.* = Foo{ .float = x };
+}
+
+fn setInt(foo: *Foo, x: i32) void {
+ foo.* = Foo{ .int = x };
+}
test/behavior/union_stage1.zig
@@ -49,24 +49,6 @@ test "comptime union field access" {
}
}
-test "init union with runtime value" {
- var foo: Foo = undefined;
-
- setFloat(&foo, 12.34);
- try expect(foo.float == 12.34);
-
- setInt(&foo, 42);
- try expect(foo.int == 42);
-}
-
-fn setFloat(foo: *Foo, x: f64) void {
- foo.* = Foo{ .float = x };
-}
-
-fn setInt(foo: *Foo, x: i32) void {
- foo.* = Foo{ .int = x };
-}
-
const FooExtern = extern union {
float: f64,
int: i32,
@@ -185,12 +167,13 @@ test "union field access gives the enum values" {
}
test "cast union to tag type of union" {
- try testCastUnionToTag(TheUnion{ .B = 1234 });
- comptime try testCastUnionToTag(TheUnion{ .B = 1234 });
+ try testCastUnionToTag();
+ comptime try testCastUnionToTag();
}
-fn testCastUnionToTag(x: TheUnion) !void {
- try expect(@as(TheTag, x) == TheTag.B);
+fn testCastUnionToTag() !void {
+ var u = TheUnion{ .B = 1234 };
+ try expect(@as(TheTag, u) == TheTag.B);
}
test "cast tag type of union to union" {