Commit e21369a153

Andrew Kelley <superjoe30@gmail.com>
2015-12-23 11:19:22
codegen: support byvalue struct assignment
1 parent ebd7aeb
Changed files (5)
example/structs/structs.zig
@@ -19,6 +19,8 @@ pub fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
 
     test_point_to_self();
 
+    test_byval_assign();
+
     print_str("OK\n");
     return 0;
 }
@@ -62,3 +64,17 @@ fn test_point_to_self() {
         print_str("BAD\n");
     }
 }
+
+fn test_byval_assign() {
+    var foo1 : Foo;
+    var foo2 : Foo;
+
+    foo1.a = 1234;
+
+    if foo2.a != 0 { print_str("BAD\n"); }
+
+    foo2 = foo1;
+
+    if foo2.a != 1234 { print_str("BAD - byval assignment failed\n"); }
+
+}
src/analyze.cpp
@@ -276,8 +276,11 @@ static void resolve_struct_type(CodeGen *g, ImportTableEntry *import, TypeTableE
     AstNode *decl_node = struct_type->data.structure.decl_node;
 
     if (struct_type->data.structure.embedded_in_current) {
-        add_node_error(g, decl_node,
-                buf_sprintf("struct has infinite size"));
+        if (!struct_type->data.structure.reported_infinite_err) {
+            struct_type->data.structure.reported_infinite_err = true;
+            add_node_error(g, decl_node,
+                    buf_sprintf("struct has infinite size"));
+        }
         return;
     }
 
src/analyze.hpp
@@ -43,9 +43,11 @@ struct TypeTableEntryStruct {
     bool is_packed;
     int field_count;
     TypeStructField *fields;
+    uint64_t size_bytes;
 
     // set this flag temporarily to detect infinite loops
     bool embedded_in_current;
+    bool reported_infinite_err;
 };
 
 struct TypeTableEntryNumLit {
@@ -201,6 +203,7 @@ struct CodeGen {
     bool verbose;
     ErrColor err_color;
     ImportTableEntry *root_import;
+    LLVMValueRef memcpy_fn_val;
 };
 
 struct VariableTableEntry {
src/codegen.cpp
@@ -666,15 +666,36 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) {
 
     LLVMValueRef target_ref = gen_lvalue(g, node, lhs_node, &op1_type);
 
+    TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
+
     LLVMValueRef value = gen_expr(g, node->data.bin_op_expr.op2);
 
-    if (node->data.bin_op_expr.bin_op == BinOpTypeAssign) {
-        // value is ready as is
-    } else {
+    if (op1_type->id == TypeTableEntryIdStruct) {
+        assert(op2_type->id == TypeTableEntryIdStruct);
+        assert(op1_type == op2_type);
+        assert(node->data.bin_op_expr.bin_op == BinOpTypeAssign);
+
+        LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
+
+        add_debug_source_node(g, node);
+        LLVMValueRef src_ptr = LLVMBuildBitCast(g->builder, value, ptr_u8, "");
+        LLVMValueRef dest_ptr = LLVMBuildBitCast(g->builder, target_ref, ptr_u8, "");
+
+        LLVMValueRef params[] = {
+            dest_ptr, // dest pointer
+            src_ptr, // source pointer
+            LLVMConstInt(LLVMIntType(g->pointer_size_bytes * 8), op1_type->size_in_bits / 8, false), // byte count
+            LLVMConstInt(LLVMInt32Type(), op1_type->align_in_bits / 8, false), // align in bits
+            LLVMConstNull(LLVMInt1Type()), // is volatile
+        };
+
+        return LLVMBuildCall(g->builder, g->memcpy_fn_val, params, 5, "");
+    }
+
+    if (node->data.bin_op_expr.bin_op != BinOpTypeAssign) {
         add_debug_source_node(g, node->data.bin_op_expr.op1);
         LLVMValueRef left_value = LLVMBuildLoad(g->builder, target_ref, "");
 
-        TypeTableEntry *op2_type = get_expr_type(node->data.bin_op_expr.op2);
         value = gen_arithmetic_bin_op(g, left_value, value, op1_type, op2_type, node);
     }
 
@@ -1158,6 +1179,20 @@ static LLVMAttribute to_llvm_fn_attr(FnAttrId attr_id) {
 static void do_code_gen(CodeGen *g) {
     assert(!g->errors.length);
 
+    {
+        LLVMTypeRef param_types[] = {
+            LLVMPointerType(LLVMInt8Type(), 0),
+            LLVMPointerType(LLVMInt8Type(), 0),
+            LLVMIntType(g->pointer_size_bytes * 8),
+            LLVMInt32Type(),
+            LLVMInt1Type(),
+        };
+        LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
+        Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8);
+        g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
+        assert(LLVMGetIntrinsicID(g->memcpy_fn_val));
+    }
+
     // Generate module level variables
     for (int i = 0; i < g->global_vars.length; i += 1) {
         VariableTableEntry *var = g->global_vars.at(i);
@@ -1379,7 +1414,7 @@ static void define_builtin_types(CodeGen *g) {
         TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdBool);
         entry->type_ref = LLVMInt1Type();
         buf_init_from_str(&entry->name, "bool");
-        entry->size_in_bits = 1;
+        entry->size_in_bits = 8;
         entry->align_in_bits = 8;
         entry->di_type = LLVMZigCreateDebugBasicType(g->dbuilder, buf_ptr(&entry->name),
                 entry->size_in_bits, entry->align_in_bits,
test/run_tests.cpp
@@ -574,6 +574,7 @@ export fn main(argc : isize, argv : &&u8, env : &&u8) -> i32 {
         print_str("BAD\n");
     }
     test_point_to_self();
+    test_byval_assign();
     print_str("OK\n");
     return 0;
 }
@@ -611,6 +612,18 @@ fn test_point_to_self() {
     if node.next.next.next.val.x != 1 {
         print_str("BAD\n");
     }
+}
+fn test_byval_assign() {
+    var foo1 : Foo;
+    var foo2 : Foo;
+
+    foo1.a = 1234;
+
+    if foo2.a != 0 { print_str("BAD\n"); }
+
+    foo2 = foo1;
+
+    if foo2.a != 1234 { print_str("BAD - byval assignment failed\n"); }
 }
     )SOURCE", "OK\n");