Commit 7e9b23e6dc

Andrew Kelley <andrew@ziglang.org>
2021-08-06 08:23:05
Sema: respect requiresComptime of function return types
When doing a function call, if the return type requires comptime, the function is analyzed as an inline/comptime call. There is an important TODO here. I will reproduce the comment from this commit: > In the case of a comptime/inline function call of a generic function, > the function return type needs to be the resolved return type based on > the function parameter type expressions being evaluated with comptime arguments > passed in. Otherwise, it ends up being .generic_poison and failing the > comptime/inline function call analysis.
1 parent c7dc451
Changed files (4)
src/Sema.zig
@@ -2461,7 +2461,8 @@ fn analyzeCall(
 
     const gpa = sema.gpa;
 
-    const is_comptime_call = block.is_comptime or modifier == .compile_time;
+    const is_comptime_call = block.is_comptime or modifier == .compile_time or
+        func_ty_info.return_type.requiresComptime();
     const is_inline_call = is_comptime_call or modifier == .always_inline or
         func_ty_info.cc == .Inline;
     const result: Air.Inst.Ref = if (is_inline_call) res: {
@@ -3609,6 +3610,8 @@ fn funcCommon(
             return mod.fail(&block.base, src, "TODO implement support for function prototypes to have alignment specified", .{});
         }
 
+        is_generic = is_generic or bare_return_type.requiresComptime();
+
         const return_type = if (!inferred_error_set or bare_return_type.tag() == .generic_poison)
             bare_return_type
         else blk: {
@@ -5334,18 +5337,18 @@ fn analyzeArithmetic(
 ) CompileError!Air.Inst.Ref {
     const lhs_ty = sema.typeOf(lhs);
     const rhs_ty = sema.typeOf(rhs);
-    if (lhs_ty.zigTypeTag() == .Vector and rhs_ty.zigTypeTag() == .Vector) {
+    const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison();
+    const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison();
+    if (lhs_zig_ty_tag == .Vector and rhs_zig_ty_tag == .Vector) {
         if (lhs_ty.arrayLen() != rhs_ty.arrayLen()) {
             return sema.mod.fail(&block.base, src, "vector length mismatch: {d} and {d}", .{
-                lhs_ty.arrayLen(),
-                rhs_ty.arrayLen(),
+                lhs_ty.arrayLen(), rhs_ty.arrayLen(),
             });
         }
         return sema.mod.fail(&block.base, src, "TODO implement support for vectors in zirBinOp", .{});
-    } else if (lhs_ty.zigTypeTag() == .Vector or rhs_ty.zigTypeTag() == .Vector) {
+    } else if (lhs_zig_ty_tag == .Vector or rhs_zig_ty_tag == .Vector) {
         return sema.mod.fail(&block.base, src, "mixed scalar and vector operands to binary expression: '{}' and '{}'", .{
-            lhs_ty,
-            rhs_ty,
+            lhs_ty, rhs_ty,
         });
     }
 
@@ -5365,7 +5368,9 @@ fn analyzeArithmetic(
     const is_float = scalar_tag == .Float or scalar_tag == .ComptimeFloat;
 
     if (!is_int and !(is_float and floatOpAllowed(zir_tag))) {
-        return sema.mod.fail(&block.base, src, "invalid operands to binary expression: '{s}' and '{s}'", .{ @tagName(lhs_ty.zigTypeTag()), @tagName(rhs_ty.zigTypeTag()) });
+        return sema.mod.fail(&block.base, src, "invalid operands to binary expression: '{s}' and '{s}'", .{
+            @tagName(lhs_zig_ty_tag), @tagName(rhs_zig_ty_tag),
+        });
     }
 
     if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| {
@@ -6164,6 +6169,10 @@ fn analyzeRet(
     const casted_operand = if (!need_coercion) operand else op: {
         const func = sema.func.?;
         const fn_ty = func.owner_decl.ty;
+        // TODO: In the case of a comptime/inline function call of a generic function,
+        // this needs to be the resolved return type based on the function parameter type
+        // expressions being evaluated with comptime arguments passed in. Otherwise, this
+        // ends up being .generic_poison and failing the comptime/inline function call analysis.
         const fn_ret_ty = fn_ty.fnReturnType();
         break :op try sema.coerce(block, fn_ret_ty, operand, src);
     };
@@ -9093,7 +9102,7 @@ fn typeHasOnePossibleValue(
 
         .inferred_alloc_const => unreachable,
         .inferred_alloc_mut => unreachable,
-        .generic_poison => unreachable,
+        .generic_poison => return error.GenericPoison,
     };
 }
 
src/type.zig
@@ -21,8 +21,14 @@ pub const Type = extern union {
     tag_if_small_enough: usize,
     ptr_otherwise: *Payload,
 
-    pub fn zigTypeTag(self: Type) std.builtin.TypeId {
-        switch (self.tag()) {
+    pub fn zigTypeTag(ty: Type) std.builtin.TypeId {
+        return ty.zigTypeTagOrPoison() catch unreachable;
+    }
+
+    pub fn zigTypeTagOrPoison(ty: Type) error{GenericPoison}!std.builtin.TypeId {
+        switch (ty.tag()) {
+            .generic_poison => return error.GenericPoison,
+
             .u1,
             .u8,
             .i8,
@@ -130,7 +136,6 @@ pub const Type = extern union {
             => return .Union,
 
             .var_args_param => unreachable, // can be any type
-            .generic_poison => unreachable, // must be handled earlier
         }
     }
 
@@ -1096,6 +1101,7 @@ pub const Type = extern union {
     }
 
     /// Anything that reports hasCodeGenBits() false returns false here as well.
+    /// `generic_poison` will return false.
     pub fn requiresComptime(ty: Type) bool {
         return switch (ty.tag()) {
             .u1,
@@ -1156,6 +1162,7 @@ pub const Type = extern union {
             .error_set_single,
             .error_set_inferred,
             .@"opaque",
+            .generic_poison,
             => false,
 
             .type,
@@ -1167,7 +1174,6 @@ pub const Type = extern union {
             .var_args_param => unreachable,
             .inferred_alloc_mut => unreachable,
             .inferred_alloc_const => unreachable,
-            .generic_poison => unreachable,
 
             .array_u8,
             .array_u8_sentinel_0,
src/Zir.zig
@@ -4501,12 +4501,16 @@ const Writer = struct {
         src: LazySrcLoc,
         src_locs: Zir.Inst.Func.SrcLocs,
     ) !void {
-        try stream.writeAll("ret_ty={\n");
-        self.indent += 2;
-        try self.writeBody(stream, ret_ty_body);
-        self.indent -= 2;
-        try stream.writeByteNTimes(' ', self.indent);
-        try stream.writeAll("}");
+        if (ret_ty_body.len == 0) {
+            try stream.writeAll("ret_ty=void");
+        } else {
+            try stream.writeAll("ret_ty={\n");
+            self.indent += 2;
+            try self.writeBody(stream, ret_ty_body);
+            self.indent -= 2;
+            try stream.writeByteNTimes(' ', self.indent);
+            try stream.writeAll("}");
+        }
 
         try self.writeOptionalInstRef(stream, ", cc=", cc);
         try self.writeOptionalInstRef(stream, ", align=", align_inst);
test/behavior/generics_stage1.zig
@@ -13,16 +13,16 @@ test {
     comptime try expect(max_f64(1.2, 3.4) == 3.4);
 }
 
-fn max_var(a: anytype, b: anytype) @TypeOf(a + b) {
+fn max_anytype(a: anytype, b: anytype) @TypeOf(a + b) {
     return if (a > b) a else b;
 }
 
 fn max_i32(a: i32, b: i32) i32 {
-    return max_var(a, b);
+    return max_anytype(a, b);
 }
 
 fn max_f64(a: f64, b: f64) f64 {
-    return max_var(a, b);
+    return max_anytype(a, b);
 }
 
 pub fn List(comptime T: type) type {