Commit 431170d981

Andrew Kelley <superjoe30@gmail.com>
2015-12-22 21:22:40
codegen: fix struct pointer field access
1 parent 437e9b9
Changed files (5)
doc/targets.md
@@ -0,0 +1,11 @@
+# How to Add Support For More Targets
+
+Create bootstrap code in std/bootstrap.zig and add conditional compilation
+logic. This code is responsible for the real executable entry point, calling
+main(argc, argv, env) and making the exit syscall when main returns.
+
+How to pass a byvalue struct parameter in the C calling convention is
+target-specific. Add logic for how to do function prototypes and function calls
+for the target when an exported or external function has a byvalue struct.
+
+Write the target-specific code in std.zig.
example/structs/structs.zig
@@ -2,7 +2,7 @@ export executable "structs";
 
 use "std.zig";
 
-export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
+pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     var foo : Foo;
 
     foo.a = foo.a + 1;
@@ -30,10 +30,14 @@ struct Foo {
 }
 
 struct Node {
-    val: i32,
+    val: Val,
     next: &Node,
 }
 
+struct Val {
+    x: i32,
+}
+
 fn test_foo(foo : Foo) {
     if !foo.b {
         print_str("BAD\n");
@@ -46,13 +50,15 @@ fn modify_foo(foo : &Foo) {
 
 fn test_point_to_self() {
     var root : Node;
-    root.val = 1;
+    root.val.x = 1;
 
     var node : Node;
     node.next = &root;
-    node.val = 2;
+    node.val.x = 2;
+
+    root.next = &node;
 
-    if node.next.val != 1 {
+    if node.next.next.next.val.x != 1 {
         print_str("BAD\n");
     }
 }
src/analyze.cpp
@@ -273,10 +273,14 @@ static void preview_function_labels(CodeGen *g, AstNode *node, FnTableEntry *fn_
 static void resolve_struct_type(CodeGen *g, ImportTableEntry *import, TypeTableEntry *struct_type) {
     assert(struct_type->id == TypeTableEntryIdStruct);
 
+    if (struct_type->data.structure.fields) {
+        // we already resolved this type. skip
+        return;
+    }
+
     AstNode *decl_node = struct_type->data.structure.decl_node;
 
     assert(struct_type->di_type);
-    assert(!struct_type->data.structure.fields);
 
     int field_count = decl_node->data.struct_decl.fields.length;
     struct_type->data.structure.field_count = field_count;
src/codegen.cpp
@@ -63,6 +63,8 @@ void codegen_set_libc_path(CodeGen *g, Buf *libc_path) {
 }
 
 static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node);
+static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *expr_node, AstNode *node, TypeTableEntry **out_type_entry);
+static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lvalue);
     
 
 static TypeTableEntry *get_type_for_type_node(CodeGen *g, AstNode *type_node) {
@@ -192,6 +194,7 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
 static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeArrayAccessExpr);
 
+    // TODO gen_lvalue
     LLVMValueRef array_ref_value = gen_expr(g, node->data.array_access_expr.array_ref_expr);
     LLVMValueRef subscript_value = gen_expr(g, node->data.array_access_expr.subscript);
 
@@ -209,15 +212,34 @@ static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) {
 static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node, TypeTableEntry **out_type_entry) {
     assert(node->type == NodeTypeFieldAccessExpr);
 
-    //TypeTableEntry *struct_type = get_expr_type(node->data.field_access_expr.struct_expr);
-    LLVMValueRef struct_ptr = gen_expr(g, node->data.field_access_expr.struct_expr);
-    assert(struct_ptr);
+    AstNode *struct_expr_node = node->data.field_access_expr.struct_expr;
 
-    /*
-    if (struct_type->id == TypeTableEntryIdPointer) {
-        zig_panic("TODO pointer field struct access");
+    LLVMValueRef struct_ptr;
+    if (struct_expr_node->type == NodeTypeSymbol) {
+        VariableTableEntry *var = find_variable(struct_expr_node->codegen_node->expr_node.block_context,
+                &struct_expr_node->data.symbol);
+        assert(var);
+
+        if (var->is_ptr && var->type->id == TypeTableEntryIdPointer) {
+            add_debug_source_node(g, node);
+            struct_ptr = LLVMBuildLoad(g->builder, var->value_ref, "");
+        } else {
+            struct_ptr = var->value_ref;
+        }
+    } else if (struct_expr_node->type == NodeTypeFieldAccessExpr) {
+        struct_ptr = gen_field_access_expr(g, struct_expr_node, true);
+        TypeTableEntry *field_type = get_expr_type(struct_expr_node);
+        if (field_type->id == TypeTableEntryIdPointer) {
+            // we have a double pointer so we must dereference it once
+            add_debug_source_node(g, node);
+            struct_ptr = LLVMBuildLoad(g->builder, struct_ptr, "");
+        }
+    } else {
+        struct_ptr = gen_expr(g, struct_expr_node);
     }
-    */
+
+    assert(LLVMGetTypeKind(LLVMTypeOf(struct_ptr)) == LLVMPointerTypeKind);
+    assert(LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(struct_ptr))) == LLVMStructTypeKind);
 
     FieldAccessNode *codegen_field_access = &node->codegen_node->data.field_access_node;
 
@@ -229,15 +251,20 @@ static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node, TypeTableEntry **ou
     return LLVMBuildStructGEP(g->builder, struct_ptr, codegen_field_access->field_index, "");
 }
 
-static LLVMValueRef gen_array_access_expr(CodeGen *g, AstNode *node) {
+static LLVMValueRef gen_array_access_expr(CodeGen *g, AstNode *node, bool is_lvalue) {
     assert(node->type == NodeTypeArrayAccessExpr);
 
     LLVMValueRef ptr = gen_array_ptr(g, node);
-    add_debug_source_node(g, node);
-    return LLVMBuildLoad(g->builder, ptr, "");
+
+    if (is_lvalue) {
+        return ptr;
+    } else {
+        add_debug_source_node(g, node);
+        return LLVMBuildLoad(g->builder, ptr, "");
+    }
 }
 
-static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node) {
+static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lvalue) {
     assert(node->type == NodeTypeFieldAccessExpr);
 
     TypeTableEntry *struct_type = get_expr_type(node->data.field_access_expr.struct_expr);
@@ -255,21 +282,26 @@ static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node) {
     {
         TypeTableEntry *type_entry;
         LLVMValueRef ptr = gen_field_ptr(g, node, &type_entry);
-        return LLVMBuildLoad(g->builder, ptr, "");
+        if (is_lvalue) {
+            return ptr;
+        } else {
+            add_debug_source_node(g, node);
+            return LLVMBuildLoad(g->builder, ptr, "");
+        }
     } else {
         zig_panic("gen_field_access_expr bad struct type");
     }
 }
 
-static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *parent_node, AstNode *node,
+static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *expr_node, AstNode *node,
         TypeTableEntry **out_type_entry)
 {
     LLVMValueRef target_ref;
 
     if (node->type == NodeTypeSymbol) {
-        VariableTableEntry *var = find_variable(parent_node->codegen_node->expr_node.block_context,
+        VariableTableEntry *var = find_variable(expr_node->codegen_node->expr_node.block_context,
                 &node->data.symbol);
-
+        assert(var);
         // semantic checking ensures no variables are constant
         assert(!var->is_const);
 
@@ -631,6 +663,7 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
     AstNode *lhs_node = node->data.bin_op_expr.op1;
 
     TypeTableEntry *op1_type;
+
     LLVMValueRef target_ref = gen_lvalue(g, node, lhs_node, &op1_type);
 
     LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
@@ -957,9 +990,9 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
         case NodeTypeFnCallExpr:
             return gen_fn_call_expr(g, node);
         case NodeTypeArrayAccessExpr:
-            return gen_array_access_expr(g, node);
+            return gen_array_access_expr(g, node, false);
         case NodeTypeFieldAccessExpr:
-            return gen_field_access_expr(g, node);
+            return gen_field_access_expr(g, node, false);
         case NodeTypeUnreachable:
             add_debug_source_node(g, node);
             return LLVMBuildUnreachable(g->builder);
@@ -1153,7 +1186,7 @@ static void do_code_gen(CodeGen *g) {
         assert(proto_node->type == NodeTypeFnProto);
         AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
 
-        LLVMTypeRef ret_type = fn_proto_type_from_type_node(g, fn_proto->return_type);
+        LLVMTypeRef ret_type = get_type_for_type_node(g, fn_proto->return_type)->type_ref;
         int param_count = count_non_void_params(g, &fn_proto->params);
         LLVMTypeRef *param_types = allocate<LLVMTypeRef>(param_count);
         int gen_param_index = 0;
test/run_tests.cpp
@@ -573,6 +573,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
     if foo.c != 100 {
         print_str("BAD\n");
     }
+    test_point_to_self();
     print_str("OK\n");
     return 0;
 }
@@ -588,6 +589,28 @@ fn test_foo(foo : Foo) {
 }
 fn test_mutation(foo : &Foo) {
     foo.c = 100;
+}
+struct Node {
+    val: Val,
+    next: &Node,
+}
+
+struct Val {
+    x: i32,
+}
+fn test_point_to_self() {
+    var root : Node;
+    root.val.x = 1;
+
+    var node : Node;
+    node.next = &root;
+    node.val.x = 2;
+
+    root.next = &node;
+
+    if node.next.next.next.val.x != 1 {
+        print_str("BAD\n");
+    }
 }
     )SOURCE", "OK\n");