Commit 09e1f37cb6

Andrew Kelley <andrew@ziglang.org>
2021-09-28 08:11:00
stage2: implement union coercion to its own tag
* AIR: add `get_union_tag` instruction - implement in LLVM backend * Sema: implement == and != for union and enum literal - Also implement coercion from union to its own tag type * Value: implement hashing for union values The motivating example is this snippet: comptime assert(@typeInfo(T) == .Float); This was the next blocker for stage2 building compiler-rt. Now it is switch at compile-time on an integer.
1 parent c2a7542
src/codegen/c.zig
@@ -956,6 +956,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
             .memset           => try airMemset(f, inst),
             .memcpy           => try airMemcpy(f, inst),
             .set_union_tag    => try airSetUnionTag(f, inst),
+            .get_union_tag    => try airGetUnionTag(f, inst),
 
             .int_to_float,
             .float_to_int,
@@ -2096,6 +2097,22 @@ fn airSetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue {
     return CValue.none;
 }
 
+fn airGetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue {
+    if (f.liveness.isUnused(inst))
+        return CValue.none;
+
+    const inst_ty = f.air.typeOfIndex(inst);
+    const local = try f.allocLocal(inst_ty, .Const);
+    const ty_op = f.air.instructions.items(.data)[inst].ty_op;
+    const writer = f.object.writer();
+    const operand = try f.resolveInst(ty_op.operand);
+
+    try writer.writeAll("get_union_tag(");
+    try f.writeCValue(writer, operand);
+    try writer.writeAll(");\n");
+    return local;
+}
+
 fn toMemoryOrder(order: std.builtin.AtomicOrder) [:0]const u8 {
     return switch (order) {
         .Unordered => "memory_order_relaxed",
src/codegen/llvm.zig
@@ -1304,6 +1304,7 @@ pub const FuncGen = struct {
                 .memset         => try self.airMemset(inst),
                 .memcpy         => try self.airMemcpy(inst),
                 .set_union_tag  => try self.airSetUnionTag(inst),
+                .get_union_tag  => try self.airGetUnionTag(inst),
 
                 .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered),
                 .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic),
@@ -2557,6 +2558,18 @@ pub const FuncGen = struct {
         return null;
     }
 
+    fn airGetUnionTag(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+        if (self.liveness.isUnused(inst))
+            return null;
+
+        const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+        const un_ty = self.air.typeOf(ty_op.operand);
+        const un = try self.resolveInst(ty_op.operand);
+
+        _ = un_ty; // TODO handle when onlyTagHasCodegenBits() == true and other union forms
+        return self.builder.buildExtractValue(un, 1, "");
+    }
+
     fn fieldPtr(
         self: *FuncGen,
         inst: Air.Inst.Index,
src/Air.zig
@@ -290,6 +290,9 @@ pub const Inst = struct {
         /// Result type is always void.
         /// Uses the `bin_op` field. LHS is union pointer, RHS is new tag value.
         set_union_tag,
+        /// Given a tagged union value, get its tag value.
+        /// Uses the `ty_op` field.
+        get_union_tag,
         /// Given a slice value, return the length.
         /// Result type is always usize.
         /// Uses the `ty_op` field.
@@ -630,6 +633,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
         .array_to_slice,
         .float_to_int,
         .int_to_float,
+        .get_union_tag,
         => return air.getRefType(datas[inst].ty_op.ty),
 
         .loop,
src/codegen.zig
@@ -890,6 +890,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                     .memcpy          => try self.airMemcpy(inst),
                     .memset          => try self.airMemset(inst),
                     .set_union_tag   => try self.airSetUnionTag(inst),
+                    .get_union_tag   => try self.airGetUnionTag(inst),
 
                     .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered),
                     .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic),
@@ -1552,6 +1553,14 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
             return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none });
         }
 
+        fn airGetUnionTag(self: *Self, inst: Air.Inst.Index) !void {
+            const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+            const result: MCValue = if (self.liveness.isUnused(inst)) .dead else switch (arch) {
+                else => return self.fail("TODO implement airGetUnionTag for {}", .{self.target.cpu.arch}),
+            };
+            return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
+        }
+
         fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool {
             if (!self.liveness.operandDies(inst, op_index))
                 return false;
src/Liveness.zig
@@ -297,6 +297,7 @@ fn analyzeInst(
         .array_to_slice,
         .float_to_int,
         .int_to_float,
+        .get_union_tag,
         => {
             const o = inst_datas[inst].ty_op;
             return trackOperands(a, new_set, inst, main_tomb, .{ o.operand, .none, .none });
src/print_air.zig
@@ -179,6 +179,7 @@ const Writer = struct {
             .array_to_slice,
             .int_to_float,
             .float_to_int,
+            .get_union_tag,
             => try w.writeTyOp(s, inst),
 
             .block,
src/Sema.zig
@@ -1349,7 +1349,13 @@ fn zirUnionDecl(
     errdefer new_decl_arena.deinit();
 
     const union_obj = try new_decl_arena.allocator.create(Module.Union);
-    const union_ty = try Type.Tag.@"union".create(&new_decl_arena.allocator, union_obj);
+    const type_tag: Type.Tag = if (small.has_tag_type or small.auto_enum_tag) .union_tagged else .@"union";
+    const union_payload = try new_decl_arena.allocator.create(Type.Payload.Union);
+    union_payload.* = .{
+        .base = .{ .tag = type_tag },
+        .data = union_obj,
+    };
+    const union_ty = Type.initPayload(&union_payload.base);
     const union_val = try Value.Tag.ty.create(&new_decl_arena.allocator, union_ty);
     const type_name = try sema.createTypeName(block, small.name_strategy);
     const new_decl = try sema.mod.createAnonymousDeclNamed(&block.base, .{
@@ -6477,10 +6483,11 @@ fn zirCmpEq(
         const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty;
         return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type});
     }
-    if (((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or
-        (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union)))
-    {
-        return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{});
+    if (lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) {
+        return sema.analyzeCmpUnionTag(block, rhs, rhs_src, lhs, lhs_src, op);
+    }
+    if (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union) {
+        return sema.analyzeCmpUnionTag(block, lhs, lhs_src, rhs, rhs_src, op);
     }
     if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
         const runtime_src: LazySrcLoc = src: {
@@ -6521,6 +6528,28 @@ fn zirCmpEq(
     return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, true);
 }
 
+fn analyzeCmpUnionTag(
+    sema: *Sema,
+    block: *Scope.Block,
+    un: Air.Inst.Ref,
+    un_src: LazySrcLoc,
+    tag: Air.Inst.Ref,
+    tag_src: LazySrcLoc,
+    op: std.math.CompareOperator,
+) CompileError!Air.Inst.Ref {
+    const union_ty = sema.typeOf(un);
+    const union_tag_ty = union_ty.unionTagType() orelse {
+        // TODO note at declaration site that says "union foo is not tagged"
+        return sema.mod.fail(&block.base, un_src, "comparison of union and enum literal is only valid for tagged union types", .{});
+    };
+    // Coerce both the union and the tag to the union's tag type, and then execute the
+    // enum comparison codepath.
+    const coerced_tag = try sema.coerce(block, union_tag_ty, tag, tag_src);
+    const coerced_union = try sema.coerce(block, union_tag_ty, un, un_src);
+
+    return sema.cmpSelf(block, coerced_union, coerced_tag, op, un_src, tag_src);
+}
+
 /// Only called for non-equality operators. See also `zirCmpEq`.
 fn zirCmp(
     sema: *Sema,
@@ -6567,10 +6596,21 @@ fn analyzeCmp(
             @tagName(op), resolved_type,
         });
     }
-
     const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
     const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
+    return sema.cmpSelf(block, casted_lhs, casted_rhs, op, lhs_src, rhs_src);
+}
 
+fn cmpSelf(
+    sema: *Sema,
+    block: *Scope.Block,
+    casted_lhs: Air.Inst.Ref,
+    casted_rhs: Air.Inst.Ref,
+    op: std.math.CompareOperator,
+    lhs_src: LazySrcLoc,
+    rhs_src: LazySrcLoc,
+) CompileError!Air.Inst.Ref {
+    const resolved_type = sema.typeOf(casted_lhs);
     const runtime_src: LazySrcLoc = src: {
         if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| {
             if (lhs_val.isUndef()) return sema.addConstUndef(resolved_type);
@@ -9919,9 +9959,9 @@ fn coerce(
                 }
             }
         },
-        .Enum => {
-            // enum literal to enum
-            if (inst_ty.zigTypeTag() == .EnumLiteral) {
+        .Enum => switch (inst_ty.zigTypeTag()) {
+            .EnumLiteral => {
+                // enum literal to enum
                 const val = try sema.resolveConstValue(block, inst_src, inst);
                 const bytes = val.castTag(.enum_literal).?.data;
                 const resolved_dest_type = try sema.resolveTypeFields(block, inst_src, dest_type);
@@ -9948,7 +9988,15 @@ fn coerce(
                     resolved_dest_type,
                     try Value.Tag.enum_field_index.create(arena, @intCast(u32, field_index)),
                 );
-            }
+            },
+            .Union => blk: {
+                // union to its own tag type
+                const union_tag_ty = inst_ty.unionTagType() orelse break :blk;
+                if (union_tag_ty.eql(dest_type)) {
+                    return sema.unionToTag(block, dest_type, inst, inst_src);
+                }
+            },
+            else => {},
         },
         .ErrorUnion => {
             // T to E!T or E to E!T
@@ -10802,6 +10850,20 @@ fn wrapErrorUnion(
     }
 }
 
+fn unionToTag(
+    sema: *Sema,
+    block: *Scope.Block,
+    dest_type: Type,
+    un: Air.Inst.Ref,
+    un_src: LazySrcLoc,
+) !Air.Inst.Ref {
+    if (try sema.resolveMaybeUndefVal(block, un_src, un)) |un_val| {
+        return sema.addConstant(dest_type, un_val.unionTag());
+    }
+    try sema.requireRuntimeBlock(block, un_src);
+    return block.addTyOp(.get_union_tag, dest_type, un);
+}
+
 fn resolvePeerTypes(
     sema: *Sema,
     block: *Scope.Block,
src/type.zig
@@ -2487,6 +2487,12 @@ pub const Type = extern union {
         };
     }
 
+    pub fn unionFieldType(ty: Type, enum_tag: Value) Type {
+        const union_obj = ty.cast(Payload.Union).?.data;
+        const index = union_obj.tag_ty.enumTagFieldIndex(enum_tag).?;
+        return union_obj.fields.values()[index].ty;
+    }
+
     /// Asserts that the type is an error union.
     pub fn errorUnionPayload(self: Type) Type {
         return switch (self.tag()) {
@@ -3801,6 +3807,8 @@ pub const Type = extern union {
         };
     };
 
+    pub const @"bool" = initTag(.bool);
+
     pub fn ptr(arena: *Allocator, d: Payload.Pointer.Data) !Type {
         assert(d.host_size == 0 or d.bit_offset < d.host_size * 8);
 
src/value.zig
@@ -1275,7 +1275,12 @@ pub const Value = extern union {
                 }
             },
             .Union => {
-                @panic("TODO implement hashing union values");
+                const union_obj = val.castTag(.@"union").?.data;
+                if (ty.unionTagType()) |tag_ty| {
+                    union_obj.tag.hash(tag_ty, hasher);
+                }
+                const active_field_ty = ty.unionFieldType(union_obj.tag);
+                union_obj.val.hash(active_field_ty, hasher);
             },
             .Fn => {
                 @panic("TODO implement hashing function values");
@@ -1431,6 +1436,14 @@ pub const Value = extern union {
         }
     }
 
+    pub fn unionTag(val: Value) Value {
+        switch (val.tag()) {
+            .undef => return val,
+            .@"union" => return val.castTag(.@"union").?.data.tag,
+            else => unreachable,
+        }
+    }
+
     /// Returns a pointer to the element value at the index.
     pub fn elemPtr(self: Value, allocator: *Allocator, index: usize) !Value {
         if (self.castTag(.elem_ptr)) |elem_ptr| {
test/behavior/union.zig
@@ -14,3 +14,21 @@ test "basic unions" {
     foo = Foo{ .float = 12.34 };
     try expect(foo.float == 12.34);
 }
+
+test "init union with runtime value" {
+    var foo: Foo = undefined;
+
+    setFloat(&foo, 12.34);
+    try expect(foo.float == 12.34);
+
+    setInt(&foo, 42);
+    try expect(foo.int == 42);
+}
+
+fn setFloat(foo: *Foo, x: f64) void {
+    foo.* = Foo{ .float = x };
+}
+
+fn setInt(foo: *Foo, x: i32) void {
+    foo.* = Foo{ .int = x };
+}
test/behavior/union_stage1.zig
@@ -49,24 +49,6 @@ test "comptime union field access" {
     }
 }
 
-test "init union with runtime value" {
-    var foo: Foo = undefined;
-
-    setFloat(&foo, 12.34);
-    try expect(foo.float == 12.34);
-
-    setInt(&foo, 42);
-    try expect(foo.int == 42);
-}
-
-fn setFloat(foo: *Foo, x: f64) void {
-    foo.* = Foo{ .float = x };
-}
-
-fn setInt(foo: *Foo, x: i32) void {
-    foo.* = Foo{ .int = x };
-}
-
 const FooExtern = extern union {
     float: f64,
     int: i32,
@@ -185,12 +167,13 @@ test "union field access gives the enum values" {
 }
 
 test "cast union to tag type of union" {
-    try testCastUnionToTag(TheUnion{ .B = 1234 });
-    comptime try testCastUnionToTag(TheUnion{ .B = 1234 });
+    try testCastUnionToTag();
+    comptime try testCastUnionToTag();
 }
 
-fn testCastUnionToTag(x: TheUnion) !void {
-    try expect(@as(TheTag, x) == TheTag.B);
+fn testCastUnionToTag() !void {
+    var u = TheUnion{ .B = 1234 };
+    try expect(@as(TheTag, u) == TheTag.B);
 }
 
 test "cast tag type of union to union" {