Commit fde276a3bf

Andrew Kelley <superjoe30@gmail.com>
2017-01-10 22:28:49
IR: implement error for missing or extra switch prongs
1 parent 430e33b
src/all_types.hpp
@@ -1484,6 +1484,7 @@ enum IrInstructionId {
     IrInstructionIdIntToPtr,
     IrInstructionIdPtrToInt,
     IrInstructionIdIntToEnum,
+    IrInstructionIdCheckSwitchProngs,
 };
 
 struct IrInstruction {
@@ -2160,6 +2161,19 @@ struct IrInstructionIntToEnum {
     IrInstruction *target;
 };
 
+struct IrInstructionCheckSwitchProngsRange {
+    IrInstruction *start;
+    IrInstruction *end;
+};
+
+struct IrInstructionCheckSwitchProngs {
+    IrInstruction base;
+
+    IrInstruction *target_value;
+    IrInstructionCheckSwitchProngsRange *ranges;
+    size_t range_count;
+};
+
 enum LValPurpose {
     LValPurposeNone,
     LValPurposeAssign,
src/codegen.cpp
@@ -2274,6 +2274,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
         case IrInstructionIdFnProto:
         case IrInstructionIdTestComptime:
         case IrInstructionIdGeneratedCode:
+        case IrInstructionIdCheckSwitchProngs:
             zig_unreachable();
         case IrInstructionIdReturn:
             return ir_render_return(g, executable, (IrInstructionReturn *)instruction);
src/ir.cpp
@@ -483,6 +483,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionIntToEnum *) {
     return IrInstructionIdIntToEnum;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionCheckSwitchProngs *) {
+    return IrInstructionIdCheckSwitchProngs;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -1981,6 +1985,24 @@ static IrInstruction *ir_build_int_to_enum(IrBuilder *irb, Scope *scope, AstNode
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_check_switch_prongs(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        IrInstruction *target_value, IrInstructionCheckSwitchProngsRange *ranges, size_t range_count)
+{
+    IrInstructionCheckSwitchProngs *instruction = ir_build_instruction<IrInstructionCheckSwitchProngs>(
+            irb, scope, source_node);
+    instruction->target_value = target_value;
+    instruction->ranges = ranges;
+    instruction->range_count = range_count;
+
+    ir_ref_instruction(target_value, irb->current_basic_block);
+    for (size_t i = 0; i < range_count; i += 1) {
+        ir_ref_instruction(ranges[i].start, irb->current_basic_block);
+        ir_ref_instruction(ranges[i].end, irb->current_basic_block);
+    }
+
+    return &instruction->base;
+}
+
 static IrInstruction *ir_instruction_br_get_dep(IrInstructionBr *instruction, size_t index) {
     return nullptr;
 }
@@ -2580,6 +2602,18 @@ static IrInstruction *ir_instruction_inttoenum_get_dep(IrInstructionIntToEnum *i
     }
 }
 
+static IrInstruction *ir_instruction_checkswitchprongs_get_dep(IrInstructionCheckSwitchProngs *instruction,
+        size_t index)
+{
+    if (index == 0) return instruction->target_value;
+    size_t range_index = index - 1;
+    if (range_index < instruction->range_count * 2) {
+        IrInstructionCheckSwitchProngsRange *range = &instruction->ranges[range_index / 2];
+        return (range_index % 2 == 0) ? range->start : range->end;
+    }
+    return nullptr;
+}
+
 static IrInstruction *ir_instruction_get_dep(IrInstruction *instruction, size_t index) {
     switch (instruction->id) {
         case IrInstructionIdInvalid:
@@ -2752,6 +2786,8 @@ static IrInstruction *ir_instruction_get_dep(IrInstruction *instruction, size_t
             return ir_instruction_ptrtoint_get_dep((IrInstructionPtrToInt *) instruction, index);
         case IrInstructionIdIntToEnum:
             return ir_instruction_inttoenum_get_dep((IrInstructionIntToEnum *) instruction, index);
+        case IrInstructionIdCheckSwitchProngs:
+            return ir_instruction_checkswitchprongs_get_dep((IrInstructionCheckSwitchProngs *) instruction, index);
     }
     zig_unreachable();
 }
@@ -4677,6 +4713,7 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
 
     ZigList<IrInstruction *> incoming_values = {0};
     ZigList<IrBasicBlock *> incoming_blocks = {0};
+    ZigList<IrInstructionCheckSwitchProngsRange> check_ranges = {0};
 
     AstNode *else_prong = nullptr;
     for (size_t prong_i = 0; prong_i < prong_count; prong_i += 1) {
@@ -4719,6 +4756,10 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
                         if (end_value == irb->codegen->invalid_instruction)
                             return irb->codegen->invalid_instruction;
 
+                        IrInstructionCheckSwitchProngsRange *check_range = check_ranges.add_one();
+                        check_range->start = start_value;
+                        check_range->end = end_value;
+
                         IrInstruction *start_value_const = ir_build_static_eval(irb, scope, start_node, start_value);
                         IrInstruction *end_value_const = ir_build_static_eval(irb, scope, start_node, end_value);
 
@@ -4738,6 +4779,10 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
                         if (item_value == irb->codegen->invalid_instruction)
                             return irb->codegen->invalid_instruction;
 
+                        IrInstructionCheckSwitchProngsRange *check_range = check_ranges.add_one();
+                        check_range->start = item_value;
+                        check_range->end = item_value;
+
                         IrInstruction *cmp_ok = ir_build_bin_op(irb, scope, item_node, IrBinOpCmpEq,
                                 item_value, target_value, false);
                         if (ok_bit) {
@@ -4776,6 +4821,10 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
                     if (item_value == irb->codegen->invalid_instruction)
                         return irb->codegen->invalid_instruction;
 
+                    IrInstructionCheckSwitchProngsRange *check_range = check_ranges.add_one();
+                    check_range->start = item_value;
+                    check_range->end = item_value;
+
                     IrInstructionSwitchBrCase *this_case = cases.add_one();
                     this_case->value = item_value;
                     this_case->block = prong_block;
@@ -4798,6 +4847,10 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
         }
     }
 
+    if (!else_prong) {
+        ir_build_check_switch_prongs(irb, scope, node, target_value, check_ranges.items, check_ranges.length);
+    }
+
     if (cases.length == 0) {
         ir_build_br(irb, scope, node, else_block, is_comptime);
     } else {
@@ -11091,6 +11144,69 @@ static TypeTableEntry *ir_analyze_instruction_test_comptime(IrAnalyze *ira, IrIn
     return ira->codegen->builtin_types.entry_bool;
 }
 
+static TypeTableEntry *ir_analyze_instruction_check_switch_prongs(IrAnalyze *ira,
+        IrInstructionCheckSwitchProngs *instruction)
+{
+    IrInstruction *target_value = instruction->target_value->other;
+    TypeTableEntry *switch_type = target_value->value.type;
+    if (switch_type->id == TypeTableEntryIdInvalid)
+        return ira->codegen->builtin_types.entry_invalid;
+
+    if (switch_type->id == TypeTableEntryIdEnumTag) {
+        TypeTableEntry *enum_type = switch_type->data.enum_tag.enum_type;
+        size_t *field_use_counts = allocate<size_t>(enum_type->data.enumeration.src_field_count);
+        for (size_t range_i = 0; range_i < instruction->range_count; range_i += 1) {
+            IrInstructionCheckSwitchProngsRange *range = &instruction->ranges[range_i];
+
+            IrInstruction *start_value = range->start->other;
+            if (start_value->value.type->id == TypeTableEntryIdInvalid)
+                return ira->codegen->builtin_types.entry_invalid;
+
+            IrInstruction *end_value = range->end->other;
+            if (end_value->value.type->id == TypeTableEntryIdInvalid)
+                return ira->codegen->builtin_types.entry_invalid;
+
+            size_t start_index;
+            size_t end_index;
+            if (start_value->value.type->id == TypeTableEntryIdEnumTag) {
+                start_index = start_value->value.data.x_bignum.data.x_uint;
+            } else if (start_value->value.type->id == TypeTableEntryIdEnum) {
+                start_index = start_value->value.data.x_enum.tag;
+            } else {
+                zig_unreachable();
+            }
+            if (end_value->value.type->id == TypeTableEntryIdEnumTag) {
+                end_index = end_value->value.data.x_bignum.data.x_uint;
+            } else if (end_value->value.type->id == TypeTableEntryIdEnum) {
+                end_index = end_value->value.data.x_enum.tag;
+            } else {
+                zig_unreachable();
+            }
+
+            for (size_t field_index = start_index; field_index <= end_index; field_index += 1) {
+                field_use_counts[field_index] += 1;
+                if (field_use_counts[field_index] > 1) {
+                    TypeEnumField *type_enum_field = &enum_type->data.enumeration.fields[field_index];
+                    ir_add_error(ira, start_value,
+                        buf_sprintf("duplicate switch value: '%s.%s'", buf_ptr(&enum_type->name),
+                            buf_ptr(type_enum_field->name)));
+                }
+            }
+        }
+        for (uint32_t i = 0; i < enum_type->data.enumeration.src_field_count; i += 1) {
+            if (field_use_counts[i] == 0) {
+                ir_add_error(ira, &instruction->base,
+                    buf_sprintf("enumeration value '%s.%s' not handled in switch", buf_ptr(&enum_type->name),
+                        buf_ptr(enum_type->data.enumeration.fields[i].name)));
+            }
+        }
+    } else {
+        // TODO check prongs of types other than enumtag
+    }
+    ir_build_const_from(ira, &instruction->base, false);
+    return ira->codegen->builtin_types.entry_void;
+}
+
 static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstruction *instruction) {
     switch (instruction->id) {
         case IrInstructionIdInvalid:
@@ -11246,6 +11362,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
             return ir_analyze_instruction_fn_proto(ira, (IrInstructionFnProto *)instruction);
         case IrInstructionIdTestComptime:
             return ir_analyze_instruction_test_comptime(ira, (IrInstructionTestComptime *)instruction);
+        case IrInstructionIdCheckSwitchProngs:
+            return ir_analyze_instruction_check_switch_prongs(ira, (IrInstructionCheckSwitchProngs *)instruction);
         case IrInstructionIdMaybeWrap:
         case IrInstructionIdErrWrapCode:
         case IrInstructionIdErrWrapPayload:
@@ -11351,6 +11469,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdMemcpy:
         case IrInstructionIdBreakpoint:
         case IrInstructionIdOverflowOp: // TODO when we support multiple returns this can be side effect free
+        case IrInstructionIdCheckSwitchProngs:
             return true;
         case IrInstructionIdPhi:
         case IrInstructionIdUnOp:
src/ir_print.cpp
@@ -799,6 +799,20 @@ static void ir_print_int_to_enum(IrPrint *irp, IrInstructionIntToEnum *instructi
     fprintf(irp->f, ")");
 }
 
+static void ir_print_check_switch_prongs(IrPrint *irp, IrInstructionCheckSwitchProngs *instruction) {
+    fprintf(irp->f, "@checkSwitchProngs(");
+    ir_print_other_instruction(irp, instruction->target_value);
+    fprintf(irp->f, ",");
+    for (size_t i = 0; i < instruction->range_count; i += 1) {
+        if (i != 0)
+            fprintf(irp->f, ",");
+        ir_print_other_instruction(irp, instruction->ranges[i].start);
+        fprintf(irp->f, "...");
+        ir_print_other_instruction(irp, instruction->ranges[i].end);
+    }
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -1056,6 +1070,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdIntToEnum:
             ir_print_int_to_enum(irp, (IrInstructionIntToEnum *)instruction);
             break;
+        case IrInstructionIdCheckSwitchProngs:
+            ir_print_check_switch_prongs(irp, (IrInstructionCheckSwitchProngs *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
test/run_tests.cpp
@@ -1167,12 +1167,12 @@ const Number = enum {
 };
 fn f(n: Number) -> i32 {
     switch (n) {
-        One => 1,
-        Two => 2,
-        Three => 3,
+        Number.One => 1,
+        Number.Two => 2,
+        Number.Three => i32(3),
     }
 }
-    )SOURCE", 1, ".tmp_source.zig:9:5: error: enumeration value 'Four' not handled in switch");
+    )SOURCE", 1, ".tmp_source.zig:9:5: error: enumeration value 'Number.Four' not handled in switch");
 
     add_compile_fail_case("import inside function body", R"SOURCE(
 fn f() {
@@ -1430,14 +1430,6 @@ fn f() -> i32 {
 }
     )SOURCE", 1, ".tmp_source.zig:2:15: error: inline parameter not allowed in extern function");
 
-    /* TODO
-    add_compile_fail_case("inline export function", R"SOURCE(
-export inline fn foo(x: i32, y: i32) -> i32{
-    x + y
-}
-    )SOURCE", 1, ".tmp_source.zig:2:1: error: extern functions cannot be inline");
-    */
-
     add_compile_fail_case("convert fixed size array to slice with invalid size", R"SOURCE(
 fn f() {
     var array: [5]u8 = undefined;