Commit cae93c860b

LemonBoy <thatlemon@gmail.com>
2020-01-13 22:18:49
Allow switching on pointer types
Closes #4074
1 parent 84930fe
Changed files (2)
src
test
stage1
behavior
src/codegen.cpp
@@ -4876,14 +4876,30 @@ static LLVMValueRef ir_render_pop_count(CodeGen *g, IrExecutable *executable, Ir
 }
 
 static LLVMValueRef ir_render_switch_br(CodeGen *g, IrExecutable *executable, IrInstructionSwitchBr *instruction) {
-    LLVMValueRef target_value = ir_llvm_value(g, instruction->target_value);
+    ZigType *target_type = instruction->target_value->value->type;
     LLVMBasicBlockRef else_block = instruction->else_block->llvm_block;
+
+    LLVMValueRef target_value = ir_llvm_value(g, instruction->target_value);
+    if (target_type->id == ZigTypeIdPointer) {
+        const ZigType *usize = g->builtin_types.entry_usize;
+        target_value = LLVMBuildPtrToInt(g->builder, target_value, usize->llvm_type, "");
+    }
+
     LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, target_value, else_block,
-            (unsigned)instruction->case_count);
+                                                (unsigned)instruction->case_count);
+
     for (size_t i = 0; i < instruction->case_count; i += 1) {
         IrInstructionSwitchBrCase *this_case = &instruction->cases[i];
-        LLVMAddCase(switch_instr, ir_llvm_value(g, this_case->value), this_case->block->llvm_block);
+
+        LLVMValueRef case_value = ir_llvm_value(g, this_case->value);
+        if (target_type->id == ZigTypeIdPointer) {
+            const ZigType *usize = g->builtin_types.entry_usize;
+            case_value = LLVMBuildPtrToInt(g->builder, case_value, usize->llvm_type, "");
+        }
+
+        LLVMAddCase(switch_instr, case_value, this_case->block->llvm_block);
     }
+
     return nullptr;
 }
 
test/stage1/behavior/switch.zig
@@ -452,3 +452,30 @@ test "switch on global mutable var isn't constant-folded" {
         poll();
     }
 }
+
+test "switch on pointer type" {
+    const S = struct {
+        const X = struct {
+            field: u32,
+        };
+
+        const P1 = @intToPtr(*X, 0x400);
+        const P2 = @intToPtr(*X, 0x800);
+        const P3 = @intToPtr(*X, 0xC00);
+
+        fn doTheTest(arg: *X) i32 {
+            switch (arg) {
+                P1 => return 1,
+                P2 => return 2,
+                else => return 3,
+            }
+        }
+    };
+
+    expect(1 == S.doTheTest(S.P1));
+    expect(2 == S.doTheTest(S.P2));
+    expect(3 == S.doTheTest(S.P3));
+    comptime expect(1 == S.doTheTest(S.P1));
+    comptime expect(2 == S.doTheTest(S.P2));
+    comptime expect(3 == S.doTheTest(S.P3));
+}