Commit bd77bc749a

Andrew Kelley <superjoe30@gmail.com>
2015-12-13 06:55:29
structs are working
1 parent 0f02e29
Changed files (5)
example/structs/structs.zig
@@ -2,8 +2,6 @@ export executable "structs";
 
 use "std.zig";
 
-// Note: this example is not working because codegen is confused about
-// how byvalue structs which are in memory on the stack work
 export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
     let mut foo : Foo;
 
src/analyze.cpp
@@ -758,6 +758,47 @@ static bool is_op_allowed(TypeTableEntry *type, BinOpType op) {
     zig_unreachable();
 }
 
+static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+        TypeTableEntry *expected_type, AstNode *node)
+{
+    TypeTableEntry *wanted_type = resolve_type(g, node->data.cast_expr.type);
+    TypeTableEntry *actual_type = analyze_expression(g, import, context, nullptr, node->data.cast_expr.expr);
+
+    if (wanted_type->id == TypeTableEntryIdInvalid ||
+        actual_type->id == TypeTableEntryIdInvalid)
+    {
+        return g->builtin_types.entry_invalid;
+    }
+
+    CastNode *cast_node = &node->codegen_node->data.cast_node;
+
+    // special casing this for now, TODO think about casting and do a general solution
+    if (wanted_type == g->builtin_types.entry_isize &&
+        actual_type->id == TypeTableEntryIdPointer)
+    {
+        cast_node->op = CastOpPtrToInt;
+        return wanted_type;
+    } else if (wanted_type->id == TypeTableEntryIdInt &&
+                actual_type->id == TypeTableEntryIdInt)
+    {
+        cast_node->op = CastOpIntWidenOrShorten;
+        return wanted_type;
+    } else if (wanted_type == g->builtin_types.entry_string &&
+                actual_type->id == TypeTableEntryIdArray &&
+                actual_type->data.array.child_type == g->builtin_types.entry_u8)
+    {
+        cast_node->op = CastOpArrayToString;
+        context->cast_expr_alloca_list.append(node);
+        return wanted_type;
+    } else {
+        add_node_error(g, node,
+            buf_sprintf("invalid cast from type '%s' to '%s'",
+                buf_ptr(&actual_type->name),
+                buf_ptr(&wanted_type->name)));
+        return g->builtin_types.entry_invalid;
+    }
+}
+
 static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -1100,45 +1141,8 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
                 break;
             }
         case NodeTypeCastExpr:
-            {
-                TypeTableEntry *wanted_type = resolve_type(g, node->data.cast_expr.type);
-                TypeTableEntry *actual_type = analyze_expression(g, import, context, nullptr,
-                        node->data.cast_expr.expr);
-
-                if (wanted_type->id == TypeTableEntryIdInvalid ||
-                    actual_type->id == TypeTableEntryIdInvalid)
-                {
-                    return_type = g->builtin_types.entry_invalid;
-                    break;
-                }
-
-                CastNode *cast_node = &node->codegen_node->data.cast_node;
-
-                // special casing this for now, TODO think about casting and do a general solution
-                if (wanted_type == g->builtin_types.entry_isize &&
-                    actual_type->id == TypeTableEntryIdPointer)
-                {
-                    cast_node->op = CastOpPtrToInt;
-                    return_type = wanted_type;
-                } else if (wanted_type->id == TypeTableEntryIdInt &&
-                           actual_type->id == TypeTableEntryIdInt)
-                {
-                    cast_node->op = CastOpIntWidenOrShorten;
-                    return_type = wanted_type;
-                } else if (wanted_type == g->builtin_types.entry_string &&
-                           actual_type->id == TypeTableEntryIdArray &&
-                           actual_type->data.array.child_type == g->builtin_types.entry_u8)
-                {
-                    cast_node->op = CastOpArrayToString;
-                    return_type = wanted_type;
-                } else {
-                    add_node_error(g, node,
-                        buf_sprintf("TODO handle cast from '%s' to '%s'",
-                            buf_ptr(&actual_type->name), buf_ptr(&wanted_type->name)));
-                    return_type = g->builtin_types.entry_invalid;
-                }
-                break;
-            }
+            return_type = analyze_cast_expr(g, import, context, expected_type, node);
+            break;
         case NodeTypePrefixOpExpr:
             switch (node->data.prefix_op_expr.prefix_op) {
                 case PrefixOpBoolNot:
src/analyze.hpp
@@ -193,6 +193,7 @@ struct BlockContext {
     BlockContext *root; // always points to the BlockContext with the NodeTypeFnDef
     BlockContext *parent; // nullptr when this is the root
     HashMap<Buf *, LocalVariableTableEntry *, buf_hash, buf_eql_buf> variable_table;
+    ZigList<AstNode *> cast_expr_alloca_list;
     LLVMZigDIScope *di_scope;
 };
 
@@ -244,6 +245,9 @@ enum CastOp {
 
 struct CastNode {
     CastOp op;
+    // if op is CastOpArrayToString, this will be a pointer to
+    // the string struct on the stack
+    LLVMValueRef ptr;
 };
 
 struct CodeGenNode {
src/codegen.cpp
@@ -59,30 +59,31 @@ void codegen_set_out_name(CodeGen *g, Buf *out_name) {
 }
 
 static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node);
+    
 
-static LLVMTypeRef to_llvm_type(AstNode *type_node) {
+static TypeTableEntry *get_type_for_type_node(CodeGen *g, AstNode *type_node) {
     assert(type_node->type == NodeTypeType);
     assert(type_node->codegen_node);
     assert(type_node->codegen_node->data.type_node.entry);
-
-    return type_node->codegen_node->data.type_node.entry->type_ref;
+    return type_node->codegen_node->data.type_node.entry;
 }
 
-static LLVMZigDIType *to_llvm_debug_type(AstNode *type_node) {
-    assert(type_node->type == NodeTypeType);
-    assert(type_node->codegen_node);
-    assert(type_node->codegen_node->data.type_node.entry);
+static LLVMTypeRef fn_proto_type_from_type_node(CodeGen *g, AstNode *type_node) {
+    TypeTableEntry *type_entry = get_type_for_type_node(g, type_node);
 
-    return type_node->codegen_node->data.type_node.entry->di_type;
+    if (type_entry->id == TypeTableEntryIdStruct || type_entry->id == TypeTableEntryIdArray) {
+        return get_pointer_to_type(g, type_entry, true)->type_ref;
+    } else {
+        return type_entry->type_ref;
+    }
 }
 
-static TypeTableEntry *get_type_for_type_node(CodeGen *g, AstNode *type_node) {
-    assert(type_node->type == NodeTypeType);
-    assert(type_node->codegen_node);
-    assert(type_node->codegen_node->data.type_node.entry);
-    return type_node->codegen_node->data.type_node.entry;
+static LLVMZigDIType *to_llvm_debug_type(CodeGen *g, AstNode *type_node) {
+    TypeTableEntry *type_entry = get_type_for_type_node(g, type_node);
+    return type_entry->di_type;
 }
 
+
 static bool type_is_unreachable(CodeGen *g, AstNode *type_node) {
     return get_type_for_type_node(g, type_node) == g->builtin_types.entry_unreachable;
 }
@@ -198,20 +199,6 @@ static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) {
     return LLVMBuildInBoundsGEP(g->builder, array_ref_value, indices, 2, "");
 }
 
-static LLVMValueRef gen_field_val(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeFieldAccessExpr);
-
-    LLVMValueRef struct_val = gen_expr(g, node->data.field_access_expr.struct_expr);
-    assert(struct_val);
-
-    FieldAccessNode *codegen_field_access = &node->codegen_node->data.field_access_node;
-    assert(codegen_field_access->field_index >= 0);
-
-    add_debug_source_node(g, node);
-    return LLVMBuildExtractValue(g->builder, struct_val, codegen_field_access->field_index, "");
-}
-
-/*
 static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeFieldAccessExpr);
 
@@ -223,9 +210,9 @@ static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node) {
 
     assert(codegen_field_access->field_index >= 0);
 
+    add_debug_source_node(g, node);
     return LLVMBuildStructGEP(g->builder, struct_ptr, codegen_field_access->field_index, "");
 }
-*/
 
 static LLVMValueRef gen_array_access_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeArrayAccessExpr);
@@ -249,11 +236,8 @@ static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node) {
             zig_panic("gen_field_access_expr bad array field");
         }
     } else if (struct_type->id == TypeTableEntryIdStruct) {
-        /*
         LLVMValueRef ptr = gen_field_ptr(g, node);
         return LLVMBuildLoad(g->builder, ptr, "");
-        */
-        return gen_field_val(g, node);
     } else {
         zig_panic("gen_field_access_expr bad struct type");
     }
@@ -311,14 +295,19 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
             }
         case CastOpArrayToString:
             {
-                LLVMValueRef struct_vals[] = {
-                    expr_val,
-                    LLVMConstInt(g->builtin_types.entry_usize->type_ref, actual_type->data.array.len, false)
-                };
-                unsigned field_count = g->builtin_types.entry_string->data.structure.field_count;
-                assert(field_count == 2);
-                return LLVMConstNamedStruct(g->builtin_types.entry_string->type_ref,
-                        struct_vals, field_count);
+                assert(cast_node->ptr);
+
+                add_debug_source_node(g, node);
+
+                LLVMValueRef ptr_ptr = LLVMBuildStructGEP(g->builder, cast_node->ptr, 0, "");
+                LLVMBuildStore(g->builder, expr_val, ptr_ptr);
+
+                LLVMValueRef len_ptr = LLVMBuildStructGEP(g->builder, cast_node->ptr, 1, "");
+                LLVMValueRef len_val = LLVMConstInt(g->builtin_types.entry_usize->type_ref,
+                        actual_type->data.array.len, false);
+                LLVMBuildStore(g->builder, len_val, len_ptr);
+
+                return cast_node->ptr;
             }
     }
     zig_unreachable();
@@ -580,6 +569,8 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
         assert(array_type->id == TypeTableEntryIdArray);
         op1_type = array_type->data.array.child_type;
         target_ref = gen_array_ptr(g, lhs_node);
+    } else if (lhs_node->type == NodeTypeFieldAccessExpr) {
+        target_ref = gen_field_ptr(g, lhs_node);
     } else {
         zig_panic("bad assign target");
     }
@@ -717,6 +708,7 @@ static LLVMValueRef gen_if_expr(CodeGen *g, AstNode *node) {
 static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *implicit_return_type) {
     assert(block_node->type == NodeTypeBlock);
 
+    BlockContext *old_block_context = g->cur_block_context;
     g->cur_block_context = block_node->codegen_node->data.block_node.block_context;
 
     LLVMValueRef return_value;
@@ -726,6 +718,7 @@ static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *i
     }
 
     if (implicit_return_type) {
+        add_debug_source_node(g, block_node);
         if (implicit_return_type == g->builtin_types.entry_void) {
             LLVMBuildRetVoid(g->builder);
         } else if (implicit_return_type != g->builtin_types.entry_unreachable) {
@@ -733,6 +726,8 @@ static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *i
         }
     }
 
+    g->cur_block_context = old_block_context;
+
     return return_value;
 }
 
@@ -934,6 +929,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
                 } else if (variable->is_ptr) {
                     if (variable->type->id == TypeTableEntryIdArray) {
                         return variable->value_ref;
+                    } else if (variable->type->id == TypeTableEntryIdStruct) {
+                        return variable->value_ref;
                     } else {
                         add_debug_source_node(g, node);
                         return LLVMBuildLoad(g->builder, variable->value_ref, "");
@@ -994,12 +991,12 @@ static LLVMZigDISubroutineType *create_di_function_type(CodeGen *g, AstNodeFnPro
         LLVMZigDIFile *di_file)
 {
     LLVMZigDIType **types = allocate<LLVMZigDIType*>(1 + fn_proto->params.length);
-    types[0] = to_llvm_debug_type(fn_proto->return_type);
+    types[0] = to_llvm_debug_type(g, fn_proto->return_type);
     int types_len = fn_proto->params.length + 1;
     for (int i = 0; i < fn_proto->params.length; i += 1) {
         AstNode *param_node = fn_proto->params.at(i);
         assert(param_node->type == NodeTypeParamDecl);
-        LLVMZigDIType *param_type = to_llvm_debug_type(param_node->data.param_decl.type);
+        LLVMZigDIType *param_type = to_llvm_debug_type(g, param_node->data.param_decl.type);
         types[i + 1] = param_type;
     }
     return LLVMZigCreateSubroutineType(g->dbuilder, di_file, types, types_len, 0);
@@ -1026,7 +1023,7 @@ static void do_code_gen(CodeGen *g) {
         assert(proto_node->type == NodeTypeFnProto);
         AstNodeFnProto *fn_proto = &proto_node->data.fn_proto;
 
-        LLVMTypeRef ret_type = to_llvm_type(fn_proto->return_type);
+        LLVMTypeRef ret_type = fn_proto_type_from_type_node(g, fn_proto->return_type);
         int param_count = count_non_void_params(g, &fn_proto->params);
         LLVMTypeRef *param_types = allocate<LLVMTypeRef>(param_count);
         int gen_param_index = 0;
@@ -1036,7 +1033,7 @@ static void do_code_gen(CodeGen *g) {
             if (is_param_decl_type_void(g, param_node))
                 continue;
             AstNode *type_node = param_node->data.param_decl.type;
-            param_types[gen_param_index] = to_llvm_type(type_node);
+            param_types[gen_param_index] = fn_proto_type_from_type_node(g, type_node);
             gen_param_index += 1;
         }
         LLVMTypeRef function_type = LLVMFunctionType(ret_type, param_types, param_count, fn_proto->is_var_args);
@@ -1061,8 +1058,8 @@ static void do_code_gen(CodeGen *g) {
     }
 
     // Generate function definitions.
-    for (int i = 0; i < g->fn_defs.length; i += 1) {
-        FnTableEntry *fn_table_entry = g->fn_defs.at(i);
+    for (int fn_i = 0; fn_i < g->fn_defs.length; fn_i += 1) {
+        FnTableEntry *fn_table_entry = g->fn_defs.at(fn_i);
         ImportTableEntry *import = fn_table_entry->import_entry;
         AstNode *fn_def_node = fn_table_entry->fn_def_node;
         LLVMValueRef fn = fn_table_entry->fn_value;
@@ -1101,8 +1098,8 @@ static void do_code_gen(CodeGen *g) {
         LLVMGetParams(fn, params);
 
         int non_void_index = 0;
-        for (int i = 0; i < fn_proto->params.length; i += 1) {
-            AstNode *param_decl = fn_proto->params.at(i);
+        for (int param_i = 0; param_i < fn_proto->params.length; param_i += 1) {
+            AstNode *param_decl = fn_proto->params.at(param_i);
             assert(param_decl->type == NodeTypeParamDecl);
             if (is_param_decl_type_void(g, param_decl))
                 continue;
@@ -1115,8 +1112,8 @@ static void do_code_gen(CodeGen *g) {
 
         // Set up debug info for blocks and variables and
         // allocate all local variables
-        for (int i = 0; i < codegen_fn_def->all_block_contexts.length; i += 1) {
-            BlockContext *block_context = codegen_fn_def->all_block_contexts.at(i);
+        for (int bc_i = 0; bc_i < codegen_fn_def->all_block_contexts.length; bc_i += 1) {
+            BlockContext *block_context = codegen_fn_def->all_block_contexts.at(bc_i);
 
             if (block_context->parent) {
                 LLVMZigDILexicalBlock *di_block = LLVMZigCreateLexicalBlock(g->dbuilder,
@@ -1157,6 +1154,16 @@ static void do_code_gen(CodeGen *g) {
                         import->di_file, var->decl_node->line + 1,
                         var->type->di_type, !g->strip_debug_symbols, 0, arg_no);
             }
+
+            // allocate structs which are the result of casts
+            for (int cea_i = 0; cea_i < block_context->cast_expr_alloca_list.length; cea_i += 1) {
+                AstNode *cast_expr_node = block_context->cast_expr_alloca_list.at(cea_i);
+                assert(cast_expr_node->type == NodeTypeCastExpr);
+                CastNode *cast_codegen = &cast_expr_node->codegen_node->data.cast_node;
+                TypeTableEntry *type_entry = get_type_for_type_node(g, cast_expr_node->data.cast_expr.type);
+                add_debug_source_node(g, cast_expr_node);
+                cast_codegen->ptr = LLVMBuildAlloca(g->builder, type_entry->type_ref, "");
+            }
         }
 
         TypeTableEntry *implicit_return_type = codegen_fn_def->implicit_return_type;
test/run_tests.cpp
@@ -477,6 +477,27 @@ export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
 }
     )SOURCE", "OK\n");
 
+    add_simple_case("structs", R"SOURCE(
+use "std.zig";
+
+export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 {
+    let mut foo : Foo;
+    foo.a = foo.a + 1;
+    foo.b = foo.a == 1;
+    test_foo(foo);
+    return 0;
+}
+struct Foo {
+    a : i32,
+    b : bool,
+    c : f32,
+}
+fn test_foo(foo : Foo) {
+    if foo.b {
+        print_str("OK\n" as string);
+    }
+}
+    )SOURCE", "OK\n");
 }
 
 static void add_compile_failure_test_cases(void) {