Commit 2fc34eaa58

Jimmi Holst Christensen <jimmiholstchristensen@gmail.com>
2018-04-28 16:27:31
Functions with infered error set can now return literals fixes #852
1 parent 3178528
Changed files (3)
src/analyze.cpp
@@ -6131,4 +6131,3 @@ bool type_can_fail(TypeTableEntry *type_entry) {
 bool fn_type_can_fail(FnTypeId *fn_type_id) {
     return type_can_fail(fn_type_id->return_type) || fn_type_id->cc == CallingConventionAsync;
 }
-
src/ir.cpp
@@ -8111,7 +8111,7 @@ static void update_errors_helper(CodeGen *g, ErrorTableEntry ***errors, size_t *
     *errors = reallocate(*errors, old_errors_count, *errors_count);
 }
 
-static TypeTableEntry *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, IrInstruction **instructions, size_t instruction_count) {
+static TypeTableEntry *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, TypeTableEntry *expected_type, IrInstruction **instructions, size_t instruction_count) {
     assert(instruction_count >= 1);
     IrInstruction *prev_inst = instructions[0];
     if (type_is_invalid(prev_inst->value.type)) {
@@ -8158,16 +8158,6 @@ static TypeTableEntry *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_nod
             continue;
         }
 
-        if (prev_type->id == TypeTableEntryIdNullLit) {
-            prev_inst = cur_inst;
-            continue;
-        }
-
-        if (cur_type->id == TypeTableEntryIdNullLit) {
-            any_are_null = true;
-            continue;
-        }
-
         if (prev_type->id == TypeTableEntryIdErrorSet) {
             assert(err_set_type != nullptr);
             if (cur_type->id == TypeTableEntryIdErrorSet) {
@@ -8427,6 +8417,16 @@ static TypeTableEntry *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_nod
             }
         }
 
+        if (prev_type->id == TypeTableEntryIdNullLit) {
+            prev_inst = cur_inst;
+            continue;
+        }
+
+        if (cur_type->id == TypeTableEntryIdNullLit) {
+            any_are_null = true;
+            continue;
+        }
+
         if (types_match_const_cast_only(ira, prev_type, cur_type, source_node).id == ConstCastResultIdOk) {
             continue;
         }
@@ -8610,6 +8610,10 @@ static TypeTableEntry *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_nod
     } else if (err_set_type != nullptr) {
         if (prev_inst->value.type->id == TypeTableEntryIdErrorSet) {
             return err_set_type;
+        } else if (prev_inst->value.type->id == TypeTableEntryIdErrorUnion) {
+            return get_error_union_type(ira->codegen, err_set_type, prev_inst->value.type->data.error_union.payload_type);
+        } else if (expected_type != nullptr && expected_type->id == TypeTableEntryIdErrorUnion) {
+            return get_error_union_type(ira->codegen, err_set_type, expected_type->data.error_union.payload_type);
         } else {
             if (prev_inst->value.type->id == TypeTableEntryIdNumLitInt ||
                 prev_inst->value.type->id == TypeTableEntryIdNumLitFloat)
@@ -8621,8 +8625,6 @@ static TypeTableEntry *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_nod
                 ir_add_error_node(ira, source_node,
                     buf_sprintf("unable to make error union out of null literal"));
                 return ira->codegen->builtin_types.entry_invalid;
-            } else if (prev_inst->value.type->id == TypeTableEntryIdErrorUnion) {
-                return get_error_union_type(ira->codegen, err_set_type, prev_inst->value.type->data.error_union.payload_type);
             } else {
                 return get_error_union_type(ira->codegen, err_set_type, prev_inst->value.type);
             }
@@ -10645,7 +10647,7 @@ static TypeTableEntry *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp
     }
 
     IrInstruction *instructions[] = {op1, op2};
-    TypeTableEntry *resolved_type = ir_resolve_peer_types(ira, source_node, instructions, 2);
+    TypeTableEntry *resolved_type = ir_resolve_peer_types(ira, source_node, nullptr, instructions, 2);
     if (type_is_invalid(resolved_type))
         return resolved_type;
     type_ensure_zero_bits_known(ira->codegen, resolved_type);
@@ -11035,7 +11037,7 @@ static TypeTableEntry *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstructionBinOp
     IrInstruction *op1 = bin_op_instruction->op1->other;
     IrInstruction *op2 = bin_op_instruction->op2->other;
     IrInstruction *instructions[] = {op1, op2};
-    TypeTableEntry *resolved_type = ir_resolve_peer_types(ira, bin_op_instruction->base.source_node, instructions, 2);
+    TypeTableEntry *resolved_type = ir_resolve_peer_types(ira, bin_op_instruction->base.source_node, nullptr, instructions, 2);
     if (type_is_invalid(resolved_type))
         return resolved_type;
     IrBinOp op_id = bin_op_instruction->op_id;
@@ -13004,7 +13006,7 @@ static TypeTableEntry *ir_analyze_instruction_phi(IrAnalyze *ira, IrInstructionP
         return first_value->value.type;
     }
 
-    TypeTableEntry *resolved_type = ir_resolve_peer_types(ira, phi_instruction->base.source_node,
+    TypeTableEntry *resolved_type = ir_resolve_peer_types(ira, phi_instruction->base.source_node, nullptr,
             new_incoming_values.items, new_incoming_values.length);
     if (type_is_invalid(resolved_type))
         return resolved_type;
@@ -18696,7 +18698,7 @@ TypeTableEntry *ir_analyze(CodeGen *codegen, IrExecutable *old_exec, IrExecutabl
     } else if (ira->src_implicit_return_type_list.length == 0) {
         return codegen->builtin_types.entry_unreachable;
     } else {
-        return ir_resolve_peer_types(ira, expected_type_source_node, ira->src_implicit_return_type_list.items,
+        return ir_resolve_peer_types(ira, expected_type_source_node, expected_type, ira->src_implicit_return_type_list.items,
                 ira->src_implicit_return_type_list.length);
     }
 }
test/cases/error.zig
@@ -202,3 +202,42 @@ const Error = error{};
 fn foo3(b: usize) Error!usize {
     return b;
 }
+
+
+test "error: Infer error set from literals" {
+    _ = nullLiteral("n") catch |err| handleErrors(err);
+    _ = floatLiteral("n") catch |err| handleErrors(err);
+    _ = intLiteral("n") catch |err| handleErrors(err);
+    _ = comptime nullLiteral("n") catch |err| handleErrors(err);
+    _ = comptime floatLiteral("n") catch |err| handleErrors(err);
+    _ = comptime intLiteral("n") catch |err| handleErrors(err);
+}
+
+fn handleErrors(err: var) noreturn {
+    switch (err) {
+        error.T => {}
+    }
+
+    unreachable;
+}
+
+fn nullLiteral(str: []const u8) !?i64 {
+    if (str[0] == 'n')
+        return null;
+
+    return error.T;
+}
+
+fn floatLiteral(str: []const u8) !?f64 {
+    if (str[0] == 'n')
+        return 1.0;
+
+    return error.T;
+}
+
+fn intLiteral(str: []const u8) !?i64 {
+    if (str[0] == 'n')
+        return 1;
+
+    return error.T;
+}