Commit 81a3910e44

Andrew Kelley <andrew@ziglang.org>
2021-12-29 04:17:34
Sema: improve union support
* reduce number of branches in zirCmpEq * implement equality comparison for enums and unions * fix coercion from union to its tag type resulting in the wrong type * fix method calls of unions * implement peer type resolution for unions, enums, and enum literals * fix union tag type memory in the wrong arena
1 parent 6229d37
src/Sema.zig
@@ -8523,28 +8523,27 @@ fn zirCmpEq(
             return Air.Inst.Ref.bool_false;
         }
     }
-    if (((lhs_ty_tag == .Null and rhs_ty_tag == .Optional) or
-        rhs_ty_tag == .Null and lhs_ty_tag == .Optional))
-    {
-        // comparing null with optionals
-        const opt_operand = if (lhs_ty_tag == .Null) rhs else lhs;
-        return sema.analyzeIsNull(block, src, opt_operand, op == .neq);
+
+    // comparing null with optionals
+    if (lhs_ty_tag == .Null and (rhs_ty_tag == .Optional or rhs_ty.isCPtr())) {
+        return sema.analyzeIsNull(block, src, rhs, op == .neq);
     }
-    if (((lhs_ty_tag == .Null and rhs_ty.isCPtr()) or (rhs_ty_tag == .Null and lhs_ty.isCPtr()))) {
-        // comparing null with C pointers
-        const opt_operand = if (lhs_ty_tag == .Null) rhs else lhs;
-        return sema.analyzeIsNull(block, src, opt_operand, op == .neq);
+    if (rhs_ty_tag == .Null and (lhs_ty_tag == .Optional or lhs_ty.isCPtr())) {
+        return sema.analyzeIsNull(block, src, lhs, op == .neq);
     }
+
     if (lhs_ty_tag == .Null or rhs_ty_tag == .Null) {
         const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty;
         return sema.fail(block, src, "comparison of '{}' with null", .{non_null_type});
     }
-    if (lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) {
-        return sema.analyzeCmpUnionTag(block, rhs, rhs_src, lhs, lhs_src, op);
-    }
-    if (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union) {
+
+    if (lhs_ty_tag == .Union and (rhs_ty_tag == .EnumLiteral or rhs_ty_tag == .Enum)) {
         return sema.analyzeCmpUnionTag(block, lhs, lhs_src, rhs, rhs_src, op);
     }
+    if (rhs_ty_tag == .Union and (lhs_ty_tag == .EnumLiteral or lhs_ty_tag == .Enum)) {
+        return sema.analyzeCmpUnionTag(block, rhs, rhs_src, lhs, lhs_src, op);
+    }
+
     if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) {
         const runtime_src: LazySrcLoc = src: {
             if (try sema.resolveMaybeUndefVal(block, lhs_src, lhs)) |lval| {
@@ -12174,7 +12173,14 @@ fn fieldCallBind(
                 const ptr_inst = try block.addStructFieldPtr(object_ptr, field_index, ptr_field_ty);
                 return sema.analyzeLoad(block, src, ptr_inst, src);
             },
-            .Union => return sema.fail(block, src, "TODO implement field calls on unions", .{}),
+            .Union => {
+                const union_ty = try sema.resolveTypeFields(block, src, concrete_ty);
+                const fields = union_ty.unionFields();
+                const field_index_usize = fields.getIndex(field_name) orelse break :find_field;
+
+                _ = field_index_usize;
+                return sema.fail(block, src, "TODO implement field calls on unions", .{});
+            },
             .Type => {
                 const namespace = try sema.analyzeLoad(block, src, object_ptr, src);
                 return sema.fieldVal(block, src, namespace, field_name, field_name_src);
@@ -12922,7 +12928,7 @@ fn coerce(
                 // union to its own tag type
                 const union_tag_ty = inst_ty.unionTagType() orelse break :blk;
                 if (union_tag_ty.eql(dest_ty)) {
-                    return sema.unionToTag(block, inst_ty, inst, inst_src);
+                    return sema.unionToTag(block, dest_ty, inst, inst_src);
                 }
             },
             else => {},
@@ -14589,10 +14595,19 @@ fn resolvePeerTypes(
                     chosen_i = candidate_i + 1;
                     continue;
                 },
+                .Union => continue,
                 else => {},
             },
             .EnumLiteral => switch (chosen_ty_tag) {
-                .Enum => continue,
+                .Enum, .Union => continue,
+                else => {},
+            },
+            .Union => switch (chosen_ty_tag) {
+                .Enum, .EnumLiteral => {
+                    chosen = candidate;
+                    chosen_i = candidate_i + 1;
+                    continue;
+                },
                 else => {},
             },
             .Pointer => {
@@ -15160,7 +15175,7 @@ fn semaUnionFields(mod: *Module, union_obj: *Module.Union) CompileError!void {
             enum_value_map = &union_obj.tag_ty.castTag(.enum_numbered).?.data.values;
         } else {
             // The provided type is the enum tag type.
-            union_obj.tag_ty = provided_ty;
+            union_obj.tag_ty = try provided_ty.copy(decl_arena_allocator);
         }
     } else {
         // If auto_enum_tag is false, this is an untagged union. However, for semantic analysis
src/value.zig
@@ -1781,7 +1781,7 @@ pub const Value = extern union {
 
     pub fn unionTag(val: Value) Value {
         switch (val.tag()) {
-            .undef => return val,
+            .undef, .enum_field_index => return val,
             .@"union" => return val.castTag(.@"union").?.data.tag,
             else => unreachable,
         }
test/behavior/union.zig
@@ -152,3 +152,94 @@ const AlignTestTaggedUnion = union(enum) {
     A: [9]u8,
     B: u64,
 };
+
+const Letter = enum { A, B, C };
+const Payload = union(Letter) {
+    A: i32,
+    B: f64,
+    C: bool,
+};
+
+test "union with specified enum tag" {
+    try doTest();
+    comptime try doTest();
+}
+
+fn doTest() error{TestUnexpectedResult}!void {
+    try expect((try bar(Payload{ .A = 1234 })) == -10);
+}
+
+fn bar(value: Payload) error{TestUnexpectedResult}!i32 {
+    try expect(@as(Letter, value) == Letter.A);
+    return switch (value) {
+        Payload.A => |x| return x - 1244,
+        Payload.B => |x| if (x == 12.34) @as(i32, 20) else 21,
+        Payload.C => |x| if (x) @as(i32, 30) else 31,
+    };
+}
+
+fn testComparison() !void {
+    var x = Payload{ .A = 42 };
+    try expect(x == .A);
+    try expect(x != .B);
+    try expect(x != .C);
+    try expect((x == .B) == false);
+    try expect((x == .C) == false);
+    try expect((x != .A) == false);
+}
+
+test "comparison between union and enum literal" {
+    try testComparison();
+    comptime try testComparison();
+}
+
+const TheTag = enum { A, B, C };
+const TheUnion = union(TheTag) {
+    A: i32,
+    B: i32,
+    C: i32,
+};
+test "cast union to tag type of union" {
+    try testCastUnionToTag();
+    comptime try testCastUnionToTag();
+}
+
+fn testCastUnionToTag() !void {
+    var u = TheUnion{ .B = 1234 };
+    try expect(@as(TheTag, u) == TheTag.B);
+}
+
+test "cast tag type of union to union" {
+    var x: Value2 = Letter2.B;
+    try expect(@as(Letter2, x) == Letter2.B);
+}
+const Letter2 = enum { A, B, C };
+const Value2 = union(Letter2) {
+    A: i32,
+    B,
+    C,
+};
+
+test "implicit cast union to its tag type" {
+    var x: Value2 = Letter2.B;
+    try expect(x == Letter2.B);
+    try giveMeLetterB(x);
+}
+fn giveMeLetterB(x: Letter2) !void {
+    try expect(x == Value2.B);
+}
+
+// TODO it looks like this test intended to test packed unions, but this is not a packed
+// union. go through git history and find out what happened.
+pub const PackThis = union(enum) {
+    Invalid: bool,
+    StringLiteral: u2,
+};
+
+test "constant packed union" {
+    try testConstPackedUnion(&[_]PackThis{PackThis{ .StringLiteral = 1 }});
+}
+
+fn testConstPackedUnion(expected_tokens: []const PackThis) !void {
+    try expect(expected_tokens[0].StringLiteral == 1);
+}
test/behavior/union_stage1.zig
@@ -10,11 +10,6 @@ const Payload = union(Letter) {
     C: bool,
 };
 
-test "union with specified enum tag" {
-    try doTest();
-    comptime try doTest();
-}
-
 fn doTest() error{TestUnexpectedResult}!void {
     try expect((try bar(Payload{ .A = 1234 })) == -10);
 }
@@ -28,6 +23,18 @@ fn bar(value: Payload) error{TestUnexpectedResult}!i32 {
     };
 }
 
+test "packed union generates correctly aligned LLVM type" {
+    const U = packed union {
+        f1: fn () error{TestUnexpectedResult}!void,
+        f2: u32,
+    };
+    var foo = [_]U{
+        U{ .f1 = doTest },
+        U{ .f2 = 0 },
+    };
+    try foo[0].f1();
+}
+
 const MultipleChoice = union(enum(u32)) {
     A = 20,
     B = 40,
@@ -100,51 +107,6 @@ test "union field access gives the enum values" {
     try expect(TheUnion.C == TheTag.C);
 }
 
-test "cast union to tag type of union" {
-    try testCastUnionToTag();
-    comptime try testCastUnionToTag();
-}
-
-fn testCastUnionToTag() !void {
-    var u = TheUnion{ .B = 1234 };
-    try expect(@as(TheTag, u) == TheTag.B);
-}
-
-test "cast tag type of union to union" {
-    var x: Value2 = Letter2.B;
-    try expect(@as(Letter2, x) == Letter2.B);
-}
-const Letter2 = enum { A, B, C };
-const Value2 = union(Letter2) {
-    A: i32,
-    B,
-    C,
-};
-
-test "implicit cast union to its tag type" {
-    var x: Value2 = Letter2.B;
-    try expect(x == Letter2.B);
-    try giveMeLetterB(x);
-}
-fn giveMeLetterB(x: Letter2) !void {
-    try expect(x == Value2.B);
-}
-
-// TODO it looks like this test intended to test packed unions, but this is not a packed
-// union. go through git history and find out what happened.
-pub const PackThis = union(enum) {
-    Invalid: bool,
-    StringLiteral: u2,
-};
-
-test "constant packed union" {
-    try testConstPackedUnion(&[_]PackThis{PackThis{ .StringLiteral = 1 }});
-}
-
-fn testConstPackedUnion(expected_tokens: []const PackThis) !void {
-    try expect(expected_tokens[0].StringLiteral == 1);
-}
-
 test "switch on union with only 1 field" {
     var r: PartialInst = undefined;
     r = PartialInst.Compiled;
@@ -355,33 +317,6 @@ test "union no tag with struct member" {
     u.foo();
 }
 
-fn testComparison() !void {
-    var x = Payload{ .A = 42 };
-    try expect(x == .A);
-    try expect(x != .B);
-    try expect(x != .C);
-    try expect((x == .B) == false);
-    try expect((x == .C) == false);
-    try expect((x != .A) == false);
-}
-
-test "comparison between union and enum literal" {
-    try testComparison();
-    comptime try testComparison();
-}
-
-test "packed union generates correctly aligned LLVM type" {
-    const U = packed union {
-        f1: fn () error{TestUnexpectedResult}!void,
-        f2: u32,
-    };
-    var foo = [_]U{
-        U{ .f1 = doTest },
-        U{ .f2 = 0 },
-    };
-    try foo[0].f1();
-}
-
 test "union with one member defaults to u0 tag type" {
     const U0 = union(enum) {
         X: u32,
test/behavior/union_with_members.zig
@@ -1,6 +1,7 @@
-const expect = @import("std").testing.expect;
-const mem = @import("std").mem;
-const fmt = @import("std").fmt;
+const std = @import("std");
+const expect = std.testing.expect;
+const mem = std.mem;
+const fmt = std.fmt;
 
 const ET = union(enum) {
     SINT: i32,