Commit df4f77024e

Andrew Kelley <andrew@ziglang.org>
2019-05-15 00:06:02
else value when switching on error set has
optional capture value which is subset. see #769
1 parent 6536b40
Changed files (5)
src/all_types.hpp
@@ -2149,6 +2149,7 @@ enum IrInstructionId {
     IrInstructionIdCondBr,
     IrInstructionIdSwitchBr,
     IrInstructionIdSwitchVar,
+    IrInstructionIdSwitchElseVar,
     IrInstructionIdSwitchTarget,
     IrInstructionIdPhi,
     IrInstructionIdUnOp,
@@ -2372,6 +2373,13 @@ struct IrInstructionSwitchVar {
     IrInstruction *prong_value;
 };
 
+struct IrInstructionSwitchElseVar {
+    IrInstruction base;
+
+    IrInstruction *target_value_ptr;
+    IrInstructionSwitchBr *switch_br;
+};
+
 struct IrInstructionSwitchTarget {
     IrInstruction base;
 
src/codegen.cpp
@@ -5572,6 +5572,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
         case IrInstructionIdTypeName:
         case IrInstructionIdDeclRef:
         case IrInstructionIdSwitchVar:
+        case IrInstructionIdSwitchElseVar:
         case IrInstructionIdByteOffsetOf:
         case IrInstructionIdBitOffsetOf:
         case IrInstructionIdTypeInfo:
src/ir.cpp
@@ -419,6 +419,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionSwitchVar *) {
     return IrInstructionIdSwitchVar;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionSwitchElseVar *) {
+    return IrInstructionIdSwitchElseVar;
+}
+
 static constexpr IrInstructionId ir_instruction_id(IrInstructionSwitchTarget *) {
     return IrInstructionIdSwitchTarget;
 }
@@ -1791,7 +1795,7 @@ static IrInstruction *ir_build_pop_count(IrBuilder *irb, Scope *scope, AstNode *
     return &instruction->base;
 }
 
-static IrInstruction *ir_build_switch_br(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *target_value,
+static IrInstructionSwitchBr *ir_build_switch_br(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *target_value,
         IrBasicBlock *else_block, size_t case_count, IrInstructionSwitchBrCase *cases, IrInstruction *is_comptime,
         IrInstruction *switch_prongs_void)
 {
@@ -1815,7 +1819,7 @@ static IrInstruction *ir_build_switch_br(IrBuilder *irb, Scope *scope, AstNode *
         ir_ref_bb(cases[i].block);
     }
 
-    return &instruction->base;
+    return instruction;
 }
 
 static IrInstruction *ir_build_switch_target(IrBuilder *irb, Scope *scope, AstNode *source_node,
@@ -1842,6 +1846,18 @@ static IrInstruction *ir_build_switch_var(IrBuilder *irb, Scope *scope, AstNode
     return &instruction->base;
 }
 
+// For this instruction the switch_br must be set later.
+static IrInstructionSwitchElseVar *ir_build_switch_else_var(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        IrInstruction *target_value_ptr)
+{
+    IrInstructionSwitchElseVar *instruction = ir_build_instruction<IrInstructionSwitchElseVar>(irb, scope, source_node);
+    instruction->target_value_ptr = target_value_ptr;
+
+    ir_ref_instruction(target_value_ptr, irb->current_basic_block);
+
+    return instruction;
+}
+
 static IrInstruction *ir_build_union_tag(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *value) {
     IrInstructionUnionTag *instruction = ir_build_instruction<IrInstructionUnionTag>(irb, scope, source_node);
     instruction->value = value;
@@ -6294,7 +6310,8 @@ static IrInstruction *ir_gen_if_err_expr(IrBuilder *irb, Scope *scope, AstNode *
 static bool ir_gen_switch_prong_expr(IrBuilder *irb, Scope *scope, AstNode *switch_node, AstNode *prong_node,
         IrBasicBlock *end_block, IrInstruction *is_comptime, IrInstruction *var_is_comptime,
         IrInstruction *target_value_ptr, IrInstruction *prong_value,
-        ZigList<IrBasicBlock *> *incoming_blocks, ZigList<IrInstruction *> *incoming_values)
+        ZigList<IrBasicBlock *> *incoming_blocks, ZigList<IrInstruction *> *incoming_values,
+        IrInstructionSwitchElseVar **out_switch_else_var)
 {
     assert(switch_node->type == NodeTypeSwitchExpr);
     assert(prong_node->type == NodeTypeSwitchProng);
@@ -6312,13 +6329,17 @@ static bool ir_gen_switch_prong_expr(IrBuilder *irb, Scope *scope, AstNode *swit
         ZigVar *var = ir_create_var(irb, var_symbol_node, scope,
                 var_name, is_const, is_const, is_shadowable, var_is_comptime);
         child_scope = var->child_scope;
-        IrInstruction *var_value;
-        if (prong_value) {
-            IrInstruction *var_ptr_value = ir_build_switch_var(irb, scope, var_symbol_node, target_value_ptr, prong_value);
-            var_value = var_is_ptr ? var_ptr_value : ir_build_load_ptr(irb, scope, var_symbol_node, var_ptr_value);
+        IrInstruction *var_ptr_value;
+        if (prong_value != nullptr) {
+            var_ptr_value = ir_build_switch_var(irb, scope, var_symbol_node, target_value_ptr, prong_value);
         } else {
-            var_value = var_is_ptr ? target_value_ptr : ir_build_load_ptr(irb, scope, var_symbol_node, target_value_ptr);
+            IrInstructionSwitchElseVar *switch_else_var = ir_build_switch_else_var(irb, scope, var_symbol_node,
+                    target_value_ptr);
+            *out_switch_else_var = switch_else_var;
+            var_ptr_value = &switch_else_var->base;
         }
+        IrInstruction *var_value = var_is_ptr ?
+            var_ptr_value : ir_build_load_ptr(irb, scope, var_symbol_node, var_ptr_value);
         IrInstruction *var_type = nullptr; // infer the type
         ir_build_var_decl_src(irb, scope, var_symbol_node, var, var_type, nullptr, var_value);
     } else {
@@ -6364,6 +6385,8 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
     ZigList<IrBasicBlock *> incoming_blocks = {0};
     ZigList<IrInstructionCheckSwitchProngsRange> check_ranges = {0};
 
+    IrInstructionSwitchElseVar *switch_else_var = nullptr;
+
     // First do the else and the ranges
     Scope *subexpr_scope = create_runtime_scope(irb->codegen, node, scope, is_comptime);
     Scope *comptime_scope = create_comptime_scope(irb->codegen, node, scope);
@@ -6384,7 +6407,8 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
             IrBasicBlock *prev_block = irb->current_basic_block;
             ir_set_cursor_at_end_and_append_block(irb, else_block);
             if (!ir_gen_switch_prong_expr(irb, subexpr_scope, node, prong_node, end_block,
-                is_comptime, var_is_comptime, target_value_ptr, nullptr, &incoming_blocks, &incoming_values))
+                is_comptime, var_is_comptime, target_value_ptr, nullptr, &incoming_blocks, &incoming_values,
+                &switch_else_var))
             {
                 return irb->codegen->invalid_instruction;
             }
@@ -6451,7 +6475,7 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
 
             ir_set_cursor_at_end_and_append_block(irb, range_block_yes);
             if (!ir_gen_switch_prong_expr(irb, subexpr_scope, node, prong_node, end_block,
-                is_comptime, var_is_comptime, target_value_ptr, nullptr, &incoming_blocks, &incoming_values))
+                is_comptime, var_is_comptime, target_value_ptr, nullptr, &incoming_blocks, &incoming_values, nullptr))
             {
                 return irb->codegen->invalid_instruction;
             }
@@ -6495,7 +6519,8 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
         IrBasicBlock *prev_block = irb->current_basic_block;
         ir_set_cursor_at_end_and_append_block(irb, prong_block);
         if (!ir_gen_switch_prong_expr(irb, subexpr_scope, node, prong_node, end_block,
-            is_comptime, var_is_comptime, target_value_ptr, only_item_value, &incoming_blocks, &incoming_values))
+            is_comptime, var_is_comptime, target_value_ptr, only_item_value, &incoming_blocks, &incoming_values,
+            nullptr))
         {
             return irb->codegen->invalid_instruction;
         }
@@ -6510,7 +6535,11 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
     if (cases.length == 0) {
         ir_build_br(irb, scope, node, else_block, is_comptime);
     } else {
-        ir_build_switch_br(irb, scope, node, target_value, else_block, cases.length, cases.items, is_comptime, switch_prongs_void);
+        IrInstructionSwitchBr *switch_br = ir_build_switch_br(irb, scope, node, target_value, else_block,
+                cases.length, cases.items, is_comptime, switch_prongs_void);
+        if (switch_else_var != nullptr) {
+            switch_else_var->switch_br = switch_br;
+        }
     }
 
     if (!else_prong) {
@@ -7474,8 +7503,9 @@ static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNod
     cases[0].block = resume_block;
     cases[1].value = ir_mark_gen(ir_build_const_u8(irb, parent_scope, node, 1));
     cases[1].block = canceled_block;
-    ir_mark_gen(ir_build_switch_br(irb, parent_scope, node, suspend_code, irb->exec->coro_suspend_block,
-            2, cases, const_bool_false, nullptr));
+    IrInstructionSwitchBr *switch_br = ir_build_switch_br(irb, parent_scope, node, suspend_code,
+            irb->exec->coro_suspend_block, 2, cases, const_bool_false, nullptr);
+    ir_mark_gen(&switch_br->base);
 
     ir_set_cursor_at_end_and_append_block(irb, cleanup_block);
     IrBasicBlock **incoming_blocks = allocate<IrBasicBlock *>(2);
@@ -8972,6 +9002,15 @@ static bool slice_is_const(ZigType *type) {
     return type->data.structure.fields[slice_ptr_index].type_entry->data.pointer.is_const;
 }
 
+static void populate_error_set_table(ErrorTableEntry **errors, ZigType *set) {
+    assert(set->id == ZigTypeIdErrorSet);
+    for (uint32_t i = 0; i < set->data.error_set.err_count; i += 1) {
+        ErrorTableEntry *error_entry = set->data.error_set.errors[i];
+        assert(errors[error_entry->value] == nullptr);
+        errors[error_entry->value] = error_entry;
+    }
+}
+
 static ZigType *get_error_set_intersection(IrAnalyze *ira, ZigType *set1, ZigType *set2,
         AstNode *source_node)
 {
@@ -8991,11 +9030,7 @@ static ZigType *get_error_set_intersection(IrAnalyze *ira, ZigType *set1, ZigTyp
         return set1;
     }
     ErrorTableEntry **errors = allocate<ErrorTableEntry *>(ira->codegen->errors_by_index.length);
-    for (uint32_t i = 0; i < set1->data.error_set.err_count; i += 1) {
-        ErrorTableEntry *error_entry = set1->data.error_set.errors[i];
-        assert(errors[error_entry->value] == nullptr);
-        errors[error_entry->value] = error_entry;
-    }
+    populate_error_set_table(errors, set1);
     ZigList<ErrorTableEntry *> intersection_list = {};
 
     ZigType *err_set_type = new_type_table_entry(ZigTypeIdErrorSet);
@@ -10410,6 +10445,24 @@ ConstExprValue *ir_eval_const_value(CodeGen *codegen, Scope *scope, AstNode *nod
     return ir_exec_const_result(codegen, analyzed_executable);
 }
 
+static ErrorTableEntry *ir_resolve_error(IrAnalyze *ira, IrInstruction *err_value) {
+    if (type_is_invalid(err_value->value.type))
+        return nullptr;
+
+    if (err_value->value.type->id != ZigTypeIdErrorSet) {
+        ir_add_error(ira, err_value,
+                buf_sprintf("expected error, found '%s'", buf_ptr(&err_value->value.type->name)));
+        return nullptr;
+    }
+
+    ConstExprValue *const_val = ir_resolve_const(ira, err_value, UndefBad);
+    if (!const_val)
+        return nullptr;
+
+    assert(const_val->data.x_err_set != nullptr);
+    return const_val->data.x_err_set;
+}
+
 static ZigType *ir_resolve_type(IrAnalyze *ira, IrInstruction *type_value) {
     if (type_is_invalid(type_value->value.type))
         return ira->codegen->builtin_types.entry_invalid;
@@ -17229,11 +17282,11 @@ static IrInstruction *ir_analyze_instruction_switch_br(IrAnalyze *ira,
     }
 
     IrBasicBlock *new_else_block = ir_get_new_bb(ira, switch_br_instruction->else_block, &switch_br_instruction->base);
-    IrInstruction *result = ir_build_switch_br(&ira->new_irb,
+    IrInstructionSwitchBr *switch_br = ir_build_switch_br(&ira->new_irb,
         switch_br_instruction->base.scope, switch_br_instruction->base.source_node,
         target_value, new_else_block, case_count, cases, nullptr, nullptr);
-    result->value.type = ira->codegen->builtin_types.entry_unreachable;
-    return ir_finish_anal(ira, result);
+    switch_br->base.value.type = ira->codegen->builtin_types.entry_unreachable;
+    return ir_finish_anal(ira, &switch_br->base);
 }
 
 static IrInstruction *ir_analyze_instruction_switch_target(IrAnalyze *ira,
@@ -17419,6 +17472,85 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
     }
 }
 
+static IrInstruction *ir_analyze_instruction_switch_else_var(IrAnalyze *ira,
+        IrInstructionSwitchElseVar *instruction)
+{
+    IrInstruction *target_value_ptr = instruction->target_value_ptr->child;
+    if (type_is_invalid(target_value_ptr->value.type))
+        return ira->codegen->invalid_instruction;
+
+    ZigType *ref_type = target_value_ptr->value.type;
+    assert(ref_type->id == ZigTypeIdPointer);
+    ZigType *target_type = target_value_ptr->value.type->data.pointer.child_type;
+    if (target_type->id == ZigTypeIdErrorSet) {
+        //  make a new set that has the other cases removed
+        if (!resolve_inferred_error_set(ira->codegen, target_type, instruction->base.source_node)) {
+            return ira->codegen->invalid_instruction;
+        }
+        if (type_is_global_error_set(target_type)) {
+            // the type of the else capture variable still has to be the global error set.
+            // once the runtime hint system is more sophisticated, we could add some hint information here.
+            return target_value_ptr;
+        }
+        // Make note of the errors handled by other cases
+        ErrorTableEntry **errors = allocate<ErrorTableEntry *>(ira->codegen->errors_by_index.length);
+        for (size_t case_i = 0; case_i < instruction->switch_br->case_count; case_i += 1) {
+            IrInstructionSwitchBrCase *br_case = &instruction->switch_br->cases[case_i];
+            IrInstruction *case_expr = br_case->value->child;
+            if (case_expr->value.type->id == ZigTypeIdErrorSet) {
+                ErrorTableEntry *err = ir_resolve_error(ira, case_expr);
+                if (err == nullptr)
+                    return ira->codegen->invalid_instruction;
+                errors[err->value] = err;
+            } else if (case_expr->value.type->id == ZigTypeIdMetaType) {
+                ZigType *err_set_type = ir_resolve_type(ira, case_expr);
+                if (type_is_invalid(err_set_type))
+                    return ira->codegen->invalid_instruction;
+                populate_error_set_table(errors, err_set_type);
+            } else {
+                zig_unreachable();
+            }
+        }
+        ZigList<ErrorTableEntry *> result_list = {};
+
+        ZigType *err_set_type = new_type_table_entry(ZigTypeIdErrorSet);
+        buf_resize(&err_set_type->name, 0);
+        buf_appendf(&err_set_type->name, "error{");
+
+        // Look at all the errors in the type switched on and add them to the result_list
+        // if they are not handled by cases.
+        for (uint32_t i = 0; i < target_type->data.error_set.err_count; i += 1) {
+            ErrorTableEntry *error_entry = target_type->data.error_set.errors[i];
+            ErrorTableEntry *existing_entry = errors[error_entry->value];
+            if (existing_entry == nullptr) {
+                result_list.append(error_entry);
+                buf_appendf(&err_set_type->name, "%s,", buf_ptr(&error_entry->name));
+            }
+        }
+        free(errors);
+
+        err_set_type->data.error_set.err_count = result_list.length;
+        err_set_type->data.error_set.errors = result_list.items;
+        err_set_type->size_in_bits = ira->codegen->builtin_types.entry_global_error_set->size_in_bits;
+        err_set_type->abi_align = ira->codegen->builtin_types.entry_global_error_set->abi_align;
+        err_set_type->abi_size = ira->codegen->builtin_types.entry_global_error_set->abi_size;
+
+        buf_appendf(&err_set_type->name, "}");
+
+        ZigType *new_target_value_ptr_type = get_pointer_to_type_extra(ira->codegen,
+            err_set_type,
+            ref_type->data.pointer.is_const, ref_type->data.pointer.is_volatile,
+            ref_type->data.pointer.ptr_len,
+            ref_type->data.pointer.explicit_alignment,
+            ref_type->data.pointer.bit_offset_in_host, ref_type->data.pointer.host_int_bytes,
+            ref_type->data.pointer.allow_zero);
+        return ir_analyze_ptr_cast(ira, &instruction->base, target_value_ptr, new_target_value_ptr_type,
+                &instruction->base, false);
+    }
+
+    return target_value_ptr;
+}
+
 static IrInstruction *ir_analyze_instruction_union_tag(IrAnalyze *ira, IrInstructionUnionTag *instruction) {
     IrInstruction *value = instruction->value->child;
     return ir_analyze_union_tag(ira, &instruction->base, value);
@@ -23094,6 +23226,8 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio
             return ir_analyze_instruction_switch_target(ira, (IrInstructionSwitchTarget *)instruction);
         case IrInstructionIdSwitchVar:
             return ir_analyze_instruction_switch_var(ira, (IrInstructionSwitchVar *)instruction);
+        case IrInstructionIdSwitchElseVar:
+            return ir_analyze_instruction_switch_else_var(ira, (IrInstructionSwitchElseVar *)instruction);
         case IrInstructionIdUnionTag:
             return ir_analyze_instruction_union_tag(ira, (IrInstructionUnionTag *)instruction);
         case IrInstructionIdImport:
@@ -23457,6 +23591,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdCtz:
         case IrInstructionIdPopCount:
         case IrInstructionIdSwitchVar:
+        case IrInstructionIdSwitchElseVar:
         case IrInstructionIdSwitchTarget:
         case IrInstructionIdUnionTag:
         case IrInstructionIdRef:
src/ir_print.cpp
@@ -546,6 +546,11 @@ static void ir_print_switch_var(IrPrint *irp, IrInstructionSwitchVar *instructio
     ir_print_other_instruction(irp, instruction->prong_value);
 }
 
+static void ir_print_switch_else_var(IrPrint *irp, IrInstructionSwitchElseVar *instruction) {
+    fprintf(irp->f, "switchelsevar ");
+    ir_print_other_instruction(irp, &instruction->switch_br->base);
+}
+
 static void ir_print_switch_target(IrPrint *irp, IrInstructionSwitchTarget *instruction) {
     fprintf(irp->f, "switchtarget ");
     ir_print_other_instruction(irp, instruction->target_value_ptr);
@@ -1559,6 +1564,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdSwitchVar:
             ir_print_switch_var(irp, (IrInstructionSwitchVar *)instruction);
             break;
+        case IrInstructionIdSwitchElseVar:
+            ir_print_switch_else_var(irp, (IrInstructionSwitchElseVar *)instruction);
+            break;
         case IrInstructionIdSwitchTarget:
             ir_print_switch_target(irp, (IrInstructionSwitchTarget *)instruction);
             break;
test/stage1/behavior/switch.zig
@@ -1,4 +1,6 @@
-const expect = @import("std").testing.expect;
+const std = @import("std");
+const expect = std.testing.expect;
+const expectError = std.testing.expectError;
 
 test "switch with numbers" {
     testSwitchWithNumbers(13);
@@ -296,3 +298,33 @@ test "anon enum literal used in switch on union enum" {
         },
     }
 }
+
+test "else prong of switch on error set excludes other cases" {
+    const S = struct {
+        fn doTheTest() void {
+            expectError(error.C, bar());
+        }
+        const E = error{
+            A,
+            B,
+        } || E2;
+
+        const E2 = error{
+            C,
+            D,
+        };
+
+        fn foo() E!void {
+            return error.C;
+        }
+
+        fn bar() E2!void {
+            foo() catch |err| switch (err) {
+                error.A, error.B => {},
+                else => |e| return e,
+            };
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}