Commit fb85d3a0a2

Andrew Kelley <superjoe30@gmail.com>
2016-01-26 00:37:45
codegen: get rid of cur_block_context
1 parent deb3586
src/all_types.hpp
@@ -86,9 +86,6 @@ struct ConstExprValue {
 
 struct Expr {
     TypeTableEntry *type_entry;
-    // the context in which this expression is evaluated.
-    // for blocks, this points to the containing scope, not the block's own scope for its children.
-    BlockContext *block_context;
 
     LLVMValueRef const_llvm_val;
     ConstExprValue const_val;
@@ -254,6 +251,7 @@ struct AstNodeVariableDeclaration {
     // populated by semantic analyzer
     TopLevelDecl top_level_decl;
     Expr resolved_expr;
+    VariableTableEntry *variable;
 };
 
 struct AstNodeErrorValueDecl {
@@ -440,7 +438,6 @@ struct AstNodeIfVarExpr {
 
     // populated by semantic analyzer
     TypeTableEntry *type;
-    BlockContext *block_context;
     Expr resolved_expr;
 };
 
@@ -464,7 +461,6 @@ struct AstNodeForExpr {
     // populated by semantic analyzer
     bool contains_break;
     Expr resolved_expr;
-    BlockContext *block_context;
     VariableTableEntry *elem_var;
     VariableTableEntry *index_var;
 };
@@ -684,6 +680,9 @@ struct AstNode {
     uint32_t create_index; // for determinism purposes
     ImportTableEntry *owner;
     AstNode **parent_field; // for AST rewriting
+    // the context in which this expression/node is evaluated.
+    // for blocks, this points to the containing scope, not the block's own scope for its children.
+    BlockContext *block_context;
     union {
         AstNodeRoot root;
         AstNodeRootExportDecl root_export_decl;
@@ -1007,8 +1006,6 @@ struct CodeGen {
     OutType out_type;
     FnTableEntry *cur_fn;
     LLVMValueRef cur_ret_ptr;
-    // TODO remove this in favor of get_resolved_expr(expr_node)->context
-    BlockContext *cur_block_context;
     ZigList<LLVMBasicBlockRef> break_block_stack;
     ZigList<LLVMBasicBlockRef> continue_block_stack;
     bool c_stdint_used;
src/analyze.cpp
@@ -1663,6 +1663,8 @@ static TypeTableEntry *analyze_container_init_expr(CodeGen *g, ImportTableEntry
             AstNode *val_field_node = container_init_expr->entries.at(i);
             assert(val_field_node->type == NodeTypeStructValueField);
 
+            val_field_node->block_context = context;
+
             TypeStructField *type_field = find_struct_type_field(container_type,
                     &val_field_node->data.struct_val_field.name);
 
@@ -2158,6 +2160,7 @@ static TypeTableEntry *analyze_lvalue(CodeGen *g, ImportTableEntry *import, Bloc
         AstNode *lhs_node, LValPurpose purpose, bool is_ptr_const)
 {
     TypeTableEntry *expected_rhs_type = nullptr;
+    lhs_node->block_context = block_context;
     if (lhs_node->type == NodeTypeSymbol) {
         Buf *name = &lhs_node->data.symbol_expr.symbol;
         if (purpose == LValPurposeAddressOf) {
@@ -2520,6 +2523,7 @@ static TypeTableEntry *analyze_unwrap_error_expr(CodeGen *g, ImportTableEntry *i
         BlockContext *child_context;
         if (var_node) {
             child_context = new_block_context(node, parent_context);
+            var_node->block_context = child_context;
             Buf *var_name = &var_node->data.symbol_expr.symbol;
             node->data.unwrap_err_expr.var = add_local_var(g, var_node, child_context, var_name,
                     g->builtin_types.entry_pure_error, true);
@@ -2601,6 +2605,8 @@ static VariableTableEntry *analyze_variable_declaration_raw(CodeGen *g, ImportTa
     VariableTableEntry *var = add_local_var(g, source_node, context,
             &variable_declaration->symbol, type, is_const);
 
+    variable_declaration->variable = var;
+
 
     bool is_pub = (variable_declaration->visib_mod != VisibModPrivate);
     if (is_pub) {
@@ -2785,15 +2791,16 @@ static TypeTableEntry *analyze_for_expr(CodeGen *g, ImportTableEntry *import, Bl
     }
 
     BlockContext *child_context = new_block_context(node, context);
-    node->data.for_expr.block_context = child_context;
 
     AstNode *elem_var_node = node->data.for_expr.elem_node;
+    elem_var_node->block_context = child_context;
     Buf *elem_var_name = &elem_var_node->data.symbol_expr.symbol;
     node->data.for_expr.elem_var = add_local_var(g, elem_var_node, child_context, elem_var_name, child_type, true);
 
     AstNode *index_var_node = node->data.for_expr.index_node;
     if (index_var_node) {
         Buf *index_var_name = &index_var_node->data.symbol_expr.symbol;
+        index_var_node->block_context = child_context;
         node->data.for_expr.index_var = add_local_var(g, index_var_node, child_context, index_var_name,
                 g->builtin_types.entry_isize, true);
     } else {
@@ -2872,7 +2879,6 @@ static TypeTableEntry *analyze_if_var_expr(CodeGen *g, ImportTableEntry *import,
     assert(node->type == NodeTypeIfVarExpr);
 
     BlockContext *child_context = new_block_context(node, context);
-    node->data.if_var_expr.block_context = child_context;
 
     analyze_variable_declaration_raw(g, import, child_context, node, &node->data.if_var_expr.var_decl, true);
 
@@ -3412,6 +3418,7 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
     }
 
     if (fn_ref_expr->type == NodeTypeFieldAccessExpr) {
+        fn_ref_expr->block_context = context;
         AstNode *first_param_expr = fn_ref_expr->data.field_access_expr.struct_expr;
         TypeTableEntry *struct_type = analyze_expression(g, import, context, nullptr, first_param_expr);
         Buf *name = &fn_ref_expr->data.field_access_expr.field_name;
@@ -3724,6 +3731,7 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
             if (var_node) {
                 assert(var_node->type == NodeTypeSymbol);
                 Buf *var_name = &var_node->data.symbol_expr.symbol;
+                var_node->block_context = child_context;
                 prong_node->data.switch_prong.var = add_local_var(g, var_node, child_context, var_name,
                         var_type, true);
             }
@@ -3802,6 +3810,7 @@ static TypeTableEntry *analyze_block_expr(CodeGen *g, ImportTableEntry *import,
     for (int i = 0; i < node->data.block.statements.length; i += 1) {
         AstNode *child = node->data.block.statements.at(i);
         if (child->type == NodeTypeLabel) {
+            child->block_context = child_context;
             LabelTableEntry *label_entry = child->data.label.label_entry;
             assert(label_entry);
             label_entry->entered_from_fallthrough = (return_type->id != TypeTableEntryIdUnreachable);
@@ -3987,7 +3996,7 @@ static TypeTableEntry *analyze_expression(CodeGen *g, ImportTableEntry *import,
 
     Expr *expr = get_resolved_expr(node);
     expr->type_entry = return_type;
-    expr->block_context = context;
+    node->block_context = context;
 
     add_global_const_expr(g, expr);
 
src/codegen.cpp
@@ -68,7 +68,7 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node);
 static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *expr_node, AstNode *node, TypeTableEntry **out_type_entry);
 static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lvalue);
 static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVariableDeclaration *var_decl,
-        BlockContext *block_context, bool unwrap_maybe, LLVMValueRef *init_val);
+        bool unwrap_maybe, LLVMValueRef *init_val);
 static LLVMValueRef gen_assign_raw(CodeGen *g, AstNode *source_node, BinOpType bin_op,
         LLVMValueRef target_ref, LLVMValueRef value,
         TypeTableEntry *op1_type, TypeTableEntry *op2_type);
@@ -82,10 +82,8 @@ static TypeTableEntry *get_type_for_type_node(AstNode *node) {
 }
 
 static void add_debug_source_node(CodeGen *g, AstNode *node) {
-    if (!g->cur_block_context)
-        return;
-    LLVMZigSetCurrentDebugLocation(g->builder, node->line + 1, node->column + 1,
-            g->cur_block_context->di_scope);
+    assert(node->block_context);
+    LLVMZigSetCurrentDebugLocation(g->builder, node->line + 1, node->column + 1, node->block_context->di_scope);
 }
 
 static TypeTableEntry *get_expr_type(AstNode *node) {
@@ -557,7 +555,7 @@ static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node, TypeTableEntry **ou
 
     LLVMValueRef struct_ptr;
     if (struct_expr_node->type == NodeTypeSymbol) {
-        VariableTableEntry *var = find_variable(get_resolved_expr(struct_expr_node)->block_context,
+        VariableTableEntry *var = find_variable(struct_expr_node->block_context,
                 &struct_expr_node->data.symbol_expr.symbol);
         assert(var);
 
@@ -745,7 +743,7 @@ static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *expr_node, AstNode *node,
     LLVMValueRef target_ref;
 
     if (node->type == NodeTypeSymbol) {
-        VariableTableEntry *var = find_variable(get_resolved_expr(expr_node)->block_context,
+        VariableTableEntry *var = find_variable(expr_node->block_context,
                 &node->data.symbol_expr.symbol);
         assert(var);
         // semantic checking ensures no variables are constant
@@ -1522,33 +1520,24 @@ static LLVMValueRef gen_if_var_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeIfVarExpr);
     assert(node->data.if_var_expr.var_decl.expr);
 
-    BlockContext *old_block_context = g->cur_block_context;
-    BlockContext *new_block_context = node->data.if_var_expr.block_context;
-
     LLVMValueRef init_val;
-    gen_var_decl_raw(g, node, &node->data.if_var_expr.var_decl, new_block_context, true, &init_val);
+    gen_var_decl_raw(g, node, &node->data.if_var_expr.var_decl, true, &init_val);
 
     // test if value is the maybe state
     add_debug_source_node(g, node);
     LLVMValueRef maybe_field_ptr = LLVMBuildStructGEP(g->builder, init_val, 1, "");
     LLVMValueRef cond_value = LLVMBuildLoad(g->builder, maybe_field_ptr, "");
 
-    g->cur_block_context = new_block_context;
-
     LLVMValueRef return_value = gen_if_bool_expr_raw(g, node, cond_value,
             node->data.if_var_expr.then_block,
             node->data.if_var_expr.else_node);
 
-    g->cur_block_context = old_block_context;
     return return_value;
 }
 
 static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *implicit_return_type) {
     assert(block_node->type == NodeTypeBlock);
 
-    BlockContext *old_block_context = g->cur_block_context;
-    g->cur_block_context = block_node->data.block.block_context;
-
     LLVMValueRef return_value;
     for (int i = 0; i < block_node->data.block.statements.length; i += 1) {
         AstNode *statement_node = block_node->data.block.statements.at(i);
@@ -1556,12 +1545,10 @@ static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *i
     }
 
     if (implicit_return_type && implicit_return_type->id != TypeTableEntryIdUnreachable) {
-        gen_return(g, block_node, return_value);
+        return gen_return(g, block_node, return_value);
+    } else {
+        return return_value;
     }
-
-    g->cur_block_context = old_block_context;
-
-    return return_value;
 }
 
 static int find_asm_index(CodeGen *g, AstNode *node, AsmToken *tok) {
@@ -1646,9 +1633,7 @@ static LLVMValueRef gen_asm_expr(CodeGen *g, AstNode *node) {
         }
 
         if (!is_return) {
-            VariableTableEntry *variable = find_variable(
-                    get_resolved_expr(node)->block_context,
-                    &asm_output->variable_name);
+            VariableTableEntry *variable = find_variable( node->block_context, &asm_output->variable_name);
             assert(variable);
             param_types[param_index] = LLVMTypeOf(variable->value_ref);
             param_values[param_index] = variable->value_ref;
@@ -1763,13 +1748,10 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
     assert(node->data.while_expr.condition);
     assert(node->data.while_expr.body);
 
-    BlockContext *old_block_context = g->cur_block_context;
-
     bool condition_always_true = node->data.while_expr.condition_always_true;
     bool contains_break = node->data.while_expr.contains_break;
     if (condition_always_true) {
         // generate a forever loop
-        g->cur_block_context = node->data.while_expr.block_context;
 
         LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody");
         LLVMBasicBlockRef end_block = nullptr;
@@ -1806,7 +1788,6 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
         LLVMBuildBr(g->builder, cond_block);
 
         LLVMPositionBuilderAtEnd(g->builder, cond_block);
-        g->cur_block_context = old_block_context;
         LLVMValueRef cond_val = gen_expr(g, node->data.while_expr.condition);
         add_debug_source_node(g, node->data.while_expr.condition);
         LLVMBuildCondBr(g->builder, cond_val, body_block, end_block);
@@ -1814,7 +1795,6 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
         LLVMPositionBuilderAtEnd(g->builder, body_block);
         g->break_block_stack.append(end_block);
         g->continue_block_stack.append(cond_block);
-        g->cur_block_context = node->data.while_expr.block_context;
         gen_expr(g, node->data.while_expr.body);
         g->break_block_stack.pop();
         g->continue_block_stack.pop();
@@ -1826,7 +1806,6 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
         LLVMPositionBuilderAtEnd(g->builder, end_block);
     }
 
-    g->cur_block_context = old_block_context;
     return nullptr;
 }
 
@@ -1845,8 +1824,6 @@ static LLVMValueRef gen_for_expr(CodeGen *g, AstNode *node) {
     LLVMValueRef index_ptr = index_var->value_ref;
     LLVMValueRef one_const = LLVMConstInt(g->builtin_types.entry_isize->type_ref, 1, false);
 
-    BlockContext *old_block_context = g->cur_block_context;
-
     LLVMBasicBlockRef cond_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "ForCond");
     LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "ForBody");
     LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "ForEnd");
@@ -1884,7 +1861,6 @@ static LLVMValueRef gen_for_expr(CodeGen *g, AstNode *node) {
             elem_var->type, child_type);
     g->break_block_stack.append(end_block);
     g->continue_block_stack.append(cond_block);
-    g->cur_block_context = node->data.for_expr.block_context;
     gen_expr(g, node->data.for_expr.body);
     g->break_block_stack.pop();
     g->continue_block_stack.pop();
@@ -1896,7 +1872,6 @@ static LLVMValueRef gen_for_expr(CodeGen *g, AstNode *node) {
     }
 
     LLVMPositionBuilderAtEnd(g->builder, end_block);
-    g->cur_block_context = old_block_context;
     return nullptr;
 }
 
@@ -1917,9 +1892,9 @@ static LLVMValueRef gen_continue(CodeGen *g, AstNode *node) {
 }
 
 static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVariableDeclaration *var_decl,
-        BlockContext *block_context, bool unwrap_maybe, LLVMValueRef *init_value)
+        bool unwrap_maybe, LLVMValueRef *init_value)
 {
-    VariableTableEntry *variable = find_variable(block_context, &var_decl->symbol);
+    VariableTableEntry *variable = var_decl->variable;
 
     assert(variable);
     assert(variable->is_ptr);
@@ -2010,7 +1985,7 @@ static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVa
     }
 
     LLVMZigDILocation *debug_loc = LLVMZigGetDebugLoc(source_node->line + 1, source_node->column + 1,
-            g->cur_block_context->di_scope);
+            source_node->block_context->di_scope);
     LLVMZigInsertDeclareAtEnd(g->dbuilder, variable->value_ref, variable->di_loc_var, debug_loc,
             LLVMGetInsertBlock(g->builder));
     return nullptr;
@@ -2028,8 +2003,7 @@ static LLVMValueRef gen_var_decl_expr(CodeGen *g, AstNode *node) {
     }
 
     LLVMValueRef init_val;
-    return gen_var_decl_raw(g, node, &node->data.variable_declaration,
-            get_resolved_expr(node)->block_context, false, &init_val);
+    return gen_var_decl_raw(g, node, &node->data.variable_declaration, false, &init_val);
 }
 
 static LLVMValueRef gen_symbol(CodeGen *g, AstNode *node) {
@@ -2498,8 +2472,6 @@ static void do_code_gen(CodeGen *g) {
                 block_context->di_scope = LLVMZigLexicalBlockToScope(di_block);
             }
 
-            g->cur_block_context = block_context;
-
             for (int var_i = 0; var_i < block_context->variable_list.length; var_i += 1) {
                 VariableTableEntry *var = block_context->variable_list.at(var_i);