Commit 63d37b7cff

Andrew Kelley <superjoe30@gmail.com>
2017-02-14 07:08:30
add runtime debug safety for dividing integer min value by -1
closes #260
1 parent 0931b85
src/analyze.cpp
@@ -3447,7 +3447,7 @@ static int64_t max_signed_val(TypeTableEntry *type_entry) {
     }
 }
 
-static int64_t min_signed_val(TypeTableEntry *type_entry) {
+int64_t min_signed_val(TypeTableEntry *type_entry) {
     assert(type_entry->id == TypeTableEntryIdInt);
     if (type_entry->data.integral.bit_count == 64) {
         return INT64_MIN;
src/analyze.hpp
@@ -81,6 +81,7 @@ void complete_enum(CodeGen *g, TypeTableEntry *enum_type);
 bool ir_get_var_is_comptime(VariableTableEntry *var);
 bool const_values_equal(ConstExprValue *a, ConstExprValue *b);
 void eval_min_max_value(CodeGen *g, TypeTableEntry *type_entry, ConstExprValue *const_val, bool is_max);
+int64_t min_signed_val(TypeTableEntry *type_entry);
 
 void render_const_value(Buf *buf, ConstExprValue *const_val);
 void define_local_param_variables(CodeGen *g, FnTableEntry *fn_table_entry, VariableTableEntry **arg_vars);
src/codegen.cpp
@@ -846,14 +846,30 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_debug_safety, LLVMValueRef val
         } else {
             zig_unreachable();
         }
-        LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroOk");
-        LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
-        LLVMBuildCondBr(g->builder, is_zero_bit, fail_block, ok_block);
+        LLVMBasicBlockRef div_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroOk");
+        LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
+        LLVMBuildCondBr(g->builder, is_zero_bit, div_zero_fail_block, div_zero_ok_block);
 
-        LLVMPositionBuilderAtEnd(g->builder, fail_block);
+        LLVMPositionBuilderAtEnd(g->builder, div_zero_fail_block);
         gen_debug_safety_crash(g, PanicMsgIdDivisionByZero);
 
-        LLVMPositionBuilderAtEnd(g->builder, ok_block);
+        LLVMPositionBuilderAtEnd(g->builder, div_zero_ok_block);
+
+        if (type_entry->id == TypeTableEntryIdInt && type_entry->data.integral.is_signed) {
+            LLVMValueRef neg_1_value = LLVMConstInt(type_entry->type_ref, -1, true);
+            LLVMValueRef int_min_value = LLVMConstInt(type_entry->type_ref, min_signed_val(type_entry), true);
+            LLVMBasicBlockRef overflow_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowOk");
+            LLVMBasicBlockRef overflow_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivOverflowFail");
+            LLVMValueRef num_is_int_min = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, int_min_value, "");
+            LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
+            LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
+            LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);
+
+            LLVMPositionBuilderAtEnd(g->builder, overflow_fail_block);
+            gen_debug_safety_crash(g, PanicMsgIdIntegerOverflow);
+
+            LLVMPositionBuilderAtEnd(g->builder, overflow_ok_block);
+        }
     }
 
     if (type_entry->id == TypeTableEntryIdFloat) {
test/run_tests.cpp
@@ -1700,13 +1700,28 @@ pub fn panic(message: []const u8) -> unreachable {
 error Whatever;
 pub fn main(args: [][]u8) -> %void {
     const x = neg(-32768);
-    if (x == 0) return error.Whatever;
+    if (x == 32767) return error.Whatever;
 }
 fn neg(a: i16) -> i16 {
     -a
 }
     )SOURCE");
 
+    add_debug_safety_case("signed integer division overflow", R"SOURCE(
+pub fn panic(message: []const u8) -> unreachable {
+    @breakpoint();
+    while (true) {}
+}
+error Whatever;
+pub fn main(args: [][]u8) -> %void {
+    const x = div(-32768, -1);
+    if (x == 32767) return error.Whatever;
+}
+fn div(a: i16, b: i16) -> i16 {
+    a / b
+}
+    )SOURCE");
+
     add_debug_safety_case("signed shift left overflow", R"SOURCE(
 pub fn panic(message: []const u8) -> unreachable {
     @breakpoint();