Commit 96fd103073

Andrew Kelley <andrew@ziglang.org>
2019-07-04 06:35:28
improve the error message and test coverage
1 parent bfe0bf6
Changed files (4)
doc/langref.html.in
@@ -3024,7 +3024,7 @@ test "switch on tagged union" {
     // Switching on more complex enums is allowed.
     const b = switch (a) {
         // A capture group is allowed on a match, and will return the enum
-        // value matched. If the payloads of both cases are the same
+        // value matched. If the payload types of both cases are the same
         // they can be put into the same switch prong.
         Item.A, Item.E => |item| item,
 
src/ir.cpp
@@ -19229,53 +19229,52 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
         ZigType *enum_type = target_type->data.unionation.tag_type;
         assert(enum_type != nullptr);
         assert(enum_type->id == ZigTypeIdEnum);
+        assert(instruction->prongs_len > 0);
 
-        IrInstruction *prong_value = instruction->prongs_ptr[0]->child;
-        if (type_is_invalid(prong_value->value.type))
+        IrInstruction *first_prong_value = instruction->prongs_ptr[0]->child;
+        if (type_is_invalid(first_prong_value->value.type))
             return ira->codegen->invalid_instruction;
 
-        IrInstruction *casted_prong_value = ir_implicit_cast(ira, prong_value, enum_type);
-        if (type_is_invalid(casted_prong_value->value.type))
+        IrInstruction *first_casted_prong_value = ir_implicit_cast(ira, first_prong_value, enum_type);
+        if (type_is_invalid(first_casted_prong_value->value.type))
             return ira->codegen->invalid_instruction;
 
-        ConstExprValue *prong_val = ir_resolve_const(ira, casted_prong_value, UndefBad);
-        if (!prong_val)
+        ConstExprValue *first_prong_val = ir_resolve_const(ira, first_casted_prong_value, UndefBad);
+        if (first_prong_val == nullptr)
             return ira->codegen->invalid_instruction;
 
-        TypeUnionField *field = find_union_field_by_tag(target_type, &prong_val->data.x_enum_tag);
+        TypeUnionField *first_field = find_union_field_by_tag(target_type, &first_prong_val->data.x_enum_tag);
 
-        if (instruction->prongs_len != 1) {
-            ErrorMsg *invalid_payload = nullptr;
-            Buf *invalid_payload_list = nullptr;
-
-            for (size_t i = 1; i < instruction->prongs_len; i++) {
-                IrInstruction *casted_prong_value = ir_implicit_cast(ira, instruction->prongs_ptr[i]->child, enum_type);
-                if (type_is_invalid(casted_prong_value->value.type))
-                    return ira->codegen->invalid_instruction;
-                
-                ConstExprValue *next_prong = ir_resolve_const(ira, casted_prong_value, UndefBad);
-                if (!next_prong)
-                    return ira->codegen->invalid_instruction;
+        ErrorMsg *invalid_payload_msg = nullptr;
+        for (size_t prong_i = 1; prong_i < instruction->prongs_len; prong_i += 1) {
+            IrInstruction *this_prong_inst = instruction->prongs_ptr[prong_i]->child;
+            if (type_is_invalid(this_prong_inst->value.type))
+                return ira->codegen->invalid_instruction;
 
-                ZigType *payload = find_union_field_by_tag(target_type, &next_prong->data.x_enum_tag)->type_entry;
+            IrInstruction *this_casted_prong_value = ir_implicit_cast(ira, this_prong_inst, enum_type);
+            if (type_is_invalid(this_casted_prong_value->value.type))
+                return ira->codegen->invalid_instruction;
 
-                if (field->type_entry != payload) {
-                    if (!invalid_payload) {
-                        invalid_payload = ir_add_error(ira, &instruction->base,
-                            buf_sprintf("switch prong contains cases with different payloads"));
-                        invalid_payload_list = buf_sprintf("payload types are %s", buf_ptr(&field->type_entry->name));
-                    }
+            ConstExprValue *this_prong = ir_resolve_const(ira, this_casted_prong_value, UndefBad);
+            if (this_prong == nullptr)
+                return ira->codegen->invalid_instruction;
 
-                    if (i == instruction->prongs_len - 1)
-                        buf_append_buf(invalid_payload_list, buf_sprintf(" and %s", buf_ptr(&payload->name)));
-                    else
-                        buf_append_buf(invalid_payload_list, buf_sprintf(", %s", buf_ptr(&payload->name)));
+            TypeUnionField *payload_field = find_union_field_by_tag(target_type, &this_prong->data.x_enum_tag);
+            ZigType *payload_type = payload_field->type_entry;
+            if (first_field->type_entry != payload_type) {
+                if (invalid_payload_msg == nullptr) {
+                    invalid_payload_msg = ir_add_error(ira, &instruction->base,
+                        buf_sprintf("capture group with incompatible types"));
+                    add_error_note(ira->codegen, invalid_payload_msg, first_prong_value->source_node,
+                            buf_sprintf("type '%s' here", buf_ptr(&first_field->type_entry->name)));
                 }
+                add_error_note(ira->codegen, invalid_payload_msg, this_prong_inst->source_node,
+                        buf_sprintf("type '%s' here", buf_ptr(&payload_field->type_entry->name)));
             }
+        }
 
-            if (invalid_payload)
-                add_error_note(ira->codegen, invalid_payload,
-                    ((IrInstruction*)instruction)->source_node, invalid_payload_list);
+        if (invalid_payload_msg != nullptr) {
+            return ira->codegen->invalid_instruction;
         }
 
         if (instr_is_comptime(target_value_ptr)) {
@@ -19288,7 +19287,7 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
                 return ira->codegen->invalid_instruction;
 
             IrInstruction *result = ir_const(ira, &instruction->base,
-                    get_pointer_to_type(ira->codegen, field->type_entry,
+                    get_pointer_to_type(ira->codegen, first_field->type_entry,
                     target_val_ptr->type->data.pointer.is_const));
             ConstExprValue *out_val = &result->value;
             out_val->data.x_ptr.special = ConstPtrSpecialRef;
@@ -19298,8 +19297,8 @@ static IrInstruction *ir_analyze_instruction_switch_var(IrAnalyze *ira, IrInstru
         }
 
         IrInstruction *result = ir_build_union_field_ptr(&ira->new_irb,
-            instruction->base.scope, instruction->base.source_node, target_value_ptr, field, false, false);
-        result->value.type = get_pointer_to_type(ira->codegen, field->type_entry,
+            instruction->base.scope, instruction->base.source_node, target_value_ptr, first_field, false, false);
+        result->value.type = get_pointer_to_type(ira->codegen, first_field->type_entry,
                 target_value_ptr->value.type->data.pointer.is_const);
         return result;
     } else if (target_type->id == ZigTypeIdErrorSet) {
@@ -23007,11 +23006,11 @@ static IrInstruction *ir_analyze_instruction_mul_add(IrAnalyze *ira, IrInstructi
     IrInstruction *type_value = instruction->type_value->child;
     if (type_is_invalid(type_value->value.type))
         return ira->codegen->invalid_instruction;
-    
+
     ZigType *expr_type = ir_resolve_type(ira, type_value);
     if (type_is_invalid(expr_type))
         return ira->codegen->invalid_instruction;
-    
+
     // Only allow float types, and vectors of floats.
     ZigType *float_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type;
     if (float_type->id != ZigTypeIdFloat) {
@@ -25112,7 +25111,7 @@ static IrInstruction *ir_analyze_instruction_float_op(IrAnalyze *ira, IrInstruct
     IrInstruction *type = instruction->type->child;
     if (type_is_invalid(type->value.type))
         return ira->codegen->invalid_instruction;
-    
+
     ZigType *expr_type = ir_resolve_type(ira, type);
     if (type_is_invalid(expr_type))
         return ira->codegen->invalid_instruction;
test/stage1/behavior/switch.zig
@@ -392,20 +392,36 @@ test "switch with null and T peer types and inferred result location type" {
     comptime S.doTheTest(1);
 }
 
-test "switch prongs with cases with identical payloads" {
+test "switch prongs with cases with identical payload types" {
     const Union = union(enum) {
         A: usize,
         B: isize,
         C: usize,
     };
     const S = struct {
-        fn doTheTest(u: Union) void {
+        fn doTheTest() void {
+            doTheSwitch1(Union{ .A = 8 });
+            doTheSwitch2(Union{ .B = -8 });
+        }
+        fn doTheSwitch1(u: Union) void {
             switch (u) {
-                .A, .C => |e| expect(@typeOf(e) == usize),
-                .B => |e| expect(@typeOf(e) == isize),
+                .A, .C => |e| {
+                    expect(@typeOf(e) == usize);
+                    expect(e == 8);
+                },
+                .B => |e| @panic("fail"),
+            }
+        }
+        fn doTheSwitch2(u: Union) void {
+            switch (u) {
+                .A, .C => |e| @panic("fail"),
+                .B => |e| {
+                    expect(@typeOf(e) == isize);
+                    expect(e == -8);
+                },
             }
         }
     };
-    S.doTheTest(Union{ .A = 8 });
-    comptime S.doTheTest(Union{ .B = -8 });
+    S.doTheTest();
+    comptime S.doTheTest();
 }
test/compile_errors.zig
@@ -2,6 +2,24 @@ const tests = @import("tests.zig");
 const builtin = @import("builtin");
 
 pub fn addCases(cases: *tests.CompileErrorContext) void {
+    cases.add(
+        "capture group on switch prong with incompatible payload types",
+        \\const Union = union(enum) {
+        \\    A: usize,
+        \\    B: isize,
+        \\};
+        \\comptime {
+        \\    var u = Union{ .A = 8 };
+        \\    switch (u) {
+        \\        .A, .B => |e| unreachable,
+        \\    }
+        \\}
+    ,
+        "tmp.zig:8:20: error: capture group with incompatible types",
+        "tmp.zig:8:9: note: type 'usize' here",
+        "tmp.zig:8:13: note: type 'isize' here",
+    );
+
     cases.add(
         "wrong type to @hasField",
         \\export fn entry() bool {
@@ -6073,21 +6091,4 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
         "tmp.zig:5:30: error: expression value is ignored",
         "tmp.zig:9:30: error: expression value is ignored",
     );
-
-    cases.add(
-        "capture group on switch prong with different payloads",
-        \\const Union = union(enum) {
-        \\    A: usize,
-        \\    B: isize,
-        \\};
-        \\comptime {
-        \\    var u = Union{ .A = 8 };
-        \\    switch (u) {
-        \\        .A, .B => |e| unreachable,
-        \\    }
-        \\}
-    ,
-        "tmp.zig:8:20: error: switch prong contains cases with different payloads",
-        "tmp.zig:8:20: note: payload types are usize and isize",
-    );
 }