Commit b6e7a0dadd

Andrew Kelley <superjoe30@gmail.com>
2017-02-16 23:08:55
support arithmetic for non byte aligned integer types
see #261
1 parent fc5d47b
src/all_types.hpp
@@ -1219,6 +1219,39 @@ struct TypeId {
 uint32_t type_id_hash(TypeId);
 bool type_id_eql(TypeId a, TypeId b);
 
+enum ZigLLVMFnId {
+    ZigLLVMFnIdCtz,
+    ZigLLVMFnIdClz,
+    ZigLLVMFnIdOverflowArithmetic,
+};
+
+enum AddSubMul {
+    AddSubMulAdd = 0,
+    AddSubMulSub = 1,
+    AddSubMulMul = 2,
+};
+
+struct ZigLLVMFnKey {
+    ZigLLVMFnId id;
+
+    union {
+        struct {
+            uint32_t bit_count;
+        } ctz;
+        struct {
+            uint32_t bit_count;
+        } clz;
+        struct {
+            AddSubMul add_sub_mul;
+            uint32_t bit_count;
+            bool is_signed;
+        } overflow_arithmetic;
+    } data;
+};
+
+uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey);
+bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b);
+
 struct CodeGen {
     LLVMModuleRef module;
     ZigList<ErrorMsg*> errors;
@@ -1239,6 +1272,7 @@ struct CodeGen {
     HashMap<Buf *, ErrorTableEntry *, buf_hash, buf_eql_buf> error_table;
     HashMap<GenericFnTypeId *, FnTableEntry *, generic_fn_type_id_hash, generic_fn_type_id_eql> generic_table;
     HashMap<Scope *, IrInstruction *, fn_eval_hash, fn_eval_eql> memoized_fn_eval_table;
+    HashMap<ZigLLVMFnKey, LLVMValueRef, zig_llvm_fn_key_hash, zig_llvm_fn_key_eql> llvm_fn_table;
 
     ZigList<ImportTableEntry *> import_queue;
     size_t import_queue_index;
@@ -1363,8 +1397,6 @@ struct CodeGen {
     bool error_during_imports;
     uint32_t next_node_index;
     TypeTableEntry *err_tag_type;
-    LLVMValueRef int_overflow_fns[2][3][4]; // [0-signed,1-unsigned][0-add,1-sub,2-mul][0-8,1-16,2-32,3-64]
-    LLVMValueRef int_builtin_fns[2][4]; // [0-ctz,1-clz][0-8,1-16,2-32,3-64]
 
     const char **clang_argv;
     size_t clang_argv_len;
src/analyze.cpp
@@ -3997,3 +3997,33 @@ bool type_id_eql(TypeId a, TypeId b) {
     }
     zig_unreachable();
 }
+
+uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) {
+    switch (x.id) {
+        case ZigLLVMFnIdCtz:
+            return x.data.ctz.bit_count * 810453934;
+        case ZigLLVMFnIdClz:
+            return x.data.clz.bit_count * 2428952817;
+        case ZigLLVMFnIdOverflowArithmetic:
+            return (x.data.overflow_arithmetic.bit_count * 87135777) +
+                (x.data.overflow_arithmetic.add_sub_mul * 31640542) +
+                (x.data.overflow_arithmetic.is_signed ? 1062315172 : 314955820);
+    }
+    zig_unreachable();
+}
+
+bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) {
+    if (a.id != b.id)
+        return false;
+    switch (a.id) {
+        case ZigLLVMFnIdCtz:
+            return a.data.ctz.bit_count == b.data.ctz.bit_count;
+        case ZigLLVMFnIdClz:
+            return a.data.clz.bit_count == b.data.clz.bit_count;
+        case ZigLLVMFnIdOverflowArithmetic:
+            return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) &&
+                (a.data.overflow_arithmetic.add_sub_mul == b.data.overflow_arithmetic.add_sub_mul) &&
+                (a.data.overflow_arithmetic.is_signed == b.data.overflow_arithmetic.is_signed);
+    }
+    zig_unreachable();
+}
src/codegen.cpp
@@ -62,6 +62,7 @@ CodeGen *codegen_create(Buf *root_source_dir, const ZigTarget *target) {
     g->fn_type_table.init(32);
     g->error_table.init(16);
     g->generic_table.init(16);
+    g->llvm_fn_table.init(16);
     g->memoized_fn_eval_table.init(16);
     g->is_release_build = false;
     g->is_test_build = false;
@@ -352,33 +353,14 @@ static void clear_debug_source_node(CodeGen *g) {
     ZigLLVMClearCurrentDebugLocation(g->builder);
 }
 
-enum AddSubMul {
-    AddSubMulAdd = 0,
-    AddSubMulSub = 1,
-    AddSubMulMul = 2,
-};
-
-static size_t bits_index(size_t size_in_bits) {
-    switch (size_in_bits) {
-        case 8:
-            return 0;
-        case 16:
-            return 1;
-        case 32:
-            return 2;
-        case 64:
-            return 3;
-        default:
-            zig_unreachable();
-    }
-}
-
 static LLVMValueRef get_arithmetic_overflow_fn(CodeGen *g, TypeTableEntry *type_entry,
         const char *signed_name, const char *unsigned_name)
 {
+    char fn_name[64];
+
     assert(type_entry->id == TypeTableEntryIdInt);
     const char *signed_str = type_entry->data.integral.is_signed ? signed_name : unsigned_name;
-    Buf *llvm_name = buf_sprintf("llvm.%s.with.overflow.i%zu", signed_str, type_entry->data.integral.bit_count);
+    sprintf(fn_name, "llvm.%s.with.overflow.i%zu", signed_str, type_entry->data.integral.bit_count);
 
     LLVMTypeRef return_elem_types[] = {
         type_entry->type_ref,
@@ -390,34 +372,39 @@ static LLVMValueRef get_arithmetic_overflow_fn(CodeGen *g, TypeTableEntry *type_
     };
     LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false);
     LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false);
-    LLVMValueRef fn_val = LLVMAddFunction(g->module, buf_ptr(llvm_name), fn_type);
+    LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type);
     assert(LLVMGetIntrinsicID(fn_val));
     return fn_val;
 }
 
 static LLVMValueRef get_int_overflow_fn(CodeGen *g, TypeTableEntry *type_entry, AddSubMul add_sub_mul) {
     assert(type_entry->id == TypeTableEntryIdInt);
-    // [0-signed,1-unsigned][0-add,1-sub,2-mul][0-8,1-16,2-32,3-64]
-    size_t index0 = type_entry->data.integral.is_signed ? 0 : 1;
-    size_t index1 = add_sub_mul;
-    size_t index2 = bits_index(type_entry->data.integral.bit_count);
-    LLVMValueRef *fn = &g->int_overflow_fns[index0][index1][index2];
-    if (*fn) {
-        return *fn;
-    }
+
+    ZigLLVMFnKey key = {};
+    key.id = ZigLLVMFnIdOverflowArithmetic;
+    key.data.overflow_arithmetic.is_signed = type_entry->data.integral.is_signed;
+    key.data.overflow_arithmetic.add_sub_mul = add_sub_mul;
+    key.data.overflow_arithmetic.bit_count = type_entry->data.integral.bit_count;
+
+    auto existing_entry = g->llvm_fn_table.maybe_get(key);
+    if (existing_entry)
+        return existing_entry->value;
+
+    LLVMValueRef fn_val;
     switch (add_sub_mul) {
         case AddSubMulAdd:
-            *fn = get_arithmetic_overflow_fn(g, type_entry, "sadd", "uadd");
+            fn_val = get_arithmetic_overflow_fn(g, type_entry, "sadd", "uadd");
             break;
         case AddSubMulSub:
-            *fn = get_arithmetic_overflow_fn(g, type_entry, "ssub", "usub");
+            fn_val = get_arithmetic_overflow_fn(g, type_entry, "ssub", "usub");
             break;
         case AddSubMulMul:
-            *fn = get_arithmetic_overflow_fn(g, type_entry, "smul", "umul");
+            fn_val = get_arithmetic_overflow_fn(g, type_entry, "smul", "umul");
             break;
-
     }
-    return *fn;
+
+    g->llvm_fn_table.put(key, fn_val);
+    return fn_val;
 }
 
 static LLVMValueRef get_handle_value(CodeGen *g, LLVMValueRef ptr, TypeTableEntry *type, bool is_volatile) {
@@ -1388,13 +1375,22 @@ static LLVMValueRef ir_render_load_ptr(CodeGen *g, IrExecutable *executable, IrI
     bool is_volatile = ptr_type->data.pointer.is_volatile;
 
     uint32_t bit_offset = ptr_type->data.pointer.bit_offset;
-    if (bit_offset == 0)
-        return get_handle_value(g, ptr, child_type, is_volatile);
-
-    assert(!handle_is_ptr(child_type));
-
-    LLVMValueRef containing_int = LLVMBuildLoad(g->builder, ptr, "");
-    LLVMSetVolatile(containing_int, is_volatile);
+    LLVMValueRef containing_int;
+    if (bit_offset == 0) {
+        LLVMValueRef result_val = get_handle_value(g, ptr, child_type, is_volatile);
+        if (LLVMGetTypeKind(LLVMTypeOf(result_val)) == LLVMIntegerTypeKind &&
+            LLVMGetTypeKind(child_type->type_ref) == LLVMIntegerTypeKind &&
+            LLVMGetIntTypeWidth(child_type->type_ref) < LLVMGetIntTypeWidth(LLVMTypeOf(result_val)))
+        {
+            containing_int = result_val;
+        } else {
+            return result_val;
+        }
+    } else {
+        assert(!handle_is_ptr(child_type));
+        containing_int = LLVMBuildLoad(g->builder, ptr, "");
+        LLVMSetVolatile(containing_int, is_volatile);
+    }
 
     uint32_t child_bit_count = type_size_bits(g, child_type);
     uint32_t host_bit_count = LLVMGetIntTypeWidth(LLVMTypeOf(containing_int));
@@ -1748,21 +1744,34 @@ static LLVMValueRef ir_render_unwrap_maybe(CodeGen *g, IrExecutable *executable,
 }
 
 static LLVMValueRef get_int_builtin_fn(CodeGen *g, TypeTableEntry *int_type, BuiltinFnId fn_id) {
-    // [0-ctz,1-clz][0-8,1-16,2-32,3-64]
-    size_t index0 = (fn_id == BuiltinFnIdCtz) ? 0 : 1;
-    size_t index1 = bits_index(int_type->data.integral.bit_count);
-    LLVMValueRef *fn = &g->int_builtin_fns[index0][index1];
-    if (!*fn) {
-        const char *fn_name = (fn_id == BuiltinFnIdCtz) ? "cttz" : "ctlz";
-        Buf *llvm_name = buf_sprintf("llvm.%s.i%zu", fn_name, int_type->data.integral.bit_count);
-        LLVMTypeRef param_types[] = {
-            int_type->type_ref,
-            LLVMInt1Type(),
-        };
-        LLVMTypeRef fn_type = LLVMFunctionType(int_type->type_ref, param_types, 2, false);
-        *fn = LLVMAddFunction(g->module, buf_ptr(llvm_name), fn_type);
+    ZigLLVMFnKey key = {};
+    const char *fn_name;
+    if (fn_id == BuiltinFnIdCtz) {
+        fn_name = "cttz";
+        key.id = ZigLLVMFnIdCtz;
+        key.data.ctz.bit_count = int_type->data.integral.bit_count;
+    } else {
+        fn_name = "ctlz";
+        key.id = ZigLLVMFnIdClz;
+        key.data.clz.bit_count = int_type->data.integral.bit_count;
     }
-    return *fn;
+
+    auto existing_entry = g->llvm_fn_table.maybe_get(key);
+    if (existing_entry)
+        return existing_entry->value;
+
+    char llvm_name[64];
+    sprintf(llvm_name, "llvm.%s.i%zu", fn_name, int_type->data.integral.bit_count);
+    LLVMTypeRef param_types[] = {
+        int_type->type_ref,
+        LLVMInt1Type(),
+    };
+    LLVMTypeRef fn_type = LLVMFunctionType(int_type->type_ref, param_types, 2, false);
+    LLVMValueRef fn_val = LLVMAddFunction(g->module, llvm_name, fn_type);
+
+    g->llvm_fn_table.put(key, fn_val);
+
+    return fn_val;
 }
 
 static LLVMValueRef ir_render_clz(CodeGen *g, IrExecutable *executable, IrInstructionClz *instruction) {
std/math.zig
@@ -71,15 +71,13 @@ fn getReturnTypeForAbs(comptime T: type) -> type {
 fn testMath() {
     @setFnTest(this);
 
+    testMathImpl();
+    comptime testMathImpl();
+}
+
+fn testMathImpl() {
     assert(%%mulOverflow(i32, 3, 4) == 12);
     assert(%%addOverflow(i32, 3, 4) == 7);
     assert(%%subOverflow(i32, 3, 4) == -1);
     assert(%%shlOverflow(i32, 0b11, 4) == 0b110000);
-
-    comptime {
-        assert(%%mulOverflow(i32, 3, 4) == 12);
-        assert(%%addOverflow(i32, 3, 4) == 7);
-        assert(%%subOverflow(i32, 3, 4) == -1);
-        assert(%%shlOverflow(i32, 0b11, 4) == 0b110000);
-    }
 }
test/cases/math.zig
@@ -180,3 +180,24 @@ fn binaryNot() {
 fn testBinaryNot(x: u16) {
     assert(~x == 0b0101010101010101);
 }
+
+fn smallIntAddition() {
+    @setFnTest(this);
+
+    var x: @intType(false, 2) = 0;
+    assert(x == 0);
+
+    x += 1;
+    assert(x == 1);
+
+    x += 1;
+    assert(x == 2);
+
+    x += 1;
+    assert(x == 3);
+
+    var result: @typeOf(x) = 3;
+    assert(@addWithOverflow(@typeOf(x), x, 1, &result));
+
+    assert(result == 0);
+}
test/cases/struct.zig
@@ -244,11 +244,16 @@ fn bitFieldAccess() {
         .b = 2,
         .c = 3,
     };
+    assert(getA(&data) == 1);
     assert(getB(&data) == 2);
     assert(getC(&data) == 3);
     comptime assert(@sizeOf(BitField1) == 1);
 }
 
+fn getA(data: &const BitField1) -> u3 {
+    return data.a;
+}
+
 fn getB(data: &const BitField1) -> u3 {
     return data.b;
 }