Commit b8d17b11a7

Andrew Kelley <superjoe30@gmail.com>
2016-05-06 03:07:04
add tests for integer overflow crashing
see #46
1 parent 094336f
Changed files (2)
src/codegen.cpp
@@ -1097,7 +1097,7 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) {
         LLVMBuildStore(g->builder, slice_start_ptr, ptr_field_ptr);
 
         LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 1, "");
-        LLVMValueRef len_value = LLVMBuildSub(g->builder, end_val, start_val, "");
+        LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, "");
         LLVMBuildStore(g->builder, len_value, len_field_ptr);
 
         return tmp_struct_ptr;
@@ -1115,7 +1115,7 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) {
         LLVMBuildStore(g->builder, slice_start_ptr, ptr_field_ptr);
 
         LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 1, "");
-        LLVMValueRef len_value = LLVMBuildSub(g->builder, end_val, start_val, "");
+        LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, "");
         LLVMBuildStore(g->builder, len_value, len_field_ptr);
 
         return tmp_struct_ptr;
@@ -1160,7 +1160,7 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) {
         LLVMBuildStore(g->builder, slice_start_ptr, ptr_field_ptr);
 
         LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, len_index, "");
-        LLVMValueRef len_value = LLVMBuildSub(g->builder, end_val, start_val, "");
+        LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, "");
         LLVMBuildStore(g->builder, len_value, len_field_ptr);
 
         return tmp_struct_ptr;
@@ -1287,6 +1287,28 @@ static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *expr_node, AstNode *node,
     return target_ref;
 }
 
+static LLVMValueRef gen_overflow_op(CodeGen *g, TypeTableEntry *type_entry, AddSubMul op,
+        LLVMValueRef val1, LLVMValueRef val2)
+{
+    LLVMValueRef fn_val = get_int_overflow_fn(g, type_entry, op);
+    LLVMValueRef params[] = {
+        val1,
+        val2,
+    };
+    LLVMValueRef result_struct = LLVMBuildCall(g->builder, fn_val, params, 2, "");
+    LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
+    LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
+    LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowFail");
+    LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowOk");
+    LLVMBuildCondBr(g->builder, overflow_bit, fail_block, ok_block);
+
+    LLVMPositionBuilderAtEnd(g->builder, fail_block);
+    gen_debug_safety_crash(g);
+
+    LLVMPositionBuilderAtEnd(g->builder, ok_block);
+    return result;
+}
+
 static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypePrefixOpExpr);
     assert(node->data.prefix_op_expr.primary_expr);
@@ -1300,12 +1322,20 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
         case PrefixOpNegation:
             {
                 LLVMValueRef expr = gen_expr(g, expr_node);
-                if (expr_type->id == TypeTableEntryIdInt) {
-                    set_debug_source_node(g, node);
-                    return LLVMBuildNeg(g->builder, expr, "");
-                } else if (expr_type->id == TypeTableEntryIdFloat) {
-                    set_debug_source_node(g, node);
+                set_debug_source_node(g, node);
+                if (expr_type->id == TypeTableEntryIdFloat) {
                     return LLVMBuildFNeg(g->builder, expr, "");
+                } else if (expr_type->id == TypeTableEntryIdInt) {
+                    if (expr_type->data.integral.is_wrapping) {
+                        return LLVMBuildNeg(g->builder, expr, "");
+                    } else if (want_debug_safety(g, expr_node)) {
+                        LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(expr));
+                        return gen_overflow_op(g, expr_type, AddSubMulSub, zero, expr);
+                    } else if (expr_type->data.integral.is_signed) {
+                        return LLVMBuildNSWNeg(g->builder, expr, "");
+                    } else {
+                        return LLVMBuildNUWNeg(g->builder, expr, "");
+                    }
                 } else {
                     zig_unreachable();
                 }
@@ -1431,28 +1461,6 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
     zig_unreachable();
 }
 
-static LLVMValueRef gen_overflow_op(CodeGen *g, TypeTableEntry *type_entry, AddSubMul op,
-        LLVMValueRef val1, LLVMValueRef val2)
-{
-    LLVMValueRef fn_val = get_int_overflow_fn(g, type_entry, op);
-    LLVMValueRef params[] = {
-        val1,
-        val2,
-    };
-    LLVMValueRef result_struct = LLVMBuildCall(g->builder, fn_val, params, 2, "");
-    LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
-    LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
-    LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowFail");
-    LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "OverflowOk");
-    LLVMBuildCondBr(g->builder, overflow_bit, fail_block, ok_block);
-
-    LLVMPositionBuilderAtEnd(g->builder, fail_block);
-    gen_debug_safety_crash(g);
-
-    LLVMPositionBuilderAtEnd(g->builder, ok_block);
-    return result;
-}
-
 static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g, AstNode *source_node,
     LLVMValueRef val1, LLVMValueRef val2,
     TypeTableEntry *op1_type, TypeTableEntry *op2_type,
@@ -2727,7 +2735,7 @@ static LLVMValueRef gen_for_expr(CodeGen *g, AstNode *node) {
 
     LLVMPositionBuilderAtEnd(g->builder, continue_block);
     set_debug_source_node(g, node);
-    LLVMValueRef new_index_val = LLVMBuildAdd(g->builder, index_val, one_const, "");
+    LLVMValueRef new_index_val = LLVMBuildNSWAdd(g->builder, index_val, one_const, "");
     LLVMBuildStore(g->builder, new_index_val, index_ptr);
     LLVMBuildBr(g->builder, cond_block);
 
test/run_tests.cpp
@@ -1323,6 +1323,46 @@ fn bar(a: []i32) -> i32 {
 fn baz(a: i32) {}
     )SOURCE");
 
+    add_debug_safety_case("integer addition overflow", R"SOURCE(
+pub fn main(args: [][]u8) -> %void {
+    add(65530, 10);
+}
+#static_eval_enable(false)
+fn add(a: u16, b: u16) -> u16 {
+    a + b
+}
+    )SOURCE");
+
+    add_debug_safety_case("integer subtraction overflow", R"SOURCE(
+pub fn main(args: [][]u8) -> %void {
+    sub(10, 20);
+}
+#static_eval_enable(false)
+fn sub(a: u16, b: u16) -> u16 {
+    a - b
+}
+    )SOURCE");
+
+    add_debug_safety_case("integer multiplication overflow", R"SOURCE(
+pub fn main(args: [][]u8) -> %void {
+    mul(300, 6000);
+}
+#static_eval_enable(false)
+fn mul(a: u16, b: u16) -> u16 {
+    a * b
+}
+    )SOURCE");
+
+    add_debug_safety_case("integer negation overflow", R"SOURCE(
+pub fn main(args: [][]u8) -> %void {
+    neg(-32768);
+}
+#static_eval_enable(false)
+fn neg(a: i16) -> i16 {
+    -a
+}
+    )SOURCE");
+
 }
 
 //////////////////////////////////////////////////////////////////////////////