Commit 54c06bf715
Changed files (5)
src/codegen.cpp
@@ -1958,6 +1958,54 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
zig_unreachable();
}
+static void add_error_range_check(CodeGen *g, TypeTableEntry *err_set_type, TypeTableEntry *int_type, LLVMValueRef target_val) {
+ assert(err_set_type->id == TypeTableEntryIdErrorSet);
+
+ if (type_is_global_error_set(err_set_type)) {
+ LLVMValueRef zero = LLVMConstNull(int_type->type_ref);
+ LLVMValueRef neq_zero_bit = LLVMBuildICmp(g->builder, LLVMIntNE, target_val, zero, "");
+ LLVMValueRef ok_bit;
+
+ BigInt biggest_possible_err_val = {0};
+ eval_min_max_value_int(g, int_type, &biggest_possible_err_val, true);
+
+ if (bigint_fits_in_bits(&biggest_possible_err_val, 64, false) &&
+ bigint_as_unsigned(&biggest_possible_err_val) < g->errors_by_index.length)
+ {
+ ok_bit = neq_zero_bit;
+ } else {
+ LLVMValueRef error_value_count = LLVMConstInt(int_type->type_ref, g->errors_by_index.length, false);
+ LLVMValueRef in_bounds_bit = LLVMBuildICmp(g->builder, LLVMIntULT, target_val, error_value_count, "");
+ ok_bit = LLVMBuildAnd(g->builder, neq_zero_bit, in_bounds_bit, "");
+ }
+
+ LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
+ LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
+
+ LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
+
+ LLVMPositionBuilderAtEnd(g->builder, fail_block);
+ gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
+
+ LLVMPositionBuilderAtEnd(g->builder, ok_block);
+ } else {
+ LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
+ LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
+
+ uint32_t err_count = err_set_type->data.error_set.err_count;
+ LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, target_val, fail_block, err_count);
+ for (uint32_t i = 0; i < err_count; i += 1) {
+ LLVMValueRef case_value = LLVMConstInt(g->err_tag_type->type_ref, err_set_type->data.error_set.errors[i]->value, false);
+ LLVMAddCase(switch_instr, case_value, ok_block);
+ }
+
+ LLVMPositionBuilderAtEnd(g->builder, fail_block);
+ gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
+
+ LLVMPositionBuilderAtEnd(g->builder, ok_block);
+ }
+}
+
static LLVMValueRef ir_render_cast(CodeGen *g, IrExecutable *executable,
IrInstructionCast *cast_instruction)
{
@@ -2082,7 +2130,9 @@ static LLVMValueRef ir_render_cast(CodeGen *g, IrExecutable *executable,
assert(actual_type->id == TypeTableEntryIdBool);
return LLVMBuildZExt(g->builder, expr_val, wanted_type->type_ref, "");
case CastOpErrSet:
- // TODO runtime safety for error casting
+ if (ir_want_runtime_safety(g, &cast_instruction->base)) {
+ add_error_range_check(g, wanted_type, g->err_tag_type, expr_val);
+ }
return expr_val;
}
zig_unreachable();
@@ -2154,32 +2204,7 @@ static LLVMValueRef ir_render_int_to_err(CodeGen *g, IrExecutable *executable, I
LLVMValueRef target_val = ir_llvm_value(g, instruction->target);
if (ir_want_runtime_safety(g, &instruction->base)) {
- LLVMValueRef zero = LLVMConstNull(actual_type->type_ref);
- LLVMValueRef neq_zero_bit = LLVMBuildICmp(g->builder, LLVMIntNE, target_val, zero, "");
- LLVMValueRef ok_bit;
-
- BigInt biggest_possible_err_val = {0};
- eval_min_max_value_int(g, actual_type, &biggest_possible_err_val, true);
-
- if (bigint_fits_in_bits(&biggest_possible_err_val, 64, false) &&
- bigint_as_unsigned(&biggest_possible_err_val) < g->errors_by_index.length)
- {
- ok_bit = neq_zero_bit;
- } else {
- LLVMValueRef error_value_count = LLVMConstInt(actual_type->type_ref, g->errors_by_index.length, false);
- LLVMValueRef in_bounds_bit = LLVMBuildICmp(g->builder, LLVMIntULT, target_val, error_value_count, "");
- ok_bit = LLVMBuildAnd(g->builder, neq_zero_bit, in_bounds_bit, "");
- }
-
- LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrOk");
- LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "IntToErrFail");
-
- LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
-
- LLVMPositionBuilderAtEnd(g->builder, fail_block);
- gen_safety_crash(g, PanicMsgIdInvalidErrorCode);
-
- LLVMPositionBuilderAtEnd(g->builder, ok_block);
+ add_error_range_check(g, wanted_type, actual_type, target_val);
}
return gen_widen_or_shorten(g, false, actual_type, g->err_tag_type, target_val);
src/ir.cpp
@@ -8505,19 +8505,49 @@ static IrInstruction *ir_analyze_int_to_err(IrAnalyze *ira, IrInstruction *sourc
IrInstruction *result = ir_create_const(&ira->new_irb, source_instr->scope,
source_instr->source_node, wanted_type);
- BigInt err_count;
- bigint_init_unsigned(&err_count, ira->codegen->errors_by_index.length);
- if (bigint_cmp_zero(&val->data.x_bigint) == CmpEQ || bigint_cmp(&val->data.x_bigint, &err_count) != CmpLT) {
- Buf *val_buf = buf_alloc();
- bigint_append_buf(val_buf, &val->data.x_bigint, 10);
- ir_add_error(ira, source_instr,
- buf_sprintf("integer value %s represents no error", buf_ptr(val_buf)));
+ if (!resolve_inferred_error_set(ira, wanted_type, source_instr->source_node)) {
return ira->codegen->invalid_instruction;
}
- size_t index = bigint_as_unsigned(&val->data.x_bigint);
- result->value.data.x_err_set = ira->codegen->errors_by_index.at(index);
- return result;
+ if (type_is_global_error_set(wanted_type)) {
+ BigInt err_count;
+ bigint_init_unsigned(&err_count, ira->codegen->errors_by_index.length);
+
+ if (bigint_cmp_zero(&val->data.x_bigint) == CmpEQ || bigint_cmp(&val->data.x_bigint, &err_count) != CmpLT) {
+ Buf *val_buf = buf_alloc();
+ bigint_append_buf(val_buf, &val->data.x_bigint, 10);
+ ir_add_error(ira, source_instr,
+ buf_sprintf("integer value %s represents no error", buf_ptr(val_buf)));
+ return ira->codegen->invalid_instruction;
+ }
+
+ size_t index = bigint_as_unsigned(&val->data.x_bigint);
+ result->value.data.x_err_set = ira->codegen->errors_by_index.at(index);
+ return result;
+ } else {
+ ErrorTableEntry *err = nullptr;
+ BigInt err_int;
+
+ for (uint32_t i = 0, count = wanted_type->data.error_set.err_count; i < count; i += 1) {
+ ErrorTableEntry *this_err = wanted_type->data.error_set.errors[i];
+ bigint_init_unsigned(&err_int, this_err->value);
+ if (bigint_cmp(&val->data.x_bigint, &err_int) == CmpEQ) {
+ err = this_err;
+ break;
+ }
+ }
+
+ if (err == nullptr) {
+ Buf *val_buf = buf_alloc();
+ bigint_append_buf(val_buf, &val->data.x_bigint, 10);
+ ir_add_error(ira, source_instr,
+ buf_sprintf("integer value %s represents no error in '%s'", buf_ptr(val_buf), buf_ptr(&wanted_type->name)));
+ return ira->codegen->invalid_instruction;
+ }
+
+ result->value.data.x_err_set = err;
+ return result;
+ }
}
IrInstruction *result = ir_build_int_to_err(&ira->new_irb, source_instr->scope, source_instr->source_node, target);
test/compile_errors.zig
@@ -1,6 +1,38 @@
const tests = @import("tests.zig");
pub fn addCases(cases: &tests.CompileErrorContext) void {
+ cases.add("implicit cast of error set not a subset",
+ \\const Set1 = error{A, B};
+ \\const Set2 = error{A, C};
+ \\export fn entry() void {
+ \\ foo(Set1.B);
+ \\}
+ \\fn foo(set1: Set1) void {
+ \\ var x: Set2 = set1;
+ \\}
+ ,
+ ".tmp_source.zig:7:19: error: expected 'Set2', found 'Set1'",
+ ".tmp_source.zig:1:23: note: 'error.B' not a member of destination error set");
+
+ cases.add("int to err global invalid number",
+ \\const Set1 = error{A, B};
+ \\comptime {
+ \\ var x: usize = 3;
+ \\ var y = error(x);
+ \\}
+ ,
+ ".tmp_source.zig:4:18: error: integer value 3 represents no error");
+
+ cases.add("int to err non global invalid number",
+ \\const Set1 = error{A, B};
+ \\const Set2 = error{A, C};
+ \\comptime {
+ \\ var x = usize(Set1.B);
+ \\ var y = Set2(x);
+ \\}
+ ,
+ ".tmp_source.zig:5:17: error: integer value 2 represents no error in 'Set2'");
+
cases.add("@memberCount of error",
\\comptime {
\\ _ = @memberCount(error);
test/runtime_safety.zig
@@ -220,7 +220,7 @@ pub fn addCases(cases: &tests.CompareOutputContext) void {
\\}
);
- cases.addRuntimeSafety("cast integer to error and no code matches",
+ cases.addRuntimeSafety("cast integer to global error and no code matches",
\\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
\\ @import("std").os.exit(126);
\\}
@@ -232,6 +232,20 @@ pub fn addCases(cases: &tests.CompareOutputContext) void {
\\}
);
+ cases.addRuntimeSafety("cast integer to non-global error set and no match",
+ \\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
+ \\ @import("std").os.exit(126);
+ \\}
+ \\const Set1 = error{A, B};
+ \\const Set2 = error{A, C};
+ \\pub fn main() void {
+ \\ _ = foo(Set1.B);
+ \\}
+ \\fn foo(set1: Set1) Set2 {
+ \\ return Set2(set1);
+ \\}
+ );
+
cases.addRuntimeSafety("@alignCast misaligned",
\\pub fn panic(message: []const u8, stack_trace: ?&@import("builtin").StackTrace) noreturn {
\\ @import("std").os.exit(126);
TODO
@@ -17,10 +17,6 @@ you can get the compiler to tell you the possible errors for an inferred error s
foo() catch |err| switch (err) {};
-// TODO this is an explicit cast and should actually coerce the type
- erorr set casting
- // add a runtime safety check
-
test err should be comptime if error set has 0 members