Commit 6d9119fcd9

Andrew Kelley <superjoe30@gmail.com>
2016-01-09 10:16:54
add memcpy and memset intrinsics
1 parent bdca82e
src/analyze.cpp
@@ -1944,8 +1944,11 @@ static TypeTableEntry *analyze_while_expr(CodeGen *g, ImportTableEntry *import,
         if (resolved_type->id != TypeTableEntryIdInvalid) {
             assert(resolved_type->id == TypeTableEntryIdBool);
             bool constant_cond_value = number_literal.data.x_uint;
-            if (constant_cond_value && !node->codegen_node->data.while_node.contains_break) {
-                expr_return_type = g->builtin_types.entry_unreachable;
+            if (constant_cond_value) {
+                node->codegen_node->data.while_node.condition_always_true = true;
+                if (!node->codegen_node->data.while_node.contains_break) {
+                    expr_return_type = g->builtin_types.entry_unreachable;
+                }
             }
         }
     }
@@ -2085,13 +2088,74 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry
                         builtin_fn->param_count, actual_param_count));
         }
 
-        for (int i = 0; i < actual_param_count; i += 1) {
-            AstNode *child = node->data.fn_call_expr.params.at(i);
-            TypeTableEntry *expected_param_type = builtin_fn->param_types[i];
-            analyze_expression(g, import, context, expected_param_type, child);
-        }
+        switch (builtin_fn->id) {
+            case BuiltinFnIdInvalid:
+                zig_unreachable();
+            case BuiltinFnIdArithmeticWithOverflow:
+                for (int i = 0; i < actual_param_count; i += 1) {
+                    AstNode *child = node->data.fn_call_expr.params.at(i);
+                    TypeTableEntry *expected_param_type = builtin_fn->param_types[i];
+                    analyze_expression(g, import, context, expected_param_type, child);
+                }
+                return builtin_fn->return_type;
+            case BuiltinFnIdMemcpy:
+                {
+                    AstNode *dest_node = node->data.fn_call_expr.params.at(0);
+                    AstNode *src_node = node->data.fn_call_expr.params.at(1);
+                    AstNode *len_node = node->data.fn_call_expr.params.at(2);
+                    TypeTableEntry *dest_type = analyze_expression(g, import, context, nullptr, dest_node);
+                    TypeTableEntry *src_type = analyze_expression(g, import, context, nullptr, src_node);
+                    analyze_expression(g, import, context, builtin_fn->param_types[2], len_node);
+
+                    if (dest_type->id != TypeTableEntryIdInvalid &&
+                        dest_type->id != TypeTableEntryIdPointer)
+                    {
+                        add_node_error(g, dest_node,
+                                buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&dest_type->name)));
+                    }
+
+                    if (src_type->id != TypeTableEntryIdInvalid &&
+                        src_type->id != TypeTableEntryIdPointer)
+                    {
+                        add_node_error(g, src_node,
+                                buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&src_type->name)));
+                    }
 
-        return builtin_fn->return_type;
+                    if (dest_type->id == TypeTableEntryIdPointer &&
+                        src_type->id == TypeTableEntryIdPointer)
+                    {
+                        uint64_t dest_align_bits = dest_type->data.pointer.child_type->align_in_bits;
+                        uint64_t src_align_bits = src_type->data.pointer.child_type->align_in_bits;
+                        if (dest_align_bits != src_align_bits) {
+                            add_node_error(g, dest_node, buf_sprintf(
+                                "misaligned memcpy, '%s' has alignment '%" PRIu64 ", '%s' has alignment %" PRIu64,
+                                        buf_ptr(&dest_type->name), dest_align_bits / 8,
+                                        buf_ptr(&src_type->name), src_align_bits / 8));
+                        }
+                    }
+
+                    return builtin_fn->return_type;
+                }
+            case BuiltinFnIdMemset:
+                {
+                    AstNode *dest_node = node->data.fn_call_expr.params.at(0);
+                    AstNode *char_node = node->data.fn_call_expr.params.at(1);
+                    AstNode *len_node = node->data.fn_call_expr.params.at(2);
+                    TypeTableEntry *dest_type = analyze_expression(g, import, context, nullptr, dest_node);
+                    analyze_expression(g, import, context, builtin_fn->param_types[1], char_node);
+                    analyze_expression(g, import, context, builtin_fn->param_types[2], len_node);
+
+                    if (dest_type->id != TypeTableEntryIdInvalid &&
+                        dest_type->id != TypeTableEntryIdPointer)
+                    {
+                        add_node_error(g, dest_node,
+                                buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&dest_type->name)));
+                    }
+
+                    return builtin_fn->return_type;
+                }
+        }
+        zig_unreachable();
     } else {
         add_node_error(g, node,
                 buf_sprintf("invalid builtin function: '%s'", buf_ptr(name)));
src/analyze.hpp
@@ -151,6 +151,8 @@ struct FnTableEntry {
 enum BuiltinFnId {
     BuiltinFnIdInvalid,
     BuiltinFnIdArithmeticWithOverflow,
+    BuiltinFnIdMemcpy,
+    BuiltinFnIdMemset,
 };
 
 struct BuiltinFnEntry {
@@ -354,6 +356,7 @@ struct ImportNode {
 };
 
 struct WhileNode {
+    bool condition_always_true;
     bool contains_break;
 };
 
src/codegen.cpp
@@ -171,6 +171,67 @@ static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
 
                 return overflow_bit;
             }
+        case BuiltinFnIdMemcpy:
+            {
+                int fn_call_param_count = node->data.fn_call_expr.params.length;
+                assert(fn_call_param_count == 3);
+
+                AstNode *dest_node = node->data.fn_call_expr.params.at(0);
+                TypeTableEntry *dest_type = get_expr_type(dest_node);
+
+                LLVMValueRef dest_ptr = gen_expr(g, dest_node);
+                LLVMValueRef src_ptr = gen_expr(g, node->data.fn_call_expr.params.at(1));
+                LLVMValueRef len_val = gen_expr(g, node->data.fn_call_expr.params.at(2));
+
+                LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
+
+                add_debug_source_node(g, node);
+                LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, dest_ptr, ptr_u8, "");
+                LLVMValueRef src_ptr_casted = LLVMBuildBitCast(g->builder, src_ptr, ptr_u8, "");
+
+                uint64_t align_in_bytes = dest_type->data.pointer.child_type->align_in_bits / 8;
+
+                LLVMValueRef params[] = {
+                    dest_ptr_casted, // dest pointer
+                    src_ptr_casted, // source pointer
+                    len_val, // byte count
+                    LLVMConstInt(LLVMInt32Type(), align_in_bytes, false), // align in bytes
+                    LLVMConstNull(LLVMInt1Type()), // is volatile
+                };
+
+                LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 5, "");
+                return nullptr;
+            }
+        case BuiltinFnIdMemset:
+            {
+                int fn_call_param_count = node->data.fn_call_expr.params.length;
+                assert(fn_call_param_count == 3);
+
+                AstNode *dest_node = node->data.fn_call_expr.params.at(0);
+                TypeTableEntry *dest_type = get_expr_type(dest_node);
+
+                LLVMValueRef dest_ptr = gen_expr(g, dest_node);
+                LLVMValueRef char_val = gen_expr(g, node->data.fn_call_expr.params.at(1));
+                LLVMValueRef len_val = gen_expr(g, node->data.fn_call_expr.params.at(2));
+
+                LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
+
+                add_debug_source_node(g, node);
+                LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, dest_ptr, ptr_u8, "");
+
+                uint64_t align_in_bytes = dest_type->data.pointer.child_type->align_in_bits / 8;
+
+                LLVMValueRef params[] = {
+                    dest_ptr_casted, // dest pointer
+                    char_val, // source pointer
+                    len_val, // byte count
+                    LLVMConstInt(LLVMInt32Type(), align_in_bytes, false), // align in bytes
+                    LLVMConstNull(LLVMInt1Type()), // is volatile
+                };
+
+                LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 5, "");
+                return nullptr;
+            }
     }
     zig_unreachable();
 }
@@ -1376,23 +1437,35 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
     assert(node->data.while_expr.condition);
     assert(node->data.while_expr.body);
 
-    if (get_expr_type(node)->id == TypeTableEntryIdUnreachable) {
-        // generate a forever loop. guarantees no break statements
+    bool condition_always_true = node->codegen_node->data.while_node.condition_always_true;
+    bool contains_break = node->codegen_node->data.while_node.contains_break;
+    if (condition_always_true) {
+        // generate a forever loop
 
         LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody");
+        LLVMBasicBlockRef end_block = nullptr;
+        if (contains_break) {
+            end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd");
+        }
 
         add_debug_source_node(g, node);
         LLVMBuildBr(g->builder, body_block);
 
         LLVMPositionBuilderAtEnd(g->builder, body_block);
+        g->break_block_stack.append(end_block);
         g->continue_block_stack.append(body_block);
         gen_expr(g, node->data.while_expr.body);
+        g->break_block_stack.pop();
         g->continue_block_stack.pop();
 
         if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) {
             add_debug_source_node(g, node);
             LLVMBuildBr(g->builder, body_block);
         }
+
+        if (contains_break) {
+            LLVMPositionBuilderAtEnd(g->builder, end_block);
+        }
     } else {
         // generate a normal while loop
 
@@ -1755,20 +1828,6 @@ static LLVMAttribute to_llvm_fn_attr(FnAttrId attr_id) {
 static void do_code_gen(CodeGen *g) {
     assert(!g->errors.length);
 
-    {
-        LLVMTypeRef param_types[] = {
-            LLVMPointerType(LLVMInt8Type(), 0),
-            LLVMPointerType(LLVMInt8Type(), 0),
-            LLVMIntType(g->pointer_size_bytes * 8),
-            LLVMInt32Type(),
-            LLVMInt1Type(),
-        };
-        LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
-        Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8);
-        g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
-        assert(LLVMGetIntrinsicID(g->memcpy_fn_val));
-    }
-
     // Generate module level variables
     for (int i = 0; i < g->global_vars.length; i += 1) {
         VariableTableEntry *var = g->global_vars.at(i);
@@ -2267,6 +2326,57 @@ static void define_builtin_fns(CodeGen *g) {
     define_builtin_fns_int(g, g->builtin_types.entry_i16);
     define_builtin_fns_int(g, g->builtin_types.entry_i32);
     define_builtin_fns_int(g, g->builtin_types.entry_i64);
+    {
+        BuiltinFnEntry *builtin_fn = allocate<BuiltinFnEntry>(1);
+        buf_init_from_str(&builtin_fn->name, "memcpy");
+        builtin_fn->id = BuiltinFnIdMemcpy;
+        builtin_fn->return_type = g->builtin_types.entry_void;
+        builtin_fn->param_count = 3;
+        builtin_fn->param_types = allocate<TypeTableEntry *>(builtin_fn->param_count);
+        builtin_fn->param_types[0] = nullptr; // manually checked later
+        builtin_fn->param_types[1] = nullptr; // manually checked later
+        builtin_fn->param_types[2] = g->builtin_types.entry_usize;
+
+        LLVMTypeRef param_types[] = {
+            LLVMPointerType(LLVMInt8Type(), 0),
+            LLVMPointerType(LLVMInt8Type(), 0),
+            LLVMIntType(g->pointer_size_bytes * 8),
+            LLVMInt32Type(),
+            LLVMInt1Type(),
+        };
+        LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
+        Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8);
+        g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
+        builtin_fn->fn_val = g->memcpy_fn_val;
+        assert(LLVMGetIntrinsicID(g->memcpy_fn_val));
+
+        g->builtin_fn_table.put(&builtin_fn->name, builtin_fn);
+    }
+    {
+        BuiltinFnEntry *builtin_fn = allocate<BuiltinFnEntry>(1);
+        buf_init_from_str(&builtin_fn->name, "memset");
+        builtin_fn->id = BuiltinFnIdMemset;
+        builtin_fn->return_type = g->builtin_types.entry_void;
+        builtin_fn->param_count = 3;
+        builtin_fn->param_types = allocate<TypeTableEntry *>(builtin_fn->param_count);
+        builtin_fn->param_types[0] = nullptr; // manually checked later
+        builtin_fn->param_types[1] = g->builtin_types.entry_u8;
+        builtin_fn->param_types[2] = g->builtin_types.entry_usize;
+
+        LLVMTypeRef param_types[] = {
+            LLVMPointerType(LLVMInt8Type(), 0),
+            LLVMInt8Type(),
+            LLVMIntType(g->pointer_size_bytes * 8),
+            LLVMInt32Type(),
+            LLVMInt1Type(),
+        };
+        LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
+        Buf *name = buf_sprintf("llvm.memset.p0i8.i%d", g->pointer_size_bytes * 8);
+        builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
+        assert(LLVMGetIntrinsicID(builtin_fn->fn_val));
+
+        g->builtin_fn_table.put(&builtin_fn->name, builtin_fn);
+    }
 }
 
 
std/std.zig
@@ -118,13 +118,7 @@ fn buf_print_u64(out_buf: []u8, x: u64) -> usize {
 
     const len = buf.len - index;
 
-    // TODO memcpy intrinsic
-    // @memcpy(out_buf, buf, len);
-    var i: usize = 0;
-    while (i < len) {
-        out_buf[i] = buf[index + i];
-        i += 1;
-    }
+    @memcpy(out_buf.ptr, &buf[index], len);
 
     return len;
 }
test/run_tests.cpp
@@ -973,6 +973,24 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
     return 0;
 }
     )SOURCE", "OK\n");
+
+    add_simple_case("memcpy and memset intrinsics", R"SOURCE(
+use "std.zig";
+pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
+    var foo : [20]u8;
+    var bar : [20]u8;
+
+    @memset(foo.ptr, 'A', foo.len);
+    @memcpy(bar.ptr, foo.ptr, bar.len);
+
+    if (bar[11] != 'A') {
+        print_str("BAD\n");
+    }
+
+    print_str("OK\n");
+    return 0;
+}
+    )SOURCE", "OK\n");
 }