Commit 3c3be10a60

Andrew Kelley <superjoe30@gmail.com>
2015-12-07 07:09:46
add mutable local variables
1 parent dfb48a2
example/expressions/expressions.zig
@@ -1,3 +1,5 @@
+export executable "expressions";
+
 #link("c")
 extern {
     fn puts(s: *const u8) -> i32;
@@ -28,6 +30,8 @@ export fn _start() -> unreachable {
 
     void_fun(1, void, 2);
 
+    test_mutable_vars();
+
     other_exit();
 }
 
@@ -38,3 +42,15 @@ fn void_fun(a : i32, b : void, c : i32) -> void {
     let w : void = z; // void
     if (x + y == 4) { return w; }
 }
+
+fn test_mutable_vars() {
+    let mut i = 0;
+loop_start:
+    if i == 3 {
+        goto done;
+    }
+    puts("loop");
+    i = i + 1;
+    goto loop_start;
+done:
+}
src/analyze.cpp
@@ -362,6 +362,13 @@ static BlockContext *new_block_context(AstNode *node, BlockContext *parent) {
     else
         context->root = context;
     context->variable_table.init(8);
+
+    AstNode *fn_def_node = context->root->node;
+    assert(fn_def_node->type == NodeTypeFnDef);
+    assert(fn_def_node->codegen_node);
+    FnDefNode *fn_def_info = &fn_def_node->codegen_node->data.fn_def_node;
+    fn_def_info->all_block_contexts.append(context);
+
     return context;
 }
 
@@ -388,8 +395,13 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
                 return_type = g->builtin_types.entry_void;
                 for (int i = 0; i < node->data.block.statements.length; i += 1) {
                     AstNode *child = node->data.block.statements.at(i);
-                    if (child->type == NodeTypeLabel)
+                    if (child->type == NodeTypeLabel) {
+                        LabelTableEntry *label_entry = child->codegen_node->data.label_entry;
+                        assert(label_entry);
+                        label_entry->entered_from_fallthrough = (return_type != g->builtin_types.entry_unreachable);
+                        return_type = g->builtin_types.entry_void;
                         continue;
+                    }
                     if (return_type == g->builtin_types.entry_unreachable) {
                         if (child->type == NodeTypeVoid) {
                             // {unreachable;void;void} is allowed.
@@ -457,6 +469,8 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
                     LocalVariableTableEntry *variable_entry = allocate<LocalVariableTableEntry>(1);
                     buf_init_from_buf(&variable_entry->name, &variable_declaration->symbol);
                     variable_entry->type = type;
+                    variable_entry->is_const = variable_declaration->is_const;
+                    variable_entry->decl_node = node;
                     context->variable_table.put(&variable_entry->name, variable_entry);
                 }
                 return_type = g->builtin_types.entry_void;
@@ -482,6 +496,32 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
         case NodeTypeBinOpExpr:
             {
                 switch (node->data.bin_op_expr.bin_op) {
+                    case BinOpTypeAssign:
+                        {
+                            AstNode *lhs_node = node->data.bin_op_expr.op1;
+                            if (lhs_node->type == NodeTypeSymbol) {
+                                Buf *name = &lhs_node->data.symbol;
+                                LocalVariableTableEntry *var = find_local_variable(context, name);
+                                if (var) {
+                                    if (var->is_const) {
+                                        add_node_error(g, lhs_node,
+                                            buf_sprintf("cannot assign to constant variable"));
+                                    } else {
+                                        analyze_expression(g, import, context, var->type,
+                                                node->data.bin_op_expr.op2);
+                                    }
+                                } else {
+                                    add_node_error(g, lhs_node,
+                                            buf_sprintf("use of undeclared identifier '%s'", buf_ptr(name)));
+                                }
+
+                            } else {
+                                add_node_error(g, lhs_node,
+                                        buf_sprintf("expected a bare identifier"));
+                            }
+                            return_type = g->builtin_types.entry_void;
+                            break;
+                        }
                     case BinOpTypeBoolOr:
                     case BinOpTypeBoolAnd:
                         analyze_expression(g, import, context, g->builtin_types.entry_bool,
@@ -721,7 +761,10 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
                 AstNode *fn_proto_node = node->data.fn_def.fn_proto;
                 assert(fn_proto_node->type == NodeTypeFnProto);
 
+                assert(!node->codegen_node);
+                node->codegen_node = allocate<CodeGenNode>(1);
                 BlockContext *context = new_block_context(node, nullptr);
+                node->codegen_node->data.fn_def_node.block_context = context;
 
                 AstNodeFnProto *fn_proto = &fn_proto_node->data.fn_proto;
                 for (int i = 0; i < fn_proto->params.length; i += 1) {
@@ -736,6 +779,8 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
                     LocalVariableTableEntry *variable_entry = allocate<LocalVariableTableEntry>(1);
                     buf_init_from_buf(&variable_entry->name, &param_decl->name);
                     variable_entry->type = type;
+                    variable_entry->is_const = true;
+                    variable_entry->decl_node = param_decl_node;
 
                     LocalVariableTableEntry *existing_entry = find_local_variable(context, &variable_entry->name);
                     if (!existing_entry) {
@@ -756,9 +801,7 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
                 TypeTableEntry *expected_type = fn_proto->return_type->codegen_node->data.type_node.entry;
                 TypeTableEntry *block_return_type = analyze_expression(g, import, context, expected_type, node->data.fn_def.body);
 
-                node->codegen_node = allocate<CodeGenNode>(1);
                 node->codegen_node->data.fn_def_node.implicit_return_type = block_return_type;
-                node->codegen_node->data.fn_def_node.block_context = context;
 
                 {
                     FnTableEntry *fn_table_entry = fn_proto_node->codegen_node->data.fn_proto_node.fn_table_entry;
src/codegen.cpp
@@ -102,6 +102,7 @@ static int count_non_void_params(CodeGen *g, ZigList<AstNode *> *params) {
 }
 
 static void add_debug_source_node(CodeGen *g, AstNode *node) {
+    // TODO g->block_scopes.last() is not always correct and should probably integrate with BlockContext
     LLVMZigSetCurrentDebugLocation(g->builder, node->line + 1, node->column + 1, g->block_scopes.last());
 }
 
@@ -210,6 +211,8 @@ static LLVMValueRef gen_arithmetic_bin_op_expr(CodeGen *g, AstNode *node) {
     LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
 
     switch (node->data.bin_op_expr.bin_op) {
+        case BinOpTypeAssign:
+            zig_panic("TODO assignment");
         case BinOpTypeBinOr:
             add_debug_source_node(g, node);
             return LLVMBuildOr(g->builder, val1, val2, "");
@@ -358,8 +361,28 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
     return phi;
 }
 
+static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeBinOpExpr);
+
+    AstNode *symbol_node = node->data.bin_op_expr.op1;
+    assert(symbol_node->type == NodeTypeSymbol);
+
+    LocalVariableTableEntry *var = find_local_variable(node->codegen_node->expr_node.block_context,
+            &symbol_node->data.symbol);
+
+    // semantic checking ensures no variables are constant
+    assert(!var->is_const);
+
+    LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
+
+    add_debug_source_node(g, node);
+    return LLVMBuildStore(g->builder, value, var->value_ref);
+}
+
 static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) {
     switch (node->data.bin_op_expr.bin_op) {
+        case BinOpTypeAssign:
+            return gen_assign_expr(g, node);
         case BinOpTypeInvalid:
             zig_unreachable();
         case BinOpTypeBoolOr:
@@ -498,9 +521,20 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
         case NodeTypeVariableDeclaration:
             {
                 LocalVariableTableEntry *variable = find_local_variable(node->codegen_node->expr_node.block_context, &node->data.variable_declaration.symbol);
-                assert(node->data.variable_declaration.expr);
-                variable->value_ref = gen_expr(g, node->data.variable_declaration.expr);
-                return nullptr;
+                if (variable->is_const) {
+                    assert(node->data.variable_declaration.expr);
+                    variable->value_ref = gen_expr(g, node->data.variable_declaration.expr);
+                    return nullptr;
+                } else {
+                    if (node->data.variable_declaration.expr) {
+                        LLVMValueRef value = gen_expr(g, node->data.variable_declaration.expr);
+
+                        add_debug_source_node(g, node);
+                        return LLVMBuildStore(g->builder, value, variable->value_ref);
+                    } else {
+
+                    }
+                }
             }
         case NodeTypeCastExpr:
             return gen_cast_expr(g, node);
@@ -542,7 +576,11 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
         case NodeTypeSymbol:
             {
                 LocalVariableTableEntry *variable = find_local_variable(node->codegen_node->expr_node.block_context, &node->data.symbol);
-                return variable->value_ref;
+                if (variable->is_const) {
+                    return variable->value_ref;
+                } else {
+                    return LLVMBuildLoad(g->builder, variable->value_ref, "");
+                }
             }
         case NodeTypeBlock:
             return gen_block(g, node, nullptr);
@@ -551,11 +589,15 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
             return LLVMBuildBr(g->builder, node->codegen_node->data.label_entry->basic_block);
         case NodeTypeLabel:
             {
-                LLVMBasicBlockRef basic_block = node->codegen_node->data.label_entry->basic_block;
-                add_debug_source_node(g, node);
-                LLVMValueRef result = LLVMBuildBr(g->builder, basic_block);
+                LabelTableEntry *label_entry = node->codegen_node->data.label_entry;
+                assert(label_entry);
+                LLVMBasicBlockRef basic_block = label_entry->basic_block;
+                if (label_entry->entered_from_fallthrough) {
+                    add_debug_source_node(g, node);
+                    LLVMBuildBr(g->builder, basic_block);
+                }
                 LLVMPositionBuilderAtEnd(g->builder, basic_block);
-                return result;
+                return nullptr;
             }
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
@@ -696,6 +738,24 @@ static void do_code_gen(CodeGen *g) {
 
         build_label_blocks(g, fn_def_node->data.fn_def.body);
 
+        // allocate all local variables
+        for (int i = 0; i < codegen_fn_def->all_block_contexts.length; i += 1) {
+            BlockContext *block_context = codegen_fn_def->all_block_contexts.at(i);
+
+            auto it = block_context->variable_table.entry_iterator();
+            for (;;) {
+                auto *entry = it.next();
+                if (!entry)
+                    break;
+
+                LocalVariableTableEntry *var = entry->value;
+                if (!var->is_const) {
+                    add_debug_source_node(g, var->decl_node);
+                    var->value_ref = LLVMBuildAlloca(g->builder, var->type->type_ref, buf_ptr(&var->name));
+                }
+            }
+        }
+
         TypeTableEntry *implicit_return_type = codegen_fn_def->implicit_return_type;
         gen_block(g, fn_def_node->data.fn_def.body, implicit_return_type);
 
src/parser.cpp
@@ -33,6 +33,7 @@ static const char *bin_op_str(BinOpType bin_op) {
         case BinOpTypeMult:           return "*";
         case BinOpTypeDiv:            return "/";
         case BinOpTypeMod:            return "%";
+        case BinOpTypeAssign:         return "=";
     }
     zig_unreachable();
 }
@@ -1096,7 +1097,7 @@ static AstNode *ast_parse_return_expr(ParseContext *pc, int *token_index, bool m
 }
 
 /*
-VariableDeclaration : token(Let) token(Symbole) (token(Eq) Expression | token(Colon) Type option(token(Eq) Expression))
+VariableDeclaration : token(Let) option(token(Mut)) token(Symbol) (token(Eq) Expression | token(Colon) Type option(token(Eq) Expression))
 */
 static AstNode *ast_parse_variable_declaration_expr(ParseContext *pc, int *token_index, bool mandatory) {
     Token *let_tok = &pc->tokens->at(*token_index);
@@ -1104,9 +1105,21 @@ static AstNode *ast_parse_variable_declaration_expr(ParseContext *pc, int *token
         *token_index += 1;
         AstNode *node = ast_create_node(pc, NodeTypeVariableDeclaration, let_tok);
 
-        Token *name_token = &pc->tokens->at(*token_index);
+        Token *name_token;
+        Token *token = &pc->tokens->at(*token_index);
+        if (token->id == TokenIdKeywordMut) {
+            node->data.variable_declaration.is_const = false;
+            *token_index += 1;
+            name_token = &pc->tokens->at(*token_index);
+            ast_expect_token(pc, name_token, TokenIdSymbol);
+        } else if (token->id == TokenIdSymbol) {
+            node->data.variable_declaration.is_const = true;
+            name_token = token;
+        } else {
+            ast_invalid_token_error(pc, token);
+        }
+
         *token_index += 1;
-        ast_expect_token(pc, name_token, TokenIdSymbol);
         ast_buf_from_token(pc, name_token, &node->data.variable_declaration.symbol);
 
         Token *eq_or_colon = &pc->tokens->at(*token_index);
@@ -1178,7 +1191,30 @@ static AstNode *ast_parse_block_expr(ParseContext *pc, int *token_index, bool ma
 }
 
 /*
-NonBlockExpression : ReturnExpression | VariableDeclaration | BoolOrExpression
+AssignmentExpression : BoolOrExpression token(Equal) BoolOrExpression | BoolOrExpression
+*/
+static AstNode *ast_parse_ass_expr(ParseContext *pc, int *token_index, bool mandatory) {
+    AstNode *lhs = ast_parse_bool_or_expr(pc, token_index, mandatory);
+    if (!lhs)
+        return lhs;
+
+    Token *token = &pc->tokens->at(*token_index);
+    if (token->id != TokenIdEq)
+        return lhs;
+    *token_index += 1;
+
+    AstNode *rhs = ast_parse_bool_or_expr(pc, token_index, true);
+
+    AstNode *node = ast_create_node(pc, NodeTypeBinOpExpr, token);
+    node->data.bin_op_expr.op1 = lhs;
+    node->data.bin_op_expr.bin_op = BinOpTypeAssign;
+    node->data.bin_op_expr.op2 = rhs;
+
+    return node;
+}
+
+/*
+NonBlockExpression : ReturnExpression | VariableDeclaration | AssignmentExpression
 */
 static AstNode *ast_parse_non_block_expr(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
@@ -1191,9 +1227,10 @@ static AstNode *ast_parse_non_block_expr(ParseContext *pc, int *token_index, boo
     if (variable_declaration_expr)
         return variable_declaration_expr;
 
-    AstNode *bool_or_expr = ast_parse_bool_or_expr(pc, token_index, false);
-    if (bool_or_expr)
-        return bool_or_expr;
+
+    AstNode *ass_expr = ast_parse_ass_expr(pc, token_index, false);
+    if (ass_expr)
+        return ass_expr;
 
     if (mandatory)
         ast_invalid_token_error(pc, token);
src/parser.hpp
@@ -101,6 +101,7 @@ struct AstNodeReturnExpr {
 
 struct AstNodeVariableDeclaration {
     Buf symbol;
+    bool is_const;
     // one or both of type and expr will be non null
     AstNode *type;
     AstNode *expr;
@@ -108,7 +109,7 @@ struct AstNodeVariableDeclaration {
 
 enum BinOpType {
     BinOpTypeInvalid,
-    // TODO: include assignment?
+    BinOpTypeAssign,
     BinOpTypeBoolOr,
     BinOpTypeBoolAnd,
     BinOpTypeCmpEq,
src/semantic_info.hpp
@@ -42,6 +42,7 @@ struct LabelTableEntry {
     AstNode *label_node;
     LLVMBasicBlockRef basic_block;
     bool used;
+    bool entered_from_fallthrough;
 };
 
 struct FnTableEntry {
@@ -116,6 +117,8 @@ struct LocalVariableTableEntry {
     Buf name;
     TypeTableEntry *type;
     LLVMValueRef value_ref;
+    bool is_const;
+    AstNode *decl_node;
 };
 
 struct BlockContext {
@@ -137,6 +140,7 @@ struct FnDefNode {
     TypeTableEntry *implicit_return_type;
     BlockContext *block_context;
     bool skip;
+    ZigList<BlockContext *> all_block_contexts;
 };
 
 struct ExprNode {
@@ -146,12 +150,17 @@ struct ExprNode {
     BlockContext *block_context;
 };
 
+struct AssignNode {
+    LocalVariableTableEntry *var_entry;
+};
+
 struct CodeGenNode {
     union {
         TypeNode type_node; // for NodeTypeType
         FnDefNode fn_def_node; // for NodeTypeFnDef
         FnProtoNode fn_proto_node; // for NodeTypeFnProto
         LabelTableEntry *label_entry; // for NodeTypeGoto and NodeTypeLabel
+        AssignNode assign_node; // for NodeTypeBinOpExpr where op is BinOpTypeAssign
     } data;
     ExprNode expr_node; // for all the expression nodes
 };
test/run_tests.cpp
@@ -331,6 +331,27 @@ fn void_fun(a : i32, b : void, c : i32) {
     return vv;
 }
     )SOURCE", "OK\n");
+
+    add_simple_case("void parameters", R"SOURCE(
+#link("c")
+extern {
+    fn puts(s: *const u8) -> i32;
+    fn exit(code: i32) -> unreachable;
+}
+
+export fn _start() -> unreachable {
+    let mut i = 0;
+loop_start:
+    if i == 3 {
+        goto done;
+    }
+    puts("loop");
+    i = i + 1;
+    goto loop_start;
+done:
+    exit(0);
+}
+    )SOURCE", "loop\nloop\nloop\n");
 }
 
 static void add_compile_failure_test_cases(void) {
@@ -467,10 +488,29 @@ export fn f(a : void) {}
     )SOURCE", 1, ".tmp_source.zig:2:17: error: parameter of type 'void' not allowed on exported functions");
 
     add_compile_fail_case("unused label", R"SOURCE(
-export fn f() {
+fn f() {
 a_label:
 }
     )SOURCE", 1, ".tmp_source.zig:3:1: error: label 'a_label' defined but not used");
+
+    add_compile_fail_case("expected bare identifier", R"SOURCE(
+fn f() {
+    3 = 3;
+}
+    )SOURCE", 1, ".tmp_source.zig:3:5: error: expected a bare identifier");
+
+    add_compile_fail_case("assign to constant variable", R"SOURCE(
+fn f() {
+    let a = 3;
+    a = 4;
+}
+    )SOURCE", 1, ".tmp_source.zig:4:5: error: cannot assign to constant variable");
+
+    add_compile_fail_case("use of undeclared identifier", R"SOURCE(
+fn f() {
+    b = 3;
+}
+    )SOURCE", 1, ".tmp_source.zig:3:5: error: use of undeclared identifier 'b'");
 }
 
 static void print_compiler_invocation(TestCase *test_case, Buf *zig_stderr) {