Commit 100802cdc0

Andrew Kelley <superjoe30@gmail.com>
2016-05-07 00:46:38
add debug safety for left shifting
See #46
1 parent 0c96920
src/codegen.cpp
@@ -440,7 +440,7 @@ static LLVMValueRef gen_cmp_exchange(CodeGen *g, AstNode *node) {
     LLVMAtomicOrdering failure_order = to_LLVMAtomicOrdering((AtomicOrder)failure_order_val->data.x_enum.tag);
 
     LLVMValueRef result_val = ZigLLVMBuildCmpXchg(g->builder, ptr_val, cmp_val, new_val,
-            success_order, failure_order, "");
+            success_order, failure_order);
 
     return LLVMBuildExtractValue(g->builder, result_val, 1, "");
 }
@@ -1309,6 +1309,36 @@ static LLVMValueRef gen_overflow_op(CodeGen *g, TypeTableEntry *type_entry, AddS
     return result;
 }
 
+static LLVMValueRef gen_overflow_shl_op(CodeGen *g, TypeTableEntry *type_entry,
+        LLVMValueRef val1, LLVMValueRef val2)
+{
+    // for unsigned left shifting, we do the wrapping shift, then logically shift
+    // right the same number of bits
+    // if the values don't match, we have an overflow
+    // for signed left shifting we do the same except arithmetic shift right
+
+    assert(type_entry->id == TypeTableEntryIdInt);
+
+    LLVMValueRef result = LLVMBuildShl(g->builder, val1, val2, "");
+    LLVMValueRef orig_val;
+    if (type_entry->data.integral.is_signed) {
+        orig_val = LLVMBuildAShr(g->builder, result, val2, "");
+    } else {
+        orig_val = LLVMBuildLShr(g->builder, result, val2, "");
+    }
+    LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, orig_val, "");
+
+    LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowOk");
+    LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowFail");
+    LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, fail_block);
+    gen_debug_safety_crash(g);
+
+    LLVMPositionBuilderAtEnd(g->builder, ok_block);
+    return result;
+}
+
 static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypePrefixOpExpr);
     assert(node->data.prefix_op_expr.primary_expr);
@@ -1484,7 +1514,16 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g, AstNode *source_node,
         case BinOpTypeBitShiftLeft:
         case BinOpTypeAssignBitShiftLeft:
             set_debug_source_node(g, source_node);
-            return LLVMBuildShl(g->builder, val1, val2, "");
+            assert(op1_type->id == TypeTableEntryIdInt);
+            if (op1_type->data.integral.is_wrapping) {
+                return LLVMBuildShl(g->builder, val1, val2, "");
+            } else if (want_debug_safety(g, source_node)) {
+                return gen_overflow_shl_op(g, op1_type, val1, val2);
+            } else if (op1_type->data.integral.is_signed) {
+                return ZigLLVMBuildNSWShl(g->builder, val1, val2, "");
+            } else {
+                return ZigLLVMBuildNUWShl(g->builder, val1, val2, "");
+            }
         case BinOpTypeBitShiftRight:
         case BinOpTypeAssignBitShiftRight:
             assert(op1_type->id == TypeTableEntryIdInt);
src/zig_llvm.cpp
@@ -661,14 +661,25 @@ static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering) {
 
 LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp,
         LLVMValueRef new_val, LLVMAtomicOrdering success_ordering,
-        LLVMAtomicOrdering failure_ordering,
-        const char *name)
+        LLVMAtomicOrdering failure_ordering)
 {
     return wrap(unwrap(builder)->CreateAtomicCmpXchg(unwrap(ptr), unwrap(cmp), unwrap(new_val),
                 mapFromLLVMOrdering(success_ordering), mapFromLLVMOrdering(failure_ordering),
                 CrossThread));
 }
 
+LLVMValueRef ZigLLVMBuildNSWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS,
+        const char *name)
+{
+    return wrap(unwrap(builder)->CreateShl(unwrap(LHS), unwrap(RHS), name, false, true));
+}
+
+LLVMValueRef ZigLLVMBuildNUWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS,
+        const char *name)
+{
+    return wrap(unwrap(builder)->CreateShl(unwrap(LHS), unwrap(RHS), name, false, true));
+}
+
 
 //------------------------------------
 
src/zig_llvm.hpp
@@ -41,7 +41,11 @@ LLVMValueRef LLVMZigBuildCall(LLVMBuilderRef B, LLVMValueRef Fn, LLVMValueRef *A
 
 LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp,
         LLVMValueRef new_val, LLVMAtomicOrdering success_ordering,
-        LLVMAtomicOrdering failure_ordering,
+        LLVMAtomicOrdering failure_ordering);
+
+LLVMValueRef ZigLLVMBuildNSWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS,
+        const char *name);
+LLVMValueRef ZigLLVMBuildNUWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS,
         const char *name);
 
 // 0 is return value, 1 is first arg
test/run_tests.cpp
@@ -1363,6 +1363,26 @@ fn neg(a: i16) -> i16 {
 }
     )SOURCE");
 
+    add_debug_safety_case("signed shift left overflow", R"SOURCE(
+pub fn main(args: [][]u8) -> %void {
+    shl(-16385, 1);
+}
+#static_eval_enable(false)
+fn shl(a: i16, b: i16) -> i16 {
+    a << b
+}
+    )SOURCE");
+
+    add_debug_safety_case("unsigned shift left overflow", R"SOURCE(
+pub fn main(args: [][]u8) -> %void {
+    shl(0b0010111111111111, 3);
+}
+#static_eval_enable(false)
+fn shl(a: u16, b: u16) -> u16 {
+    a << b
+}
+    )SOURCE");
+
 }
 
 //////////////////////////////////////////////////////////////////////////////