Commit 3d2752cc36

Andrew Kelley <superjoe30@gmail.com>
2018-11-24 22:15:58
refactor type_requires_comptime to have possible error
fixes a compiler crash when building https://github.com/AndreaOrru/zen
1 parent 56a8f2b
src/analyze.cpp
@@ -1619,13 +1619,16 @@ static ZigType *analyze_fn_type(CodeGen *g, AstNode *proto_node, Scope *child_sc
             case ZigTypeIdUnion:
             case ZigTypeIdFn:
             case ZigTypeIdPromise:
-                if ((err = type_resolve(g, type_entry, ResolveStatusZeroBitsKnown)))
-                    return g->builtin_types.entry_invalid;
-                if (type_requires_comptime(type_entry)) {
-                    add_node_error(g, param_node->data.param_decl.type,
-                        buf_sprintf("parameter of type '%s' must be declared comptime",
-                        buf_ptr(&type_entry->name)));
-                    return g->builtin_types.entry_invalid;
+                switch (type_requires_comptime(g, type_entry)) {
+                    case ReqCompTimeNo:
+                        break;
+                    case ReqCompTimeYes:
+                        add_node_error(g, param_node->data.param_decl.type,
+                            buf_sprintf("parameter of type '%s' must be declared comptime",
+                            buf_ptr(&type_entry->name)));
+                        return g->builtin_types.entry_invalid;
+                    case ReqCompTimeInvalid:
+                        return g->builtin_types.entry_invalid;
                 }
                 break;
         }
@@ -1711,10 +1714,13 @@ static ZigType *analyze_fn_type(CodeGen *g, AstNode *proto_node, Scope *child_sc
         case ZigTypeIdUnion:
         case ZigTypeIdFn:
         case ZigTypeIdPromise:
-            if ((err = type_resolve(g, fn_type_id.return_type, ResolveStatusZeroBitsKnown)))
-                return g->builtin_types.entry_invalid;
-            if (type_requires_comptime(fn_type_id.return_type)) {
-                return get_generic_fn_type(g, &fn_type_id);
+            switch (type_requires_comptime(g, fn_type_id.return_type)) {
+                case ReqCompTimeInvalid:
+                    return g->builtin_types.entry_invalid;
+                case ReqCompTimeYes:
+                    return get_generic_fn_type(g, &fn_type_id);
+                case ReqCompTimeNo:
+                    break;
             }
             break;
     }
@@ -2560,8 +2566,6 @@ static Error resolve_enum_zero_bits(CodeGen *g, ZigType *enum_type) {
 static Error resolve_struct_zero_bits(CodeGen *g, ZigType *struct_type) {
     assert(struct_type->id == ZigTypeIdStruct);
 
-    Error err;
-
     if (struct_type->data.structure.resolve_status == ResolveStatusInvalid)
         return ErrorSemanticAnalyzeFail;
     if (struct_type->data.structure.resolve_status >= ResolveStatusZeroBitsKnown)
@@ -2619,13 +2623,15 @@ static Error resolve_struct_zero_bits(CodeGen *g, ZigType *struct_type) {
                     buf_sprintf("enums, not structs, support field assignment"));
         }
 
-        if ((err = type_resolve(g, field_type, ResolveStatusZeroBitsKnown))) {
-            struct_type->data.structure.resolve_status = ResolveStatusInvalid;
-            continue;
-        }
-
-        if (type_requires_comptime(field_type)) {
-            struct_type->data.structure.requires_comptime = true;
+        switch (type_requires_comptime(g, field_type)) {
+            case ReqCompTimeYes:
+                struct_type->data.structure.requires_comptime = true;
+                break;
+            case ReqCompTimeInvalid:
+                struct_type->data.structure.resolve_status = ResolveStatusInvalid;
+                continue;
+            case ReqCompTimeNo:
+                break;
         }
 
         if (!type_has_bits(field_type))
@@ -2890,11 +2896,17 @@ static Error resolve_union_zero_bits(CodeGen *g, ZigType *union_type) {
         }
         union_field->type_entry = field_type;
 
-        if (type_requires_comptime(field_type)) {
-            union_type->data.unionation.requires_comptime = true;
+        switch (type_requires_comptime(g, field_type)) {
+            case ReqCompTimeInvalid:
+                union_type->data.unionation.is_invalid = true;
+                continue;
+            case ReqCompTimeYes:
+                union_type->data.unionation.requires_comptime = true;
+                break;
+            case ReqCompTimeNo:
+                break;
         }
 
-
         if (field_node->data.struct_field.value != nullptr && !decl_node->data.container_decl.auto_enum) {
             ErrorMsg *msg = add_node_error(g, field_node->data.struct_field.value,
                     buf_sprintf("non-enum union field assignment"));
@@ -5089,7 +5101,10 @@ bool type_has_bits(ZigType *type_entry) {
     return !type_entry->zero_bits;
 }
 
-bool type_requires_comptime(ZigType *type_entry) {
+ReqCompTime type_requires_comptime(CodeGen *g, ZigType *type_entry) {
+    Error err;
+    if ((err = type_resolve(g, type_entry, ResolveStatusZeroBitsKnown)))
+        return ReqCompTimeInvalid;
     switch (type_entry->id) {
         case ZigTypeIdInvalid:
         case ZigTypeIdOpaque:
@@ -5102,27 +5117,25 @@ bool type_requires_comptime(ZigType *type_entry) {
         case ZigTypeIdNamespace:
         case ZigTypeIdBoundFn:
         case ZigTypeIdArgTuple:
-            return true;
+            return ReqCompTimeYes;
         case ZigTypeIdArray:
-            return type_requires_comptime(type_entry->data.array.child_type);
+            return type_requires_comptime(g, type_entry->data.array.child_type);
         case ZigTypeIdStruct:
-            assert(type_is_resolved(type_entry, ResolveStatusZeroBitsKnown));
-            return type_entry->data.structure.requires_comptime;
+            return type_entry->data.structure.requires_comptime ? ReqCompTimeYes : ReqCompTimeNo;
         case ZigTypeIdUnion:
-            assert(type_is_resolved(type_entry, ResolveStatusZeroBitsKnown));
-            return type_entry->data.unionation.requires_comptime;
+            return type_entry->data.unionation.requires_comptime ? ReqCompTimeYes : ReqCompTimeNo;
         case ZigTypeIdOptional:
-            return type_requires_comptime(type_entry->data.maybe.child_type);
+            return type_requires_comptime(g, type_entry->data.maybe.child_type);
         case ZigTypeIdErrorUnion:
-            return type_requires_comptime(type_entry->data.error_union.payload_type);
+            return type_requires_comptime(g, type_entry->data.error_union.payload_type);
         case ZigTypeIdPointer:
             if (type_entry->data.pointer.child_type->id == ZigTypeIdOpaque) {
-                return false;
+                return ReqCompTimeNo;
             } else {
-                return type_requires_comptime(type_entry->data.pointer.child_type);
+                return type_requires_comptime(g, type_entry->data.pointer.child_type);
             }
         case ZigTypeIdFn:
-            return type_entry->data.fn.is_generic;
+            return type_entry->data.fn.is_generic ? ReqCompTimeYes : ReqCompTimeNo;
         case ZigTypeIdEnum:
         case ZigTypeIdErrorSet:
         case ZigTypeIdBool:
@@ -5131,7 +5144,7 @@ bool type_requires_comptime(ZigType *type_entry) {
         case ZigTypeIdVoid:
         case ZigTypeIdUnreachable:
         case ZigTypeIdPromise:
-            return false;
+            return ReqCompTimeNo;
     }
     zig_unreachable();
 }
src/analyze.hpp
@@ -87,7 +87,6 @@ ZigFn *create_fn(CodeGen *g, AstNode *proto_node);
 ZigFn *create_fn_raw(CodeGen *g, FnInline inline_value);
 void init_fn_type_id(FnTypeId *fn_type_id, AstNode *proto_node, size_t param_count_alloc);
 AstNode *get_param_decl_node(ZigFn *fn_entry, size_t index);
-bool type_requires_comptime(ZigType *type_entry);
 Error ATTRIBUTE_MUST_USE ensure_complete_type(CodeGen *g, ZigType *type_entry);
 Error ATTRIBUTE_MUST_USE type_resolve(CodeGen *g, ZigType *type_entry, ResolveStatus status);
 void complete_enum(CodeGen *g, ZigType *enum_type);
@@ -216,4 +215,11 @@ bool want_first_arg_sret(CodeGen *g, FnTypeId *fn_type_id);
 
 uint32_t get_host_int_bytes(CodeGen *g, ZigType *struct_type, TypeStructField *field);
 
+enum ReqCompTime {
+    ReqCompTimeInvalid,
+    ReqCompTimeNo,
+    ReqCompTimeYes,
+};
+ReqCompTime type_requires_comptime(CodeGen *g, ZigType *type_entry);
+
 #endif
src/codegen.cpp
@@ -6281,8 +6281,14 @@ static void do_code_gen(CodeGen *g) {
             }
             if (ir_get_var_is_comptime(var))
                 continue;
-            if (type_requires_comptime(var->value->type))
-                continue;
+            switch (type_requires_comptime(g, var->value->type)) {
+                case ReqCompTimeInvalid:
+                    zig_unreachable();
+                case ReqCompTimeYes:
+                    continue;
+                case ReqCompTimeNo:
+                    break;
+            }
 
             if (var->src_arg_index == SIZE_MAX) {
                 var->value_ref = build_alloca(g, var->value->type, buf_ptr(&var->name), var->align_bytes);
src/ir.cpp
@@ -11276,7 +11276,6 @@ static bool optional_value_is_null(ConstExprValue *val) {
 }
 
 static IrInstruction *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp *bin_op_instruction) {
-    Error err;
     IrInstruction *op1 = bin_op_instruction->op1->child;
     if (type_is_invalid(op1->value.type))
         return ira->codegen->invalid_instruction;
@@ -11470,10 +11469,19 @@ static IrInstruction *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp *
     if (casted_op2 == ira->codegen->invalid_instruction)
         return ira->codegen->invalid_instruction;
 
-    if ((err = type_resolve(ira->codegen, resolved_type, ResolveStatusZeroBitsKnown)))
-        return ira->codegen->invalid_instruction;
+    bool requires_comptime;
+    switch (type_requires_comptime(ira->codegen, resolved_type)) {
+        case ReqCompTimeYes:
+            requires_comptime = true;
+            break;
+        case ReqCompTimeNo:
+            requires_comptime = false;
+            break;
+        case ReqCompTimeInvalid:
+            return ira->codegen->invalid_instruction;
+    }
 
-    bool one_possible_value = !type_requires_comptime(resolved_type) && !type_has_bits(resolved_type);
+    bool one_possible_value = !requires_comptime && !type_has_bits(resolved_type);
     if (one_possible_value || (instr_is_comptime(casted_op1) && instr_is_comptime(casted_op2))) {
         ConstExprValue *op1_val = one_possible_value ? &casted_op1->value : ir_resolve_const(ira, casted_op1, UndefBad);
         if (op1_val == nullptr)
@@ -12406,42 +12414,41 @@ static IrInstruction *ir_analyze_instruction_decl_var(IrAnalyze *ira, IrInstruct
     ZigType *result_type = casted_init_value->value.type;
     if (type_is_invalid(result_type)) {
         result_type = ira->codegen->builtin_types.entry_invalid;
-    } else {
-        if ((err = type_resolve(ira->codegen, result_type, ResolveStatusZeroBitsKnown))) {
-            result_type = ira->codegen->builtin_types.entry_invalid;
-        }
+    } else if (result_type->id == ZigTypeIdUnreachable || result_type->id == ZigTypeIdOpaque) {
+        ir_add_error_node(ira, source_node,
+            buf_sprintf("variable of type '%s' not allowed", buf_ptr(&result_type->name)));
+        result_type = ira->codegen->builtin_types.entry_invalid;
     }
 
-    if (!type_is_invalid(result_type)) {
-        if (result_type->id == ZigTypeIdUnreachable ||
-            result_type->id == ZigTypeIdOpaque)
-        {
+    switch (type_requires_comptime(ira->codegen, result_type)) {
+    case ReqCompTimeInvalid:
+        result_type = ira->codegen->builtin_types.entry_invalid;
+        break;
+    case ReqCompTimeYes: {
+        var_class_requires_const = true;
+        if (!var->gen_is_const && !is_comptime_var) {
             ir_add_error_node(ira, source_node,
-                buf_sprintf("variable of type '%s' not allowed", buf_ptr(&result_type->name)));
+                buf_sprintf("variable of type '%s' must be const or comptime",
+                    buf_ptr(&result_type->name)));
             result_type = ira->codegen->builtin_types.entry_invalid;
-        } else if (type_requires_comptime(result_type)) {
+        }
+        break;
+    }
+    case ReqCompTimeNo:
+        if (casted_init_value->value.special == ConstValSpecialStatic &&
+            casted_init_value->value.type->id == ZigTypeIdFn &&
+            casted_init_value->value.data.x_ptr.data.fn.fn_entry->fn_inline == FnInlineAlways)
+        {
             var_class_requires_const = true;
-            if (!var->gen_is_const && !is_comptime_var) {
-                ir_add_error_node(ira, source_node,
-                    buf_sprintf("variable of type '%s' must be const or comptime",
-                        buf_ptr(&result_type->name)));
+            if (!var->src_is_const && !is_comptime_var) {
+                ErrorMsg *msg = ir_add_error_node(ira, source_node,
+                    buf_sprintf("functions marked inline must be stored in const or comptime var"));
+                AstNode *proto_node = casted_init_value->value.data.x_ptr.data.fn.fn_entry->proto_node;
+                add_error_note(ira->codegen, msg, proto_node, buf_sprintf("declared here"));
                 result_type = ira->codegen->builtin_types.entry_invalid;
             }
-        } else {
-            if (casted_init_value->value.special == ConstValSpecialStatic &&
-                casted_init_value->value.type->id == ZigTypeIdFn &&
-                casted_init_value->value.data.x_ptr.data.fn.fn_entry->fn_inline == FnInlineAlways)
-            {
-                var_class_requires_const = true;
-                if (!var->src_is_const && !is_comptime_var) {
-                    ErrorMsg *msg = ir_add_error_node(ira, source_node,
-                        buf_sprintf("functions marked inline must be stored in const or comptime var"));
-                    AstNode *proto_node = casted_init_value->value.data.x_ptr.data.fn.fn_entry->proto_node;
-                    add_error_note(ira->codegen, msg, proto_node, buf_sprintf("declared here"));
-                    result_type = ira->codegen->builtin_types.entry_invalid;
-                }
-            }
         }
+        break;
     }
 
     if (var->value->type != nullptr && !is_comptime_var) {
@@ -12912,10 +12919,15 @@ static bool ir_analyze_fn_call_generic_arg(IrAnalyze *ira, AstNode *fn_proto_nod
     }
 
     if (!comptime_arg) {
-        if (type_requires_comptime(casted_arg->value.type)) {
+        switch (type_requires_comptime(ira->codegen, casted_arg->value.type)) {
+        case ReqCompTimeYes:
             ir_add_error(ira, casted_arg,
                 buf_sprintf("parameter of type '%s' requires comptime", buf_ptr(&casted_arg->value.type->name)));
             return false;
+        case ReqCompTimeInvalid:
+            return false;
+        case ReqCompTimeNo:
+            break;
         }
 
         casted_args[fn_type_id->param_count] = casted_arg;
@@ -13388,12 +13400,15 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *call
                 inst_fn_type_id.return_type = specified_return_type;
             }
 
-            if ((err = type_resolve(ira->codegen, specified_return_type, ResolveStatusZeroBitsKnown)))
-                return ira->codegen->invalid_instruction;
-
-            if (type_requires_comptime(specified_return_type)) {
+            switch (type_requires_comptime(ira->codegen, specified_return_type)) {
+            case ReqCompTimeYes:
                 // Throw out our work and call the function as if it were comptime.
-                return ir_analyze_fn_call(ira, call_instruction, fn_entry, fn_type, fn_ref, first_arg_ptr, true, FnInlineAuto);
+                return ir_analyze_fn_call(ira, call_instruction, fn_entry, fn_type, fn_ref, first_arg_ptr,
+                        true, FnInlineAuto);
+            case ReqCompTimeInvalid:
+                return ira->codegen->invalid_instruction;
+            case ReqCompTimeNo:
+                break;
             }
         }
         IrInstruction *async_allocator_inst = nullptr;
@@ -14334,11 +14349,16 @@ static IrInstruction *ir_analyze_instruction_elem_ptr(IrAnalyze *ira, IrInstruct
 
     } else {
         // runtime known element index
-        if (type_requires_comptime(return_type)) {
+        switch (type_requires_comptime(ira->codegen, return_type)) {
+        case ReqCompTimeYes:
             ir_add_error(ira, elem_index,
                 buf_sprintf("values of type '%s' must be comptime known, but index value is runtime known",
                     buf_ptr(&return_type->data.pointer.child_type->name)));
             return ira->codegen->invalid_instruction;
+        case ReqCompTimeInvalid:
+            return ira->codegen->invalid_instruction;
+        case ReqCompTimeNo:
+            break;
         }
         if (ptr_align < abi_align) {
             if (elem_size >= ptr_align && elem_size % ptr_align == 0) {
@@ -19390,7 +19410,6 @@ static IrInstruction *ir_analyze_instruction_unwrap_err_payload(IrAnalyze *ira,
 }
 
 static IrInstruction *ir_analyze_instruction_fn_proto(IrAnalyze *ira, IrInstructionFnProto *instruction) {
-    Error err;
     AstNode *proto_node = instruction->base.source_node;
     assert(proto_node->type == NodeTypeFnProto);
 
@@ -19429,11 +19448,8 @@ static IrInstruction *ir_analyze_instruction_fn_proto(IrAnalyze *ira, IrInstruct
             if (type_is_invalid(param_type_value->value.type))
                 return ira->codegen->invalid_instruction;
             ZigType *param_type = ir_resolve_type(ira, param_type_value);
-            if (type_is_invalid(param_type))
-                return ira->codegen->invalid_instruction;
-            if ((err = type_resolve(ira->codegen, param_type, ResolveStatusZeroBitsKnown)))
-                return ira->codegen->invalid_instruction;
-            if (type_requires_comptime(param_type)) {
+            switch (type_requires_comptime(ira->codegen, param_type)) {
+            case ReqCompTimeYes:
                 if (!calling_convention_allows_zig_types(fn_type_id.cc)) {
                     ir_add_error(ira, param_type_value,
                         buf_sprintf("parameter of type '%s' not allowed in function with calling convention '%s'",
@@ -19443,6 +19459,10 @@ static IrInstruction *ir_analyze_instruction_fn_proto(IrAnalyze *ira, IrInstruct
                 param_info->type = param_type;
                 fn_type_id.next_param_index += 1;
                 return ir_const_type(ira, &instruction->base, get_generic_fn_type(ira->codegen, &fn_type_id));
+            case ReqCompTimeInvalid:
+                return ira->codegen->invalid_instruction;
+            case ReqCompTimeNo:
+                break;
             }
             if (!type_has_bits(param_type) && !calling_convention_allows_zig_types(fn_type_id.cc)) {
                 ir_add_error(ira, param_type_value,