Commit 31aefa6a21

Andrew Kelley <superjoe30@gmail.com>
2018-06-07 23:26:41
fix structs that contain types which require comptime
Now, if a struct has any fields which require comptime, such as `type`, then the struct is marked as requiring comptime as well. Same goes for unions. This means that a function will implicitly be called at comptime if the return type is a struct which contains a field of type `type`. closes #586
1 parent b11c5d8
src/all_types.hpp
@@ -1037,6 +1037,10 @@ struct TypeTableEntryStruct {
     // whether we've finished resolving it
     bool complete;
 
+    // whether any of the fields require comptime
+    // the value is not valid until zero_bits_known == true
+    bool requires_comptime;
+
     bool zero_bits_loop_flag;
     bool zero_bits_known;
     uint32_t abi_alignment; // also figured out with zero_bits pass
@@ -1105,6 +1109,10 @@ struct TypeTableEntryUnion {
     // whether we've finished resolving it
     bool complete;
 
+    // whether any of the fields require comptime
+    // the value is not valid until zero_bits_known == true
+    bool requires_comptime;
+
     bool zero_bits_loop_flag;
     bool zero_bits_known;
     uint32_t abi_alignment; // also figured out with zero_bits pass
src/analyze.cpp
@@ -2533,6 +2533,10 @@ static void resolve_struct_zero_bits(CodeGen *g, TypeTableEntry *struct_type) {
             continue;
         }
 
+        if (type_requires_comptime(field_type)) {
+            struct_type->data.structure.requires_comptime = true;
+        }
+
         if (!type_has_bits(field_type))
             continue;
 
@@ -2724,6 +2728,11 @@ static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) {
         }
         union_field->type_entry = field_type;
 
+        if (type_requires_comptime(field_type)) {
+            union_type->data.unionation.requires_comptime = true;
+        }
+
+
         if (field_node->data.struct_field.value != nullptr && !decl_node->data.container_decl.auto_enum) {
             ErrorMsg *msg = add_node_error(g, field_node->data.struct_field.value,
                     buf_sprintf("non-enum union field assignment"));
@@ -4944,17 +4953,29 @@ bool type_requires_comptime(TypeTableEntry *type_entry) {
         case TypeTableEntryIdArgTuple:
             return true;
         case TypeTableEntryIdArray:
+            return type_requires_comptime(type_entry->data.array.child_type);
         case TypeTableEntryIdStruct:
+            assert(type_has_zero_bits_known(type_entry));
+            return type_entry->data.structure.requires_comptime;
         case TypeTableEntryIdUnion:
+            assert(type_has_zero_bits_known(type_entry));
+            return type_entry->data.unionation.requires_comptime;
         case TypeTableEntryIdMaybe:
+            return type_requires_comptime(type_entry->data.maybe.child_type);
         case TypeTableEntryIdErrorUnion:
+            return type_requires_comptime(type_entry->data.error_union.payload_type);
+        case TypeTableEntryIdPointer:
+            if (type_entry->data.pointer.child_type->id == TypeTableEntryIdOpaque) {
+                return false;
+            } else {
+                return type_requires_comptime(type_entry->data.pointer.child_type);
+            }
         case TypeTableEntryIdEnum:
         case TypeTableEntryIdErrorSet:
         case TypeTableEntryIdFn:
         case TypeTableEntryIdBool:
         case TypeTableEntryIdInt:
         case TypeTableEntryIdFloat:
-        case TypeTableEntryIdPointer:
         case TypeTableEntryIdVoid:
         case TypeTableEntryIdUnreachable:
         case TypeTableEntryIdPromise:
src/ir.cpp
@@ -11624,61 +11624,6 @@ static TypeTableEntry *ir_analyze_instruction_bin_op(IrAnalyze *ira, IrInstructi
     zig_unreachable();
 }
 
-enum VarClassRequired {
-    VarClassRequiredAny,
-    VarClassRequiredConst,
-    VarClassRequiredIllegal,
-};
-
-static VarClassRequired get_var_class_required(TypeTableEntry *type_entry) {
-    switch (type_entry->id) {
-        case TypeTableEntryIdInvalid:
-            zig_unreachable();
-        case TypeTableEntryIdUnreachable:
-            return VarClassRequiredIllegal;
-        case TypeTableEntryIdBool:
-        case TypeTableEntryIdInt:
-        case TypeTableEntryIdFloat:
-        case TypeTableEntryIdVoid:
-        case TypeTableEntryIdErrorSet:
-        case TypeTableEntryIdFn:
-        case TypeTableEntryIdPromise:
-            return VarClassRequiredAny;
-        case TypeTableEntryIdComptimeFloat:
-        case TypeTableEntryIdComptimeInt:
-        case TypeTableEntryIdUndefined:
-        case TypeTableEntryIdBlock:
-        case TypeTableEntryIdNull:
-        case TypeTableEntryIdOpaque:
-        case TypeTableEntryIdMetaType:
-        case TypeTableEntryIdNamespace:
-        case TypeTableEntryIdBoundFn:
-        case TypeTableEntryIdArgTuple:
-            return VarClassRequiredConst;
-
-        case TypeTableEntryIdPointer:
-            if (type_entry->data.pointer.child_type->id == TypeTableEntryIdOpaque) {
-                return VarClassRequiredAny;
-            } else {
-                return get_var_class_required(type_entry->data.pointer.child_type);
-            }
-        case TypeTableEntryIdArray:
-            return get_var_class_required(type_entry->data.array.child_type);
-        case TypeTableEntryIdMaybe:
-            return get_var_class_required(type_entry->data.maybe.child_type);
-        case TypeTableEntryIdErrorUnion:
-            return get_var_class_required(type_entry->data.error_union.payload_type);
-
-        case TypeTableEntryIdStruct:
-        case TypeTableEntryIdEnum:
-        case TypeTableEntryIdUnion:
-            // TODO check the fields of these things and make sure that they don't recursively
-            // contain any of the other variable classes
-            return VarClassRequiredAny;
-    }
-    zig_unreachable();
-}
-
 static TypeTableEntry *ir_analyze_instruction_decl_var(IrAnalyze *ira, IrInstructionDeclVar *decl_var_instruction) {
     VariableTableEntry *var = decl_var_instruction->var;
 
@@ -11713,36 +11658,41 @@ static TypeTableEntry *ir_analyze_instruction_decl_var(IrAnalyze *ira, IrInstruc
     if (type_is_invalid(result_type)) {
         result_type = ira->codegen->builtin_types.entry_invalid;
     } else {
-        switch (get_var_class_required(result_type)) {
-            case VarClassRequiredIllegal:
+        type_ensure_zero_bits_known(ira->codegen, result_type);
+        if (type_is_invalid(result_type)) {
+            result_type = ira->codegen->builtin_types.entry_invalid;
+        }
+    }
+
+    if (!type_is_invalid(result_type)) {
+        if (result_type->id == TypeTableEntryIdUnreachable ||
+            result_type->id == TypeTableEntryIdOpaque)
+        {
+            ir_add_error_node(ira, source_node,
+                buf_sprintf("variable of type '%s' not allowed", buf_ptr(&result_type->name)));
+            result_type = ira->codegen->builtin_types.entry_invalid;
+        } else if (type_requires_comptime(result_type)) {
+            var_class_requires_const = true;
+            if (!var->src_is_const && !is_comptime_var) {
                 ir_add_error_node(ira, source_node,
-                    buf_sprintf("variable of type '%s' not allowed", buf_ptr(&result_type->name)));
+                    buf_sprintf("variable of type '%s' must be const or comptime",
+                        buf_ptr(&result_type->name)));
                 result_type = ira->codegen->builtin_types.entry_invalid;
-                break;
-            case VarClassRequiredConst:
+            }
+        } else {
+            if (casted_init_value->value.special == ConstValSpecialStatic &&
+                casted_init_value->value.type->id == TypeTableEntryIdFn &&
+                casted_init_value->value.data.x_ptr.data.fn.fn_entry->fn_inline == FnInlineAlways)
+            {
                 var_class_requires_const = true;
                 if (!var->src_is_const && !is_comptime_var) {
-                    ir_add_error_node(ira, source_node,
-                        buf_sprintf("variable of type '%s' must be const or comptime",
-                            buf_ptr(&result_type->name)));
+                    ErrorMsg *msg = ir_add_error_node(ira, source_node,
+                        buf_sprintf("functions marked inline must be stored in const or comptime var"));
+                    AstNode *proto_node = casted_init_value->value.data.x_ptr.data.fn.fn_entry->proto_node;
+                    add_error_note(ira->codegen, msg, proto_node, buf_sprintf("declared here"));
                     result_type = ira->codegen->builtin_types.entry_invalid;
                 }
-                break;
-            case VarClassRequiredAny:
-                if (casted_init_value->value.special == ConstValSpecialStatic &&
-                    casted_init_value->value.type->id == TypeTableEntryIdFn &&
-                    casted_init_value->value.data.x_ptr.data.fn.fn_entry->fn_inline == FnInlineAlways)
-                {
-                    var_class_requires_const = true;
-                    if (!var->src_is_const && !is_comptime_var) {
-                        ErrorMsg *msg = ir_add_error_node(ira, source_node,
-                            buf_sprintf("functions marked inline must be stored in const or comptime var"));
-                        AstNode *proto_node = casted_init_value->value.data.x_ptr.data.fn.fn_entry->proto_node;
-                        add_error_note(ira->codegen, msg, proto_node, buf_sprintf("declared here"));
-                        result_type = ira->codegen->builtin_types.entry_invalid;
-                    }
-                }
-                break;
+            }
         }
     }
 
@@ -12623,6 +12573,10 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
                 inst_fn_type_id.return_type = specified_return_type;
             }
 
+            type_ensure_zero_bits_known(ira->codegen, specified_return_type);
+            if (type_is_invalid(specified_return_type))
+                return ira->codegen->builtin_types.entry_invalid;
+
             if (type_requires_comptime(specified_return_type)) {
                 // Throw out our work and call the function as if it were comptime.
                 return ir_analyze_fn_call(ira, call_instruction, fn_entry, fn_type, fn_ref, first_arg_ptr, true, FnInlineAuto);
test/cases/eval.zig
@@ -610,3 +610,16 @@ test "slice of type" {
         }
     }
 }
+
+const Wrapper = struct {
+    T: type,
+};
+
+fn wrap(comptime T: type) Wrapper {
+    return Wrapper{ .T = T };
+}
+
+test "function which returns struct with type field causes implicit comptime" {
+    const ty = wrap(i32).T;
+    assert(ty == i32);
+}
test/compile_errors.zig
@@ -3329,7 +3329,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
         ".tmp_source.zig:9:4: error: variable of type 'comptime_float' must be const or comptime",
         ".tmp_source.zig:10:4: error: variable of type '(block)' must be const or comptime",
         ".tmp_source.zig:11:4: error: variable of type '(null)' must be const or comptime",
-        ".tmp_source.zig:12:4: error: variable of type 'Opaque' must be const or comptime",
+        ".tmp_source.zig:12:4: error: variable of type 'Opaque' not allowed",
         ".tmp_source.zig:13:4: error: variable of type 'type' must be const or comptime",
         ".tmp_source.zig:14:4: error: variable of type '(namespace)' must be const or comptime",
         ".tmp_source.zig:15:4: error: variable of type '(bound fn(*const Foo) void)' must be const or comptime",