Commit c764640e92

Andrew Kelley <andrew@ziglang.org>
2022-08-17 21:55:08
Sema: fix generics with struct literal coerced to tagged union
The `Value.eql` function has to test for value equality *as-if* the lhs value parameter is coerced into the type of the rhs. For tagged unions, there was a problematic case when the lhs was an anonymous struct, because in such case the value is empty_struct_value and the type contains all the value information. But the only type available in the function was the rhs type. So the fix involved making `Value.eqlAdvanced` also accept the lhs type, and then enhancing the logic to handle the case of the `.anon_struct` tag. closes #12418 Tests run locally: * test-behavior * test-cases
1 parent a12abc6
Changed files (3)
src/Sema.zig
@@ -1495,7 +1495,8 @@ pub fn resolveInst(sema: *Sema, zir_ref: Zir.Inst.Ref) !Air.Inst.Ref {
 
     // Finally, the last section of indexes refers to the map of ZIR=>AIR.
     const inst = sema.inst_map.get(@intCast(u32, i)).?;
-    if (sema.typeOf(inst).tag() == .generic_poison) return error.GenericPoison;
+    const ty = sema.typeOf(inst);
+    if (ty.tag() == .generic_poison) return error.GenericPoison;
     return inst;
 }
 
@@ -5570,11 +5571,15 @@ const GenericCallAdapter = struct {
     generic_fn: *Module.Fn,
     precomputed_hash: u64,
     func_ty_info: Type.Payload.Function.Data,
-    /// Unlike comptime_args, the Type here is not always present.
-    /// .generic_poison is used to communicate non-anytype parameters.
-    comptime_tvs: []const TypedValue,
+    args: []const Arg,
     module: *Module,
 
+    const Arg = struct {
+        ty: Type,
+        val: Value,
+        is_anytype: bool,
+    };
+
     pub fn eql(ctx: @This(), adapted_key: void, other_key: *Module.Fn) bool {
         _ = adapted_key;
         // The generic function Decl is guaranteed to be the first dependency
@@ -5585,10 +5590,10 @@ const GenericCallAdapter = struct {
 
         const other_comptime_args = other_key.comptime_args.?;
         for (other_comptime_args[0..ctx.func_ty_info.param_types.len]) |other_arg, i| {
-            const this_arg = ctx.comptime_tvs[i];
+            const this_arg = ctx.args[i];
             const this_is_comptime = this_arg.val.tag() != .generic_poison;
             const other_is_comptime = other_arg.val.tag() != .generic_poison;
-            const this_is_anytype = this_arg.ty.tag() != .generic_poison;
+            const this_is_anytype = this_arg.is_anytype;
             const other_is_anytype = other_key.isAnytypeParam(ctx.module, @intCast(u32, i));
 
             if (other_is_anytype != this_is_anytype) return false;
@@ -5607,7 +5612,17 @@ const GenericCallAdapter = struct {
                 }
             } else if (this_is_comptime) {
                 // Both are comptime parameters but not anytype parameters.
-                if (!this_arg.val.eql(other_arg.val, other_arg.ty, ctx.module)) {
+                // We assert no error is possible here because any lazy values must be resolved
+                // before inserting into the generic function hash map.
+                const is_eql = Value.eqlAdvanced(
+                    this_arg.val,
+                    this_arg.ty,
+                    other_arg.val,
+                    other_arg.ty,
+                    ctx.module,
+                    null,
+                ) catch unreachable;
+                if (!is_eql) {
                     return false;
                 }
             }
@@ -6258,8 +6273,7 @@ fn instantiateGenericCall(
     var hasher = std.hash.Wyhash.init(0);
     std.hash.autoHash(&hasher, @ptrToInt(module_fn));
 
-    const comptime_tvs = try sema.arena.alloc(TypedValue, func_ty_info.param_types.len);
-
+    const generic_args = try sema.arena.alloc(GenericCallAdapter.Arg, func_ty_info.param_types.len);
     {
         var i: usize = 0;
         for (fn_info.param_body) |inst| {
@@ -6283,8 +6297,9 @@ fn instantiateGenericCall(
                 else => continue,
             }
 
+            const arg_ty = sema.typeOf(uncasted_args[i]);
+
             if (is_comptime) {
-                const arg_ty = sema.typeOf(uncasted_args[i]);
                 const arg_val = sema.analyzeGenericCallArgVal(block, .unneeded, uncasted_args[i]) catch |err| switch (err) {
                     error.NeededSourceLocation => {
                         const decl = sema.mod.declPtr(block.src_decl);
@@ -6297,27 +6312,30 @@ fn instantiateGenericCall(
                 arg_val.hash(arg_ty, &hasher, mod);
                 if (is_anytype) {
                     arg_ty.hashWithHasher(&hasher, mod);
-                    comptime_tvs[i] = .{
+                    generic_args[i] = .{
                         .ty = arg_ty,
                         .val = arg_val,
+                        .is_anytype = true,
                     };
                 } else {
-                    comptime_tvs[i] = .{
-                        .ty = Type.initTag(.generic_poison),
+                    generic_args[i] = .{
+                        .ty = arg_ty,
                         .val = arg_val,
+                        .is_anytype = false,
                     };
                 }
             } else if (is_anytype) {
-                const arg_ty = sema.typeOf(uncasted_args[i]);
                 arg_ty.hashWithHasher(&hasher, mod);
-                comptime_tvs[i] = .{
+                generic_args[i] = .{
                     .ty = arg_ty,
                     .val = Value.initTag(.generic_poison),
+                    .is_anytype = true,
                 };
             } else {
-                comptime_tvs[i] = .{
-                    .ty = Type.initTag(.generic_poison),
+                generic_args[i] = .{
+                    .ty = arg_ty,
                     .val = Value.initTag(.generic_poison),
+                    .is_anytype = false,
                 };
             }
 
@@ -6331,7 +6349,7 @@ fn instantiateGenericCall(
         .generic_fn = module_fn,
         .precomputed_hash = precomputed_hash,
         .func_ty_info = func_ty_info,
-        .comptime_tvs = comptime_tvs,
+        .args = generic_args,
         .module = mod,
     };
     const gop = try mod.monomorphed_funcs.getOrPutAdapted(gpa, {}, adapter);
@@ -30124,7 +30142,7 @@ fn valuesEqual(
     rhs: Value,
     ty: Type,
 ) CompileError!bool {
-    return Value.eqlAdvanced(lhs, rhs, ty, sema.mod, sema.kit(block, src));
+    return Value.eqlAdvanced(lhs, ty, rhs, ty, sema.mod, sema.kit(block, src));
 }
 
 /// Asserts the values are comparable vectors of type `ty`.
src/value.zig
@@ -2004,6 +2004,10 @@ pub const Value = extern union {
         return (try orderAgainstZeroAdvanced(lhs, sema_kit)).compare(op);
     }
 
+    pub fn eql(a: Value, b: Value, ty: Type, mod: *Module) bool {
+        return eqlAdvanced(a, ty, b, ty, mod, null) catch unreachable;
+    }
+
     /// This function is used by hash maps and so treats floating-point NaNs as equal
     /// to each other, and not equal to other floating-point values.
     /// Similarly, it treats `undef` as a distinct value from all other values.
@@ -2012,13 +2016,10 @@ pub const Value = extern union {
     /// for `a`. This function must act *as if* `a` has been coerced to `ty`. This complication
     /// is required in order to make generic function instantiation efficient - specifically
     /// the insertion into the monomorphized function table.
-    pub fn eql(a: Value, b: Value, ty: Type, mod: *Module) bool {
-        return eqlAdvanced(a, b, ty, mod, null) catch unreachable;
-    }
-
     /// If `null` is provided for `sema_kit` then it is guaranteed no error will be returned.
     pub fn eqlAdvanced(
         a: Value,
+        a_ty: Type,
         b: Value,
         ty: Type,
         mod: *Module,
@@ -2044,33 +2045,34 @@ pub const Value = extern union {
                 const a_payload = a.castTag(.opt_payload).?.data;
                 const b_payload = b.castTag(.opt_payload).?.data;
                 var buffer: Type.Payload.ElemType = undefined;
-                return eqlAdvanced(a_payload, b_payload, ty.optionalChild(&buffer), mod, sema_kit);
+                const payload_ty = ty.optionalChild(&buffer);
+                return eqlAdvanced(a_payload, payload_ty, b_payload, payload_ty, mod, sema_kit);
             },
             .slice => {
                 const a_payload = a.castTag(.slice).?.data;
                 const b_payload = b.castTag(.slice).?.data;
-                if (!(try eqlAdvanced(a_payload.len, b_payload.len, Type.usize, mod, sema_kit))) {
+                if (!(try eqlAdvanced(a_payload.len, Type.usize, b_payload.len, Type.usize, mod, sema_kit))) {
                     return false;
                 }
 
                 var ptr_buf: Type.SlicePtrFieldTypeBuffer = undefined;
                 const ptr_ty = ty.slicePtrFieldType(&ptr_buf);
 
-                return eqlAdvanced(a_payload.ptr, b_payload.ptr, ptr_ty, mod, sema_kit);
+                return eqlAdvanced(a_payload.ptr, ptr_ty, b_payload.ptr, ptr_ty, mod, sema_kit);
             },
             .elem_ptr => {
                 const a_payload = a.castTag(.elem_ptr).?.data;
                 const b_payload = b.castTag(.elem_ptr).?.data;
                 if (a_payload.index != b_payload.index) return false;
 
-                return eqlAdvanced(a_payload.array_ptr, b_payload.array_ptr, ty, mod, sema_kit);
+                return eqlAdvanced(a_payload.array_ptr, ty, b_payload.array_ptr, ty, mod, sema_kit);
             },
             .field_ptr => {
                 const a_payload = a.castTag(.field_ptr).?.data;
                 const b_payload = b.castTag(.field_ptr).?.data;
                 if (a_payload.field_index != b_payload.field_index) return false;
 
-                return eqlAdvanced(a_payload.container_ptr, b_payload.container_ptr, ty, mod, sema_kit);
+                return eqlAdvanced(a_payload.container_ptr, ty, b_payload.container_ptr, ty, mod, sema_kit);
             },
             .@"error" => {
                 const a_name = a.castTag(.@"error").?.data.name;
@@ -2080,7 +2082,8 @@ pub const Value = extern union {
             .eu_payload => {
                 const a_payload = a.castTag(.eu_payload).?.data;
                 const b_payload = b.castTag(.eu_payload).?.data;
-                return eqlAdvanced(a_payload, b_payload, ty.errorUnionPayload(), mod, sema_kit);
+                const payload_ty = ty.errorUnionPayload();
+                return eqlAdvanced(a_payload, payload_ty, b_payload, payload_ty, mod, sema_kit);
             },
             .eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
             .opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
@@ -2098,7 +2101,7 @@ pub const Value = extern union {
                     const types = ty.tupleFields().types;
                     assert(types.len == a_field_vals.len);
                     for (types) |field_ty, i| {
-                        if (!(try eqlAdvanced(a_field_vals[i], b_field_vals[i], field_ty, mod, sema_kit))) {
+                        if (!(try eqlAdvanced(a_field_vals[i], field_ty, b_field_vals[i], field_ty, mod, sema_kit))) {
                             return false;
                         }
                     }
@@ -2109,7 +2112,7 @@ pub const Value = extern union {
                     const fields = ty.structFields().values();
                     assert(fields.len == a_field_vals.len);
                     for (fields) |field, i| {
-                        if (!(try eqlAdvanced(a_field_vals[i], b_field_vals[i], field.ty, mod, sema_kit))) {
+                        if (!(try eqlAdvanced(a_field_vals[i], field.ty, b_field_vals[i], field.ty, mod, sema_kit))) {
                             return false;
                         }
                     }
@@ -2120,7 +2123,7 @@ pub const Value = extern union {
                 for (a_field_vals) |a_elem, i| {
                     const b_elem = b_field_vals[i];
 
-                    if (!(try eqlAdvanced(a_elem, b_elem, elem_ty, mod, sema_kit))) {
+                    if (!(try eqlAdvanced(a_elem, elem_ty, b_elem, elem_ty, mod, sema_kit))) {
                         return false;
                     }
                 }
@@ -2132,7 +2135,7 @@ pub const Value = extern union {
                 switch (ty.containerLayout()) {
                     .Packed, .Extern => {
                         const tag_ty = ty.unionTagTypeHypothetical();
-                        if (!(try a_union.tag.eqlAdvanced(b_union.tag, tag_ty, mod, sema_kit))) {
+                        if (!(try eqlAdvanced(a_union.tag, tag_ty, b_union.tag, tag_ty, mod, sema_kit))) {
                             // In this case, we must disregard mismatching tags and compare
                             // based on the in-memory bytes of the payloads.
                             @panic("TODO comptime comparison of extern union values with mismatching tags");
@@ -2140,13 +2143,13 @@ pub const Value = extern union {
                     },
                     .Auto => {
                         const tag_ty = ty.unionTagTypeHypothetical();
-                        if (!(try a_union.tag.eqlAdvanced(b_union.tag, tag_ty, mod, sema_kit))) {
+                        if (!(try eqlAdvanced(a_union.tag, tag_ty, b_union.tag, tag_ty, mod, sema_kit))) {
                             return false;
                         }
                     },
                 }
                 const active_field_ty = ty.unionFieldType(a_union.tag, mod);
-                return a_union.val.eqlAdvanced(b_union.val, active_field_ty, mod, sema_kit);
+                return eqlAdvanced(a_union.val, active_field_ty, b_union.val, active_field_ty, mod, sema_kit);
             },
             else => {},
         } else if (a_tag == .null_value or b_tag == .null_value) {
@@ -2180,7 +2183,7 @@ pub const Value = extern union {
                 const b_val = b.enumToInt(ty, &buf_b);
                 var buf_ty: Type.Payload.Bits = undefined;
                 const int_ty = ty.intTagType(&buf_ty);
-                return eqlAdvanced(a_val, b_val, int_ty, mod, sema_kit);
+                return eqlAdvanced(a_val, int_ty, b_val, int_ty, mod, sema_kit);
             },
             .Array, .Vector => {
                 const len = ty.arrayLen();
@@ -2191,17 +2194,44 @@ pub const Value = extern union {
                 while (i < len) : (i += 1) {
                     const a_elem = elemValueBuffer(a, mod, i, &a_buf);
                     const b_elem = elemValueBuffer(b, mod, i, &b_buf);
-                    if (!(try eqlAdvanced(a_elem, b_elem, elem_ty, mod, sema_kit))) {
+                    if (!(try eqlAdvanced(a_elem, elem_ty, b_elem, elem_ty, mod, sema_kit))) {
                         return false;
                     }
                 }
                 return true;
             },
             .Struct => {
-                // A tuple can be represented with .empty_struct_value,
-                // the_one_possible_value, .aggregate in which case we could
-                // end up here and the values are equal if the type has zero fields.
-                return ty.isTupleOrAnonStruct() and ty.structFieldCount() != 0;
+                // A struct can be represented with one of:
+                //   .empty_struct_value,
+                //   .the_one_possible_value,
+                //   .aggregate,
+                // Note that we already checked above for matching tags, e.g. both .aggregate.
+                return ty.onePossibleValue() != null;
+            },
+            .Union => {
+                // Here we have to check for value equality, as-if `a` has been coerced to `ty`.
+                if (ty.onePossibleValue() != null) {
+                    return true;
+                }
+                if (a_ty.castTag(.anon_struct)) |payload| {
+                    const tuple = payload.data;
+                    if (tuple.values.len != 1) {
+                        return false;
+                    }
+                    const field_name = tuple.names[0];
+                    const union_obj = ty.cast(Type.Payload.Union).?.data;
+                    const field_index = union_obj.fields.getIndex(field_name) orelse return false;
+                    const tag_and_val = b.castTag(.@"union").?.data;
+                    var field_tag_buf: Value.Payload.U32 = .{
+                        .base = .{ .tag = .enum_field_index },
+                        .data = @intCast(u32, field_index),
+                    };
+                    const field_tag = Value.initPayload(&field_tag_buf.base);
+                    const tag_matches = tag_and_val.tag.eql(field_tag, union_obj.tag_ty, mod);
+                    if (!tag_matches) return false;
+                    return eqlAdvanced(tag_and_val.val, union_obj.tag_ty, tuple.values[0], tuple.types[0], mod, sema_kit);
+                }
+                return false;
             },
             .Float => {
                 switch (ty.floatBits(target)) {
@@ -2230,7 +2260,8 @@ pub const Value = extern union {
                         .base = .{ .tag = .opt_payload },
                         .data = a,
                     };
-                    return eqlAdvanced(Value.initPayload(&buffer.base), b, ty, mod, sema_kit);
+                    const opt_val = Value.initPayload(&buffer.base);
+                    return eqlAdvanced(opt_val, ty, b, ty, mod, sema_kit);
                 }
             },
             else => {},
test/behavior/generics.zig
@@ -323,3 +323,22 @@ test "generic function instantiation non-duplicates" {
     S.copy(u8, &buffer, "hello");
     S.copy(u8, &buffer, "hello2");
 }
+
+test "generic instantiation of tagged union with only one field" {
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.os.tag == .wasi) return error.SkipZigTest;
+
+    const S = struct {
+        const U = union(enum) {
+            s: []const u8,
+        };
+
+        fn foo(comptime u: U) usize {
+            return u.s.len;
+        }
+    };
+
+    try expect(S.foo(.{ .s = "a" }) == 1);
+    try expect(S.foo(.{ .s = "ab" }) == 2);
+}