Commit abbc395701
Changed files (1)
src
src/analyze.cpp
@@ -10,6 +10,12 @@
#include "error.hpp"
#include "zig_llvm.hpp"
+struct BlockContext {
+ AstNode *node;
+ BlockContext *root;
+ BlockContext *parent;
+};
+
static void add_node_error(CodeGen *g, AstNode *node, Buf *msg) {
g->errors.add_one();
ErrorMsg *last_msg = &g->errors.last();
@@ -229,6 +235,155 @@ static void preview_function_declarations(CodeGen *g, AstNode *node) {
}
}
+static TypeTableEntry * get_return_type(BlockContext *context) {
+ AstNode *fn_def_node = context->root->node;
+ assert(fn_def_node->type == NodeTypeFnDef);
+ AstNode *fn_proto_node = fn_def_node->data.fn_def.fn_proto;
+ assert(fn_proto_node->type == NodeTypeFnProto);
+ AstNode *return_type_node = fn_proto_node->data.fn_proto.return_type;
+ assert(return_type_node->codegen_node);
+ return return_type_node->codegen_node->data.type_node.entry;
+}
+
+static void check_type_compatibility(CodeGen *g, AstNode *node, TypeTableEntry *expected_type, TypeTableEntry *actual_type) {
+ if (expected_type == actual_type)
+ return; // good
+ if (expected_type == g->builtin_types.entry_invalid || actual_type == g->builtin_types.entry_invalid)
+ return; // already complained
+ if (actual_type == g->builtin_types.entry_unreachable)
+ return; // TODO: is this true?
+
+ // TODO better error message
+ add_node_error(g, node, buf_sprintf("type mismatch."));
+}
+
+static TypeTableEntry * analyze_expression(CodeGen *g, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) {
+ switch (node->type) {
+ case NodeTypeBlock:
+ {
+ // TODO: nested block scopes
+ TypeTableEntry *return_type = g->builtin_types.entry_void;
+ for (int i = 0; i < node->data.block.statements.length; i += 1) {
+ AstNode *child = node->data.block.statements.at(i);
+ if (return_type == g->builtin_types.entry_unreachable) {
+ add_node_error(g, child,
+ buf_sprintf("unreachable code"));
+ break;
+ }
+ return_type = analyze_expression(g, context, nullptr, child);
+ }
+ return return_type;
+ }
+
+ case NodeTypeReturnExpr:
+ {
+ TypeTableEntry *expected_return_type = get_return_type(context);
+ TypeTableEntry *actual_return_type;
+ if (node->data.return_expr.expr) {
+ actual_return_type = analyze_expression(g, context, expected_return_type, node->data.return_expr.expr);
+ } else {
+ actual_return_type = g->builtin_types.entry_void;
+ }
+
+ if (actual_return_type == g->builtin_types.entry_unreachable) {
+ // "return exit(0)" should just be "exit(0)".
+ add_node_error(g, node, buf_sprintf("returning is unreachable."));
+ actual_return_type = g->builtin_types.entry_invalid;
+ }
+
+ check_type_compatibility(g, node, expected_return_type, actual_return_type);
+ return g->builtin_types.entry_unreachable;
+ }
+
+ case NodeTypeBinOpExpr:
+ {
+ // TODO: think about expected types
+ analyze_expression(g, context, expected_type, node->data.bin_op_expr.op1);
+ analyze_expression(g, context, expected_type, node->data.bin_op_expr.op2);
+ return expected_type;
+ }
+
+ case NodeTypeFnCallExpr:
+ {
+ Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
+
+ auto entry = g->fn_table.maybe_get(name);
+ if (!entry) {
+ add_node_error(g, node,
+ buf_sprintf("undefined function: '%s'", buf_ptr(name)));
+ // still analyze the parameters, even though we don't know what to expect
+ for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
+ AstNode *child = node->data.fn_call_expr.params.at(i);
+ analyze_expression(g, context, nullptr, child);
+ }
+
+ return g->builtin_types.entry_invalid;
+ } else {
+ FnTableEntry *fn_table_entry = entry->value;
+ assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
+ AstNodeFnProto *fn_proto = &fn_table_entry->proto_node->data.fn_proto;
+
+ // count parameters
+ int expected_param_count = fn_proto->params.length;
+ int actual_param_count = node->data.fn_call_expr.params.length;
+ if (expected_param_count != actual_param_count) {
+ add_node_error(g, node,
+ buf_sprintf("wrong number of arguments. Expected %d, got %d.",
+ expected_param_count, actual_param_count));
+ }
+
+ // analyze each parameter
+ for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
+ AstNode *child = node->data.fn_call_expr.params.at(i);
+ // determine the expected type for each parameter
+ TypeTableEntry *expected_param_type = nullptr;
+ if (i < fn_proto->params.length) {
+ AstNode *param_decl_node = fn_proto->params.at(i);
+ assert(param_decl_node->type == NodeTypeParamDecl);
+ AstNode *param_type_node = param_decl_node->data.param_decl.type;
+ if (param_type_node->codegen_node)
+ expected_param_type = param_type_node->codegen_node->data.type_node.entry;
+ }
+ analyze_expression(g, context, expected_param_type, child);
+ }
+
+ TypeTableEntry *return_type = fn_proto->return_type->codegen_node->data.type_node.entry;
+ check_type_compatibility(g, node, expected_type, return_type);
+ return return_type;
+ }
+ }
+
+ case NodeTypeNumberLiteral:
+ // TODO: generic literal int type
+ return g->builtin_types.entry_i32;
+
+ case NodeTypeStringLiteral:
+ zig_panic("TODO");
+
+ case NodeTypeUnreachable:
+ return g->builtin_types.entry_unreachable;
+
+ case NodeTypeSymbol:
+ // look up symbol in symbol table
+ zig_panic("TODO");
+
+ case NodeTypeCastExpr:
+ case NodeTypePrefixOpExpr:
+ zig_panic("TODO");
+ case NodeTypeDirective:
+ case NodeTypeFnDecl:
+ case NodeTypeFnProto:
+ case NodeTypeParamDecl:
+ case NodeTypeType:
+ case NodeTypeRoot:
+ case NodeTypeRootExportDecl:
+ case NodeTypeExternBlock:
+ case NodeTypeFnDef:
+ zig_unreachable();
+ }
+ zig_unreachable();
+}
+
static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
// Follow the execution flow and make sure the code returns appropriately.
// * A `return` statement in an unreachable type function should be an error.
@@ -282,74 +437,6 @@ static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
}
}
-static void analyze_expression(CodeGen *g, AstNode *node) {
- switch (node->type) {
- case NodeTypeBlock:
- for (int i = 0; i < node->data.block.statements.length; i += 1) {
- AstNode *child = node->data.block.statements.at(i);
- analyze_expression(g, child);
- }
- break;
- case NodeTypeReturnExpr:
- if (node->data.return_expr.expr) {
- analyze_expression(g, node->data.return_expr.expr);
- }
- break;
- case NodeTypeBinOpExpr:
- analyze_expression(g, node->data.bin_op_expr.op1);
- analyze_expression(g, node->data.bin_op_expr.op2);
- break;
- case NodeTypeFnCallExpr:
- {
- Buf *name = hack_get_fn_call_name(g, node->data.fn_call_expr.fn_ref_expr);
-
- auto entry = g->fn_table.maybe_get(name);
- if (!entry) {
- add_node_error(g, node,
- buf_sprintf("undefined function: '%s'", buf_ptr(name)));
- } else {
- FnTableEntry *fn_table_entry = entry->value;
- assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
- int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
- int actual_param_count = node->data.fn_call_expr.params.length;
- if (expected_param_count != actual_param_count) {
- add_node_error(g, node,
- buf_sprintf("wrong number of arguments. Expected %d, got %d.",
- expected_param_count, actual_param_count));
- }
- }
-
- for (int i = 0; i < node->data.fn_call_expr.params.length; i += 1) {
- AstNode *child = node->data.fn_call_expr.params.at(i);
- analyze_expression(g, child);
- }
- break;
- }
- case NodeTypeCastExpr:
- zig_panic("TODO");
- break;
- case NodeTypePrefixOpExpr:
- zig_panic("TODO");
- break;
- case NodeTypeNumberLiteral:
- case NodeTypeStringLiteral:
- case NodeTypeUnreachable:
- case NodeTypeSymbol:
- // nothing to do
- break;
- case NodeTypeDirective:
- case NodeTypeFnDecl:
- case NodeTypeFnProto:
- case NodeTypeParamDecl:
- case NodeTypeType:
- case NodeTypeRoot:
- case NodeTypeRootExportDecl:
- case NodeTypeExternBlock:
- case NodeTypeFnDef:
- zig_unreachable();
- }
-}
-
static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
switch (node->type) {
case NodeTypeFnDef:
@@ -371,7 +458,13 @@ static void analyze_top_level_declaration(CodeGen *g, AstNode *node) {
}
check_fn_def_control_flow(g, node);
- analyze_expression(g, node->data.fn_def.body);
+
+ BlockContext context;
+ context.node = node;
+ context.root = &context;
+ context.parent = nullptr;
+ TypeTableEntry *expected_type = fn_proto->return_type->codegen_node->data.type_node.entry;
+ analyze_expression(g, &context, expected_type, node->data.fn_def.body);
}
break;
@@ -424,6 +517,12 @@ static void analyze_root(CodeGen *g, AstNode *node) {
}
static void define_primitive_types(CodeGen *g) {
+ {
+ // if this type is anywhere in the AST, we should never hit codegen.
+ TypeTableEntry *entry = allocate<TypeTableEntry>(1);
+ buf_init_from_str(&entry->name, "(invalid)");
+ g->builtin_types.entry_invalid = entry;
+ }
{
TypeTableEntry *entry = allocate<TypeTableEntry>(1);
entry->type_ref = LLVMInt8Type();
@@ -450,15 +549,12 @@ static void define_primitive_types(CodeGen *g) {
LLVMZigEncoding_DW_ATE_unsigned());
g->type_table.put(&entry->name, entry);
g->builtin_types.entry_void = entry;
-
- // invalid types are void
- g->builtin_types.entry_invalid = entry;
}
{
TypeTableEntry *entry = allocate<TypeTableEntry>(1);
entry->type_ref = LLVMVoidType();
buf_init_from_str(&entry->name, "unreachable");
- entry->di_type = g->builtin_types.entry_invalid->di_type;
+ entry->di_type = g->builtin_types.entry_void->di_type;
g->type_table.put(&entry->name, entry);
g->builtin_types.entry_unreachable = entry;
}