Commit abbc395701

Josh Wolfe <thejoshwolfe@gmail.com>
2015-12-01 02:43:45
implement basics of type checking
1 parent ef482ec
Changed files (1)
src/analyze.cpp
@@ -10,6 +10,12 @@
 #include "error.hpp"
 #include "zig_llvm.hpp"
 
+struct BlockContext {
+    AstNode *node;
+    BlockContext *root;
+    BlockContext *parent;
+};
+
 static void add_node_error(CodeGen *g, AstNode *node, Buf *msg) {
     g->errors.add_one();
     ErrorMsg *last_msg = &g->errors.last();
@@ -229,6 +235,155 @@ static void preview_function_declarations(CodeGen *g, AstNode *node) {
     }
 }
 
+static TypeTableEntry * get_return_type(BlockContext *context) {
+    AstNode *fn_def_node = context->root->node;
+    assert(fn_def_node->type == NodeTypeFnDef);
+    AstNode *fn_proto_node = fn_def_node->data.fn_def.fn_proto;
+    assert(fn_proto_node->type == NodeTypeFnProto);
+    AstNode *return_type_node = fn_proto_node->data.fn_proto.return_type;
+    assert(return_type_node->codegen_node);
+    return return_type_node->codegen_node->data.type_node.entry;
+}
+
+static void check_type_compatibility(CodeGen *g, AstNode *node, TypeTableEntry *expected_type, TypeTableEntry *actual_type) {
+    if (expected_type == actual_type)
+        return; // good
+    if (expected_type == g->builtin_types.entry_invalid || actual_type == g->builtin_types.entry_invalid)
+        return; // already complained
+    if (actual_type == g->builtin_types.entry_unreachable)
+        return; // TODO: is this true?
+
+    // TODO better error message
+    add_node_error(g, node, buf_sprintf("type mismatch."));
+}
+
+static TypeTableEntry * analyze_expression(CodeGen *g, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) {
+    switch (node->type) {
+        case NodeTypeBlock:
+            {
+                // TODO: nested block scopes
+                TypeTableEntry *return_type = g->builtin_types.entry_void;
+                for (int i = 0; i < node->data.block.statements.length; i += 1) {
+                    AstNode *child = node->data.block.statements.at(i);
+                    if (return_type == g->builtin_types.entry_unreachable) {
+                        add_node_error(g, child,
+                                buf_sprintf("unreachable code"));
+                        break;
+                    }
+                    return_type = analyze_expression(g, context, nullptr, child);
+                }
+                return return_type;
+            }
+
+        case NodeTypeReturnExpr:
+            {
+                TypeTableEntry *expected_return_type = get_return_type(context);
+                TypeTableEntry *actual_return_type;
+                if (node->data.return_expr.expr) {
+                    actual_return_type = analyze_expression(g, context, expected_return_type, node->data.return_expr.expr);
+                } else {
+                    actual_return_type = g->builtin_types.entry_void;
+                }
+
+                if (actual_return_type == g->builtin_types.entry_unreachable) {
+                    // "return exit(0)" should just be "exit(0)".
+                    add_node_error(g, node, buf_sprintf("returning is unreachable."));
+                    actual_return_type = g->builtin_types.entry_invalid;
+                }
+
+                check_type_compatibility(g, node, expected_return_type, actual_return_type);
+                return g->builtin_types.entry_unreachable;
+            }
+
+        case NodeTypeBinOpExpr:
+            {
+                // TODO: think about expected types
+                analyze_expression(g, context, expected_type, node->data.bin_op_expr.op1);
+                analyze_expression(g, context, expected_type, node->data.bin_op_expr.op2);
+                return expected_type;
+            }
+
+        case NodeTypeFnCallExpr:
+            {
+                Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
+
+                auto entry = g->fn_table.maybe_get(name);
+                if (!entry) {
+                    add_node_error(g, node,
+                            buf_sprintf("undefined function: '%s'", buf_ptr(name)));
+                    // still analyze the parameters, even though we don't know what to expect
+                    for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
+                        AstNode *child = node->data.fn_call_expr.params.at(i);
+                        analyze_expression(g, context, nullptr, child);
+                    }
+
+                    return g->builtin_types.entry_invalid;
+                } else {
+                    FnTableEntry *fn_table_entry = entry->value;
+                    assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
+                    AstNodeFnProto *fn_proto = &fn_table_entry->proto_node->data.fn_proto;
+
+                    // count parameters
+                    int expected_param_count = fn_proto->params.length;
+                    int actual_param_count = node->data.fn_call_expr.params.length;
+                    if (expected_param_count != actual_param_count) {
+                        add_node_error(g, node,
+                                buf_sprintf("wrong number of arguments. Expected %d, got %d.",
+                                    expected_param_count, actual_param_count));
+                    }
+
+                    // analyze each parameter
+                    for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
+                        AstNode *child = node->data.fn_call_expr.params.at(i);
+                        // determine the expected type for each parameter
+                        TypeTableEntry *expected_param_type = nullptr;
+                        if (i < fn_proto->params.length) {
+                            AstNode *param_decl_node = fn_proto->params.at(i);
+                            assert(param_decl_node->type == NodeTypeParamDecl);
+                            AstNode *param_type_node = param_decl_node->data.param_decl.type;
+                            if (param_type_node->codegen_node)
+                                expected_param_type = param_type_node->codegen_node->data.type_node.entry;
+                        }
+                        analyze_expression(g, context, expected_param_type, child);
+                    }
+
+                    TypeTableEntry *return_type = fn_proto->return_type->codegen_node->data.type_node.entry;
+                    check_type_compatibility(g, node, expected_type, return_type);
+                    return return_type;
+                }
+            }
+
+        case NodeTypeNumberLiteral:
+            // TODO: generic literal int type
+            return g->builtin_types.entry_i32;
+
+        case NodeTypeStringLiteral:
+            zig_panic("TODO");
+
+        case NodeTypeUnreachable:
+            return g->builtin_types.entry_unreachable;
+
+        case NodeTypeSymbol:
+            // look up symbol in symbol table
+            zig_panic("TODO");
+
+        case NodeTypeCastExpr:
+        case NodeTypePrefixOpExpr:
+            zig_panic("TODO");
+        case NodeTypeDirective:
+        case NodeTypeFnDecl:
+        case NodeTypeFnProto:
+        case NodeTypeParamDecl:
+        case NodeTypeType:
+        case NodeTypeRoot:
+        case NodeTypeRootExportDecl:
+        case NodeTypeExternBlock:
+        case NodeTypeFnDef:
+            zig_unreachable();
+    }
+    zig_unreachable();
+}
+
 static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
     // Follow the execution flow and make sure the code returns appropriately.
     // * A `return` statement in an unreachable type function should be an error.
@@ -282,74 +437,6 @@ static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
     }
 }
 
-static void analyze_expression(CodeGen *g, AstNode *node) {
-    switch (node->type) {
-        case NodeTypeBlock:
-            for (int i = 0; i < node->data.block.statements.length; i += 1) {
-                AstNode *child = node->data.block.statements.at(i);
-                analyze_expression(g, child);
-            }
-            break;
-        case NodeTypeReturnExpr:
-            if (node->data.return_expr.expr) {
-                analyze_expression(g, node->data.return_expr.expr);
-            }
-            break;
-        case NodeTypeBinOpExpr:
-            analyze_expression(g, node->data.bin_op_expr.op1);
-            analyze_expression(g, node->data.bin_op_expr.op2);
-            break;
-        case NodeTypeFnCallExpr:
-            {
-                Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
-
-                auto entry = g->fn_table.maybe_get(name);
-                if (!entry) {
-                    add_node_error(g, node,
-                            buf_sprintf("undefined function: '%s'", buf_ptr(name)));
-                } else {
-                    FnTableEntry *fn_table_entry = entry->value;
-                    assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
-                    int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
-                    int actual_param_count = node->data.fn_call_expr.params.length;
-                    if (expected_param_count != actual_param_count) {
-                        add_node_error(g, node,
-                                buf_sprintf("wrong number of arguments. Expected %d, got %d.",
-                                    expected_param_count, actual_param_count));
-                    }
-                }
-
-                for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
-                    AstNode *child = node->data.fn_call_expr.params.at(i);
-                    analyze_expression(g, child);
-                }
-                break;
-            }
-        case NodeTypeCastExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypePrefixOpExpr:
-            zig_panic("TODO");
-            break;
-        case NodeTypeNumberLiteral:
-        case NodeTypeStringLiteral:
-        case NodeTypeUnreachable:
-        case NodeTypeSymbol:
-            // nothing to do
-            break;
-        case NodeTypeDirective:
-        case NodeTypeFnDecl:
-        case NodeTypeFnProto:
-        case NodeTypeParamDecl:
-        case NodeTypeType:
-        case NodeTypeRoot:
-        case NodeTypeRootExportDecl:
-        case NodeTypeExternBlock:
-        case NodeTypeFnDef:
-            zig_unreachable();
-    }
-}
-
 static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
     switch (node->type) {
         case NodeTypeFnDef:
@@ -371,7 +458,13 @@ static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
                 }
 
                 check_fn_def_control_flow(g, node);
-                analyze_expression(g, node->data.fn_def.body);
+
+                BlockContext context;
+                context.node = node;
+                context.root = &context;
+                context.parent = nullptr;
+                TypeTableEntry *expected_type = fn_proto->return_type->codegen_node->data.type_node.entry;
+                analyze_expression(g, &context, expected_type, node->data.fn_def.body);
             }
             break;
 
@@ -424,6 +517,12 @@ static void analyze_root(CodeGen *g, AstNode *node) {
 }
 
 static void define_primitive_types(CodeGen *g) {
+    {
+        // if this type is anywhere in the AST, we should never hit codegen.
+        TypeTableEntry *entry = allocate<TypeTableEntry>(1);
+        buf_init_from_str(&entry->name, "(invalid)");
+        g->builtin_types.entry_invalid = entry;
+    }
     {
         TypeTableEntry *entry = allocate<TypeTableEntry>(1);
         entry->type_ref = LLVMInt8Type();
@@ -450,15 +549,12 @@ static void define_primitive_types(CodeGen *g) {
                 LLVMZigEncoding_DW_ATE_unsigned());
         g->type_table.put(&entry->name, entry);
         g->builtin_types.entry_void = entry;
-
-        // invalid types are void
-        g->builtin_types.entry_invalid = entry;
     }
     {
         TypeTableEntry *entry = allocate<TypeTableEntry>(1);
         entry->type_ref = LLVMVoidType();
         buf_init_from_str(&entry->name, "unreachable");
-        entry->di_type = g->builtin_types.entry_invalid->di_type;
+        entry->di_type = g->builtin_types.entry_void->di_type;
         g->type_table.put(&entry->name, entry);
         g->builtin_types.entry_unreachable = entry;
     }