Commit fe38d8142f

Andrew Kelley <superjoe30@gmail.com>
2018-03-23 01:22:15
create multiple llvm.memcpy and llvm.memset with different align params
1 parent 62668e3
src/all_types.hpp
@@ -1415,6 +1415,8 @@ enum ZigLLVMFnId {
     ZigLLVMFnIdOverflowArithmetic,
     ZigLLVMFnIdFloor,
     ZigLLVMFnIdCeil,
+    ZigLLVMFnIdMemcpy,
+    ZigLLVMFnIdMemset,
 };
 
 enum AddSubMul {
@@ -1441,6 +1443,13 @@ struct ZigLLVMFnKey {
             uint32_t bit_count;
             bool is_signed;
         } overflow_arithmetic;
+        struct {
+            uint32_t dest_align;
+            uint32_t src_align;
+        } memcpy;
+        struct {
+            uint32_t dest_align;
+        } memset;
     } data;
 };
 
@@ -1629,8 +1638,6 @@ struct CodeGen {
     ImportTableEntry *root_import;
     ImportTableEntry *bootstrap_import;
     ImportTableEntry *test_runner_import;
-    LLVMValueRef memcpy_fn_val;
-    LLVMValueRef memset_fn_val;
     LLVMValueRef trap_fn_val;
     LLVMValueRef return_address_fn_val;
     LLVMValueRef frame_address_fn_val;
src/analyze.cpp
@@ -5681,6 +5681,10 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) {
             return ((uint32_t)(x.data.overflow_arithmetic.bit_count) * 87135777) +
                 ((uint32_t)(x.data.overflow_arithmetic.add_sub_mul) * 31640542) +
                 ((uint32_t)(x.data.overflow_arithmetic.is_signed) ? 1062315172 : 314955820);
+        case ZigLLVMFnIdMemcpy:
+            return x.data.memcpy.dest_align * 2325524557 + x.data.memcpy.src_align * 519976394;
+        case ZigLLVMFnIdMemset:
+            return x.data.memset.dest_align * 388171592;
     }
     zig_unreachable();
 }
@@ -5700,6 +5704,11 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) {
             return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) &&
                 (a.data.overflow_arithmetic.add_sub_mul == b.data.overflow_arithmetic.add_sub_mul) &&
                 (a.data.overflow_arithmetic.is_signed == b.data.overflow_arithmetic.is_signed);
+        case ZigLLVMFnIdMemcpy:
+            return (a.data.memcpy.dest_align == b.data.memcpy.dest_align) &&
+                   (a.data.memcpy.src_align == b.data.memcpy.src_align);
+        case ZigLLVMFnIdMemset:
+            return (a.data.memset.dest_align == b.data.memset.dest_align);
     }
     zig_unreachable();
 }
src/codegen.cpp
@@ -324,8 +324,12 @@ static void addLLVMFnAttrInt(LLVMValueRef fn_val, const char *attr_name, uint64_
     return addLLVMAttrInt(fn_val, -1, attr_name, attr_val);
 }
 
-static void addLLVMArgAttr(LLVMValueRef arg_val, unsigned param_index, const char *attr_name) {
-    return addLLVMAttr(arg_val, param_index + 1, attr_name);
+static void addLLVMArgAttr(LLVMValueRef fn_val, unsigned param_index, const char *attr_name) {
+    return addLLVMAttr(fn_val, param_index + 1, attr_name);
+}
+
+static void addLLVMArgAttrInt(LLVMValueRef fn_val, unsigned param_index, const char *attr_name, uint64_t attr_val) {
+    return addLLVMAttrInt(fn_val, param_index + 1, attr_name, attr_val);
 }
 
 static void addLLVMCallsiteAttr(LLVMValueRef call_instr, unsigned param_index, const char *attr_name) {
@@ -912,23 +916,31 @@ static void gen_safety_crash(CodeGen *g, PanicMsgId msg_id) {
     gen_panic(g, get_panic_msg_ptr_val(g, msg_id), nullptr);
 }
 
-static LLVMValueRef get_memcpy_fn_val(CodeGen *g) {
-    if (g->memcpy_fn_val)
-        return g->memcpy_fn_val;
+static LLVMValueRef get_memcpy_fn_val(CodeGen *g, uint32_t dest_align, uint32_t src_align) {
+    ZigLLVMFnKey key = {};
+    key.id = ZigLLVMFnIdMemcpy;
+    key.data.memcpy.dest_align = dest_align;
+    key.data.memcpy.src_align = src_align;
+
+    auto existing_entry = g->llvm_fn_table.maybe_get(key);
+    if (existing_entry)
+        return existing_entry->value;
 
     LLVMTypeRef param_types[] = {
         LLVMPointerType(LLVMInt8Type(), 0),
         LLVMPointerType(LLVMInt8Type(), 0),
         LLVMIntType(g->pointer_size_bytes * 8),
-        LLVMInt32Type(),
         LLVMInt1Type(),
     };
-    LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
+    LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 4, false);
     Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8);
-    g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
-    assert(LLVMGetIntrinsicID(g->memcpy_fn_val));
+    LLVMValueRef fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
+    addLLVMArgAttrInt(fn_val, 0, "align", dest_align);
+    addLLVMArgAttrInt(fn_val, 1, "align", src_align);
+    assert(LLVMGetIntrinsicID(fn_val));
 
-    return g->memcpy_fn_val;
+    g->llvm_fn_table.put(key, fn_val);
+    return fn_val;
 }
 
 static LLVMValueRef get_coro_destroy_fn_val(CodeGen *g) {
@@ -1293,15 +1305,15 @@ static LLVMValueRef get_safety_crash_err_fn(CodeGen *g) {
     LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, err_name_val, slice_len_index, "");
     LLVMValueRef err_name_len = gen_load_untyped(g, len_field_ptr, 0, false, "");
 
+    LLVMValueRef memcpy_fn_val = get_memcpy_fn_val(g, u8_align_bytes, u8_align_bytes);
     LLVMValueRef params[] = {
         offset_buf_ptr, // dest pointer
         err_name_ptr, // source pointer
         err_name_len, // size bytes
-        LLVMConstInt(LLVMInt32Type(), u8_align_bytes, false), // align bytes
         LLVMConstNull(LLVMInt1Type()), // is volatile
     };
 
-    LLVMBuildCall(g->builder, get_memcpy_fn_val(g), params, 5, "");
+    LLVMBuildCall(g->builder, memcpy_fn_val, params, 4, "");
 
     LLVMValueRef const_prefix_len = LLVMConstInt(LLVMTypeOf(err_name_len), strlen(unwrap_err_msg_text), false);
     LLVMValueRef full_buf_len = LLVMBuildNUWAdd(g->builder, const_prefix_len, err_name_len, "");
@@ -1535,15 +1547,16 @@ static LLVMValueRef gen_assign_raw(CodeGen *g, LLVMValueRef ptr, TypeTableEntry
         LLVMValueRef volatile_bit = ptr_type->data.pointer.is_volatile ?
             LLVMConstAllOnes(LLVMInt1Type()) : LLVMConstNull(LLVMInt1Type());
 
+        LLVMValueRef memcpy_fn_val = get_memcpy_fn_val(g, align_bytes, align_bytes);
+
         LLVMValueRef params[] = {
             dest_ptr, // dest pointer
             src_ptr, // source pointer
             LLVMConstInt(usize->type_ref, size_bytes, false),
-            LLVMConstInt(LLVMInt32Type(), align_bytes, false),
             volatile_bit,
         };
 
-        LLVMBuildCall(g->builder, get_memcpy_fn_val(g), params, 5, "");
+        LLVMBuildCall(g->builder, memcpy_fn_val, params, 4, "");
         return nullptr;
     }
 
@@ -2483,23 +2496,29 @@ static LLVMValueRef ir_render_bool_not(CodeGen *g, IrExecutable *executable, IrI
     return LLVMBuildICmp(g->builder, LLVMIntEQ, value, zero, "");
 }
 
-static LLVMValueRef get_memset_fn_val(CodeGen *g) {
-    if (g->memset_fn_val)
-        return g->memset_fn_val;
+static LLVMValueRef get_memset_fn_val(CodeGen *g, uint32_t dest_align) {
+    ZigLLVMFnKey key = {};
+    key.id = ZigLLVMFnIdMemset;
+    key.data.memset.dest_align = dest_align;
+
+    auto existing_entry = g->llvm_fn_table.maybe_get(key);
+    if (existing_entry)
+        return existing_entry->value;
 
     LLVMTypeRef param_types[] = {
         LLVMPointerType(LLVMInt8Type(), 0),
         LLVMInt8Type(),
         LLVMIntType(g->pointer_size_bytes * 8),
-        LLVMInt32Type(),
         LLVMInt1Type(),
     };
-    LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
+    LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 4, false);
     Buf *name = buf_sprintf("llvm.memset.p0i8.i%d", g->pointer_size_bytes * 8);
-    g->memset_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
-    assert(LLVMGetIntrinsicID(g->memset_fn_val));
+    LLVMValueRef fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
+    addLLVMArgAttrInt(fn_val, 0, "align", dest_align);
+    assert(LLVMGetIntrinsicID(fn_val));
 
-    return g->memset_fn_val;
+    g->llvm_fn_table.put(key, fn_val);
+    return fn_val;
 }
 
 static LLVMValueRef ir_render_decl_var(CodeGen *g, IrExecutable *executable,
@@ -2535,21 +2554,21 @@ static LLVMValueRef ir_render_decl_var(CodeGen *g, IrExecutable *executable,
 
             assert(var->align_bytes > 0);
 
+            LLVMValueRef memset_fn_val = get_memset_fn_val(g, var->align_bytes);
+
             // memset uninitialized memory to 0xa
             LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
             LLVMValueRef fill_char = LLVMConstInt(LLVMInt8Type(), 0xaa, false);
             LLVMValueRef dest_ptr = LLVMBuildBitCast(g->builder, var->value_ref, ptr_u8, "");
             LLVMValueRef byte_count = LLVMConstInt(usize->type_ref, size_bytes, false);
-            LLVMValueRef align_in_bytes = LLVMConstInt(LLVMInt32Type(), var->align_bytes, false);
             LLVMValueRef params[] = {
                 dest_ptr,
                 fill_char,
                 byte_count,
-                align_in_bytes,
                 LLVMConstNull(LLVMInt1Type()), // is volatile
             };
 
-            LLVMBuildCall(g->builder, get_memset_fn_val(g), params, 5, "");
+            LLVMBuildCall(g->builder, memset_fn_val, params, 4, "");
         }
     }
 
@@ -3399,17 +3418,16 @@ static LLVMValueRef ir_render_memset(CodeGen *g, IrExecutable *executable, IrIns
     LLVMValueRef is_volatile = ptr_type->data.pointer.is_volatile ?
         LLVMConstAllOnes(LLVMInt1Type()) : LLVMConstNull(LLVMInt1Type());
 
-    LLVMValueRef align_val = LLVMConstInt(LLVMInt32Type(), ptr_type->data.pointer.alignment, false);
+    LLVMValueRef memset_fn_val = get_memset_fn_val(g, ptr_type->data.pointer.alignment);
 
     LLVMValueRef params[] = {
         dest_ptr_casted,
         char_val,
         len_val,
-        align_val,
         is_volatile,
     };
 
-    LLVMBuildCall(g->builder, get_memset_fn_val(g), params, 5, "");
+    LLVMBuildCall(g->builder, memset_fn_val, params, 4, "");
     return nullptr;
 }
 
@@ -3432,18 +3450,17 @@ static LLVMValueRef ir_render_memcpy(CodeGen *g, IrExecutable *executable, IrIns
     LLVMValueRef is_volatile = (dest_ptr_type->data.pointer.is_volatile || src_ptr_type->data.pointer.is_volatile) ?
         LLVMConstAllOnes(LLVMInt1Type()) : LLVMConstNull(LLVMInt1Type());
 
-    uint32_t min_align_bytes = min(src_ptr_type->data.pointer.alignment, dest_ptr_type->data.pointer.alignment);
-    LLVMValueRef align_val = LLVMConstInt(LLVMInt32Type(), min_align_bytes, false);
+    LLVMValueRef memcpy_fn_val = get_memcpy_fn_val(g, dest_ptr_type->data.pointer.alignment,
+            src_ptr_type->data.pointer.alignment);
 
     LLVMValueRef params[] = {
         dest_ptr_casted,
         src_ptr_casted,
         len_val,
-        align_val,
         is_volatile,
     };
 
-    LLVMBuildCall(g->builder, get_memcpy_fn_val(g), params, 5, "");
+    LLVMBuildCall(g->builder, memcpy_fn_val, params, 4, "");
     return nullptr;
 }