Commit 54c06bf715

Andrew Kelley <superjoe30@gmail.com>
2018-02-09 03:54:44
error sets: runtime safety for int-to-err and err set cast
1 parent 8fc6e31
src/codegen.cpp
@@ -1958,6 +1958,54 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
     zig_unreachable();
 }
 
+static void add_error_range_check(CodeGen *g, TypeTableEntry *err_set_type, TypeTableEntry *int_type, LLVMValueRef target_val) {
+    assert(err_set_type->id == TypeTableEntryIdErrorSet);
+
+    if (type_is_global_error_set(err_set_type)) {
+        LLVMValueRef zero = LLVMConstNull(int_type->type_ref);
+        LLVMValueRef neq_zero_bit = LLVMBuildICmp(g->builder, LLVMIntNE, target_val, zero, "");
+        LLVMValueRef ok_bit;
+
+        BigInt biggest_possible_err_val = {0};
+        eval_min_max_value_int(g, int_type, &biggest_possible_err_val, true);
+
+        if (bigint_fits_in_bits(&biggest_possible_err_val, 64, false) &&
+            bigint_as_unsigned(&biggest_possible_err_val) < g->errors_by_index.length)
+        {
+            ok_bit = neq_zero_bit;
+        } else {
+            LLVMValueRef error_value_count = LLVMConstInt(int_type->type_ref, g->errors_by_index.length, false);
+            LLVMValueRef in_bounds_bit = LLVMBuildICmp(g->builder, LLVMIntULT, target_val, error_value_count, "");
+            ok_bit = LLVMBuildAnd(g->builder, neq_zero_bit, in_bounds_bit, "");
+        }
+
+        LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
+        LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
+
+        LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
+
+        LLVMPositionBuilderAtEnd(g->builder, fail_block);
+        gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
+
+        LLVMPositionBuilderAtEnd(g->builder, ok_block);
+    } else {
+        LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
+        LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
+
+        uint32_t err_count = err_set_type->data.error_set.err_count;
+        LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, target_val, fail_block, err_count);
+        for (uint32_t i = 0; i < err_count; i += 1) {
+            LLVMValueRef case_value = LLVMConstInt(g->err_tag_type->type_ref, err_set_type->data.error_set.errors[i]->value, false);
+            LLVMAddCase(switch_instr, case_value, ok_block);
+        }
+
+        LLVMPositionBuilderAtEnd(g->builder, fail_block);
+        gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
+
+        LLVMPositionBuilderAtEnd(g->builder, ok_block);
+    }
+}
+
 static LLVMValueRef ir_render_cast(CodeGen *g, IrExecutable *executable,
         IrInstructionCast *cast_instruction)
 {
@@ -2082,7 +2130,9 @@ static LLVMValueRef ir_render_cast(CodeGen *g, IrExecutable *executable,
             assert(actual_type->id == TypeTableEntryIdBool);
             return LLVMBuildZExt(g->builder, expr_val, wanted_type->type_ref, "");
         case CastOpErrSet:
-            // TODO runtime safety for error casting
+            if (ir_want_runtime_safety(g, &cast_instruction->base)) {
+                add_error_range_check(g, wanted_type, g->err_tag_type, expr_val);
+            }
             return expr_val;
     }
     zig_unreachable();
@@ -2154,32 +2204,7 @@ static LLVMValueRef ir_render_int_to_err(CodeGen *g, IrExecutable *executable, I
     LLVMValueRef target_val = ir_llvm_value(g, instruction->target);
 
     if (ir_want_runtime_safety(g, &instruction->base)) {
-        LLVMValueRef zero = LLVMConstNull(actual_type->type_ref);
-        LLVMValueRef neq_zero_bit = LLVMBuildICmp(g->builder, LLVMIntNE, target_val, zero, "");
-        LLVMValueRef ok_bit;
-
-        BigInt biggest_possible_err_val = {0};
-        eval_min_max_value_int(g, actual_type, &biggest_possible_err_val, true);
-
-        if (bigint_fits_in_bits(&biggest_possible_err_val, 64, false) &&
-            bigint_as_unsigned(&biggest_possible_err_val) < g->errors_by_index.length)
-        {
-            ok_bit = neq_zero_bit;
-        } else {
-            LLVMValueRef error_value_count = LLVMConstInt(actual_type->type_ref, g->errors_by_index.length, false);
-            LLVMValueRef in_bounds_bit = LLVMBuildICmp(g->builder, LLVMIntULT, target_val, error_value_count, "");
-            ok_bit = LLVMBuildAnd(g->builder, neq_zero_bit, in_bounds_bit, "");
-        }
-
-        LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
-        LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
-
-        LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
-
-        LLVMPositionBuilderAtEnd(g->builder, fail_block);
-        gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
-
-        LLVMPositionBuilderAtEnd(g->builder, ok_block);
+        add_error_range_check(g, wanted_type, actual_type, target_val);
     }
 
     return gen_widen_or_shorten(g, false, actual_type, g->err_tag_type, target_val);
src/ir.cpp
@@ -8505,19 +8505,49 @@ static IrInstruction *ir_analyze_int_to_err(IrAnalyze *ira, IrInstruction *sourc
         IrInstruction *result = ir_create_const(&ira->new_irb, source_instr->scope,
                 source_instr->source_node, wanted_type);
 
-        BigInt err_count;
-        bigint_init_unsigned(&err_count, ira->codegen->errors_by_index.length);
-        if (bigint_cmp_zero(&val->data.x_bigint) == CmpEQ || bigint_cmp(&val->data.x_bigint, &err_count) != CmpLT) {
-            Buf *val_buf = buf_alloc();
-            bigint_append_buf(val_buf, &val->data.x_bigint, 10);
-            ir_add_error(ira, source_instr,
-                buf_sprintf("integer value %s represents no error", buf_ptr(val_buf)));
+        if (!resolve_inferred_error_set(ira, wanted_type, source_instr->source_node)) {
             return ira->codegen->invalid_instruction;
         }
 
-        size_t index = bigint_as_unsigned(&val->data.x_bigint);
-        result->value.data.x_err_set = ira->codegen->errors_by_index.at(index);
-        return result;
+        if (type_is_global_error_set(wanted_type)) {
+            BigInt err_count;
+            bigint_init_unsigned(&err_count, ira->codegen->errors_by_index.length);
+
+            if (bigint_cmp_zero(&val->data.x_bigint) == CmpEQ || bigint_cmp(&val->data.x_bigint, &err_count) != CmpLT) {
+                Buf *val_buf = buf_alloc();
+                bigint_append_buf(val_buf, &val->data.x_bigint, 10);
+                ir_add_error(ira, source_instr,
+                    buf_sprintf("integer value %s represents no error", buf_ptr(val_buf)));
+                return ira->codegen->invalid_instruction;
+            }
+
+            size_t index = bigint_as_unsigned(&val->data.x_bigint);
+            result->value.data.x_err_set = ira->codegen->errors_by_index.at(index);
+            return result;
+        } else {
+            ErrorTableEntry *err = nullptr;
+            BigInt err_int;
+
+            for (uint32_t i = 0, count = wanted_type->data.error_set.err_count; i < count; i += 1) {
+                ErrorTableEntry *this_err = wanted_type->data.error_set.errors[i];
+                bigint_init_unsigned(&err_int, this_err->value);
+                if (bigint_cmp(&val->data.x_bigint, &err_int) == CmpEQ) {
+                    err = this_err;
+                    break;
+                }
+            }
+
+            if (err == nullptr) {
+                Buf *val_buf = buf_alloc();
+                bigint_append_buf(val_buf, &val->data.x_bigint, 10);
+                ir_add_error(ira, source_instr,
+                    buf_sprintf("integer value %s represents no error in '%s'", buf_ptr(val_buf), buf_ptr(&wanted_type->name)));
+                return ira->codegen->invalid_instruction;
+            }
+
+            result->value.data.x_err_set = err;
+            return result;
+        }
     }
 
     IrInstruction *result = ir_build_int_to_err(&ira->new_irb, source_instr->scope, source_instr->source_node, target);
test/compile_errors.zig
@@ -1,6 +1,38 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: &tests.CompileErrorContext) void {
+    cases.add("implicit cast of error set not a subset",
+        \\const Set1 = error{A, B};
+        \\const Set2 = error{A, C};
+        \\export fn entry() void {
+        \\    foo(Set1.B);
+        \\}
+        \\fn foo(set1: Set1) void {
+        \\    var x: Set2 = set1;
+        \\}
+    ,
+        ".tmp_source.zig:7:19: error: expected 'Set2', found 'Set1'",
+        ".tmp_source.zig:1:23: note: 'error.B' not a member of destination error set");
+
+    cases.add("int to err global invalid number",
+        \\const Set1 = error{A, B};
+        \\comptime {
+        \\    var x: usize = 3;
+        \\    var y = error(x);
+        \\}
+    ,
+        ".tmp_source.zig:4:18: error: integer value 3 represents no error");
+
+    cases.add("int to err non global invalid number",
+        \\const Set1 = error{A, B};
+        \\const Set2 = error{A, C};
+        \\comptime {
+        \\    var x = usize(Set1.B);
+        \\    var y = Set2(x);
+        \\}
+    ,
+        ".tmp_source.zig:5:17: error: integer value 2 represents no error in 'Set2'");
+
     cases.add("@memberCount of error",
         \\comptime {
         \\    _ = @memberCount(error);
test/runtime_safety.zig
@@ -220,7 +220,7 @@ pub fn addCases(cases: &tests.CompareOutputContext) void {
         \\}
     );
 
-    cases.addRuntimeSafety("cast integer to error and no code matches",
+    cases.addRuntimeSafety("cast integer to global error and no code matches",
         \\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
         \\    @import("std").os.exit(126);
         \\}
@@ -232,6 +232,20 @@ pub fn addCases(cases: &tests.CompareOutputContext) void {
         \\}
     );
 
+    cases.addRuntimeSafety("cast integer to non-global error set and no match",
+        \\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\const Set1 = error{A, B};
+        \\const Set2 = error{A, C};
+        \\pub fn main() void {
+        \\    _ = foo(Set1.B);
+        \\}
+        \\fn foo(set1: Set1) Set2 {
+        \\    return Set2(set1);
+        \\}
+    );
+
     cases.addRuntimeSafety("@alignCast misaligned",
         \\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
         \\    @import("std").os.exit(126);
TODO
@@ -17,10 +17,6 @@ you can get the compiler to tell you the possible errors for an inferred error s
 
 foo() catch |err| switch (err) {};
 
-// TODO this is an explicit cast and should actually coerce the type
-   erorr set casting
-   // add a runtime safety check
-
 
 test err should be comptime if error set has 0 members