Commit e93a05b6e4

Andrew Kelley <andrew@ziglang.org>
2019-05-15 01:11:37
switching on error sets makes new error set for capture values
closes #769
1 parent c08c222
Changed files (4)
src/all_types.hpp
@@ -2370,7 +2370,8 @@ struct IrInstructionSwitchVar {
     IrInstruction base;
 
     IrInstruction *target_value_ptr;
-    IrInstruction *prong_value;
+    IrInstruction **prongs_ptr;
+    size_t prongs_len;
 };
 
 struct IrInstructionSwitchElseVar {
src/ir.cpp
@@ -1834,14 +1834,17 @@ static IrInstruction *ir_build_switch_target(IrBuilder *irb, Scope *scope, AstNo
 }
 
 static IrInstruction *ir_build_switch_var(IrBuilder *irb, Scope *scope, AstNode *source_node,
-        IrInstruction *target_value_ptr, IrInstruction *prong_value)
+        IrInstruction *target_value_ptr, IrInstruction **prongs_ptr, size_t prongs_len)
 {
     IrInstructionSwitchVar *instruction = ir_build_instruction<IrInstructionSwitchVar>(irb, scope, source_node);
     instruction->target_value_ptr = target_value_ptr;
-    instruction->prong_value = prong_value;
+    instruction->prongs_ptr = prongs_ptr;
+    instruction->prongs_len = prongs_len;
 
     ir_ref_instruction(target_value_ptr, irb->current_basic_block);
-    ir_ref_instruction(prong_value, irb->current_basic_block);
+    for (size_t i = 0; i < prongs_len; i += 1) {
+        ir_ref_instruction(prongs_ptr[i], irb->current_basic_block);
+    }
 
     return &instruction->base;
 }
@@ -6309,7 +6312,7 @@ static IrInstruction *ir_gen_if_err_expr(IrBuilder *irb, Scope *scope, AstNode *
 
 static bool ir_gen_switch_prong_expr(IrBuilder *irb, Scope *scope, AstNode *switch_node, AstNode *prong_node,
         IrBasicBlock *end_block, IrInstruction *is_comptime, IrInstruction *var_is_comptime,
-        IrInstruction *target_value_ptr, IrInstruction *prong_value,
+        IrInstruction *target_value_ptr, IrInstruction **prong_values, size_t prong_values_len,
         ZigList<IrBasicBlock *> *incoming_blocks, ZigList<IrInstruction *> *incoming_values,
         IrInstructionSwitchElseVar **out_switch_else_var)
 {
@@ -6336,8 +6339,9 @@ static bool ir_gen_switch_prong_expr(IrBuilder *irb, Scope *scope, AstNode *swit
             *out_switch_else_var = switch_else_var;
             IrInstruction *var_ptr_value = &switch_else_var->base;
             var_value = var_is_ptr ? var_ptr_value : ir_build_load_ptr(irb, scope, var_symbol_node, var_ptr_value);
-        } else if (prong_value != nullptr) {
-            IrInstruction *var_ptr_value = ir_build_switch_var(irb, scope, var_symbol_node, target_value_ptr, prong_value);
+        } else if (prong_values != nullptr) {
+            IrInstruction *var_ptr_value = ir_build_switch_var(irb, scope, var_symbol_node, target_value_ptr,
+                    prong_values, prong_values_len);
             var_value = var_is_ptr ? var_ptr_value : ir_build_load_ptr(irb, scope, var_symbol_node, var_ptr_value);
         } else {
             var_value = var_is_ptr ? target_value_ptr : ir_build_load_ptr(irb, scope, var_symbol_node, 
@@ -6410,7 +6414,7 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
             IrBasicBlock *prev_block = irb->current_basic_block;
             ir_set_cursor_at_end_and_append_block(irb, else_block);
             if (!ir_gen_switch_prong_expr(irb, subexpr_scope, node, prong_node, end_block,
-                is_comptime, var_is_comptime, target_value_ptr, nullptr, &incoming_blocks, &incoming_values,
+                is_comptime, var_is_comptime, target_value_ptr, nullptr, 0, &incoming_blocks, &incoming_values,
                 &switch_else_var))
             {
                 return irb->codegen->invalid_instruction;
@@ -6478,7 +6482,8 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
 
             ir_set_cursor_at_end_and_append_block(irb, range_block_yes);
             if (!ir_gen_switch_prong_expr(irb, subexpr_scope, node, prong_node, end_block,
-                is_comptime, var_is_comptime, target_value_ptr, nullptr, &incoming_blocks, &incoming_values, nullptr))
+                is_comptime, var_is_comptime, target_value_ptr, nullptr, 0,
+                &incoming_blocks, &incoming_values, nullptr))
             {
                 return irb->codegen->invalid_instruction;
             }
@@ -6497,7 +6502,7 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
             continue;
 
         IrBasicBlock *prong_block = ir_create_basic_block(irb, scope, "SwitchProng");
-        IrInstruction *last_item_value = nullptr;
+        IrInstruction **items = allocate<IrInstruction *>(prong_item_count);
 
         for (size_t item_i = 0; item_i < prong_item_count; item_i += 1) {
             AstNode *item_node = prong_node->data.switch_prong.items.at(item_i);
@@ -6515,15 +6520,14 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
             this_case->value = item_value;
             this_case->block = prong_block;
 
-            last_item_value = item_value;
+            items[item_i] = item_value;
         }
-        IrInstruction *only_item_value = (prong_item_count == 1) ? last_item_value : nullptr;
 
         IrBasicBlock *prev_block = irb->current_basic_block;
         ir_set_cursor_at_end_and_append_block(irb, prong_block);
         if (!ir_gen_switch_prong_expr(irb, subexpr_scope, node, prong_node, end_block,
-            is_comptime, var_is_comptime, target_value_ptr, only_item_value, &incoming_blocks, &incoming_values,
-            nullptr))
+            is_comptime, var_is_comptime, target_value_ptr, items, prong_item_count,
+            &incoming_blocks, &incoming_values, nullptr))
         {
             return irb->codegen->invalid_instruction;
         }
@@ -17423,17 +17427,22 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
     if (type_is_invalid(target_value_ptr->value.type))
         return ira->codegen->invalid_instruction;
 
-    IrInstruction *prong_value = instruction->prong_value->child;
-    if (type_is_invalid(prong_value->value.type))
-        return ira->codegen->invalid_instruction;
-
-    assert(target_value_ptr->value.type->id == ZigTypeIdPointer);
+    ZigType *ref_type = target_value_ptr->value.type;
+    assert(ref_type->id == ZigTypeIdPointer);
     ZigType *target_type = target_value_ptr->value.type->data.pointer.child_type;
     if (target_type->id == ZigTypeIdUnion) {
         ZigType *enum_type = target_type->data.unionation.tag_type;
         assert(enum_type != nullptr);
         assert(enum_type->id == ZigTypeIdEnum);
 
+        if (instruction->prongs_len != 1) {
+            return target_value_ptr;
+        }
+
+        IrInstruction *prong_value = instruction->prongs_ptr[0]->child;
+        if (type_is_invalid(prong_value->value.type))
+            return ira->codegen->invalid_instruction;
+
         IrInstruction *casted_prong_value = ir_implicit_cast(ira, prong_value, enum_type);
         if (type_is_invalid(casted_prong_value->value.type))
             return ira->codegen->invalid_instruction;
@@ -17468,6 +17477,36 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
         result->value.type = get_pointer_to_type(ira->codegen, field->type_entry,
                 target_value_ptr->value.type->data.pointer.is_const);
         return result;
+    } else if (target_type->id == ZigTypeIdErrorSet) {
+        // construct an error set from the prong values
+        ZigType *err_set_type = new_type_table_entry(ZigTypeIdErrorSet);
+        err_set_type->size_in_bits = ira->codegen->builtin_types.entry_global_error_set->size_in_bits;
+        err_set_type->abi_align = ira->codegen->builtin_types.entry_global_error_set->abi_align;
+        err_set_type->abi_size = ira->codegen->builtin_types.entry_global_error_set->abi_size;
+        ZigList<ErrorTableEntry *> error_list = {};
+        buf_resize(&err_set_type->name, 0);
+        buf_appendf(&err_set_type->name, "error{");
+        for (size_t i = 0; i < instruction->prongs_len; i += 1) {
+            ErrorTableEntry *err = ir_resolve_error(ira, instruction->prongs_ptr[i]->child);
+            if (err == nullptr)
+                return ira->codegen->invalid_instruction;
+            error_list.append(err);
+            buf_appendf(&err_set_type->name, "%s,", buf_ptr(&err->name));
+        }
+        err_set_type->data.error_set.errors = error_list.items;
+        err_set_type->data.error_set.err_count = error_list.length;
+        buf_appendf(&err_set_type->name, "}");
+
+
+        ZigType *new_target_value_ptr_type = get_pointer_to_type_extra(ira->codegen,
+            err_set_type,
+            ref_type->data.pointer.is_const, ref_type->data.pointer.is_volatile,
+            ref_type->data.pointer.ptr_len,
+            ref_type->data.pointer.explicit_alignment,
+            ref_type->data.pointer.bit_offset_in_host, ref_type->data.pointer.host_int_bytes,
+            ref_type->data.pointer.allow_zero);
+        return ir_analyze_ptr_cast(ira, &instruction->base, target_value_ptr, new_target_value_ptr_type,
+                &instruction->base, false);
     } else {
         ir_add_error(ira, &instruction->base,
             buf_sprintf("switch on type '%s' provides no expression parameter", buf_ptr(&target_type->name)));
src/ir_print.cpp
@@ -542,8 +542,10 @@ static void ir_print_switch_br(IrPrint *irp, IrInstructionSwitchBr *instruction)
 static void ir_print_switch_var(IrPrint *irp, IrInstructionSwitchVar *instruction) {
     fprintf(irp->f, "switchvar ");
     ir_print_other_instruction(irp, instruction->target_value_ptr);
-    fprintf(irp->f, ", ");
-    ir_print_other_instruction(irp, instruction->prong_value);
+    for (size_t i = 0; i < instruction->prongs_len; i += 1) {
+        fprintf(irp->f, ", ");
+        ir_print_other_instruction(irp, instruction->prongs_ptr[i]);
+    }
 }
 
 static void ir_print_switch_else_var(IrPrint *irp, IrInstructionSwitchElseVar *instruction) {
test/stage1/behavior/switch.zig
@@ -328,3 +328,35 @@ test "else prong of switch on error set excludes other cases" {
     S.doTheTest();
     comptime S.doTheTest();
 }
+
+test "switch prongs with error set cases make a new error set type for capture value" {
+    const S = struct {
+        fn doTheTest() void {
+            expectError(error.B, bar());
+        }
+        const E = E1 || E2;
+
+        const E1 = error{
+            A,
+            B,
+        };
+
+        const E2 = error{
+            C,
+            D,
+        };
+
+        fn foo() E!void {
+            return error.B;
+        }
+
+        fn bar() E1!void {
+            foo() catch |err| switch (err) {
+                error.A, error.B => |e| return e,
+                else => {},
+            };
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}