Commit fcbeaddbb2

Andrew Kelley <superjoe30@gmail.com>
2016-02-04 23:26:27
codegen: fix switch expressions for enums with payloads
1 parent b87d0ab
Changed files (2)
src/codegen.cpp
@@ -1256,8 +1256,12 @@ static LLVMValueRef gen_cmp_expr(CodeGen *g, AstNode *node) {
                 op1_type->data.integral.is_signed);
         return LLVMBuildICmp(g->builder, pred, val1, val2, "");
     } else if (op1_type->id == TypeTableEntryIdEnum) {
-        LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, false);
-        return LLVMBuildICmp(g->builder, pred, val1, val2, "");
+        if (op1_type->data.enumeration.gen_field_count == 0) {
+            LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, false);
+            return LLVMBuildICmp(g->builder, pred, val1, val2, "");
+        } else {
+            zig_unreachable();
+        }
     } else {
         zig_unreachable();
     }
@@ -2309,9 +2313,25 @@ static LLVMValueRef gen_symbol(CodeGen *g, AstNode *node) {
 static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeSwitchExpr);
 
-    LLVMValueRef target_value = gen_expr(g, node->data.switch_expr.expr);
+    TypeTableEntry *target_type = get_expr_type(node->data.switch_expr.expr);
+    LLVMValueRef target_value_handle = gen_expr(g, node->data.switch_expr.expr);
+    LLVMValueRef target_value;
+    if (handle_is_ptr(target_type)) {
+        if (target_type->id == TypeTableEntryIdEnum) {
+            add_debug_source_node(g, node);
+            LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, target_value_handle, 0, "");
+            target_value = LLVMBuildLoad(g->builder, tag_field_ptr, "");
+        } else {
+            zig_unreachable();
+        }
+    } else {
+        target_value = target_value_handle;
+    }
+
 
-    bool end_unreachable = (get_expr_type(node)->id == TypeTableEntryIdUnreachable);
+    TypeTableEntry *switch_type = get_expr_type(node);
+    bool result_has_bits = type_has_bits(switch_type);
+    bool end_unreachable = (switch_type->id == TypeTableEntryIdUnreachable);
 
     LLVMBasicBlockRef end_block = end_unreachable ?
         nullptr : LLVMAppendBasicBlock(g->cur_fn->fn_value, "SwitchEnd");
@@ -2338,7 +2358,21 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
                 AstNode *item_node = prong_node->data.switch_prong.items.at(item_i);
                 assert(item_node->type != NodeTypeSwitchRange);
                 assert(get_resolved_expr(item_node)->const_val.ok);
-                LLVMValueRef val = gen_expr(g, item_node);
+                LLVMValueRef val_handle = gen_expr(g, item_node);
+                LLVMValueRef val;
+                if (handle_is_ptr(target_type)) {
+                    if (target_type->id == TypeTableEntryIdEnum) {
+                        ConstExprValue *item_const_val = &get_resolved_expr(item_node)->const_val;
+                        assert(item_const_val->ok);
+                        assert(get_expr_type(item_node)->id == TypeTableEntryIdEnum);
+                        val = LLVMConstInt(target_type->data.enumeration.tag_type->type_ref,
+                                item_const_val->data.x_enum.tag, false);
+                    } else {
+                        zig_unreachable();
+                    }
+                } else {
+                    val = val_handle;
+                }
                 LLVMAddCase(switch_instr, val, prong_block);
             }
         }
@@ -2367,11 +2401,14 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
 
     LLVMPositionBuilderAtEnd(g->builder, end_block);
 
-    add_debug_source_node(g, node);
-    LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMTypeOf(incoming_values.at(0)), "");
-    LLVMAddIncoming(phi, incoming_values.items, incoming_blocks.items, incoming_values.length);
-
-    return phi;
+    if (result_has_bits) {
+        add_debug_source_node(g, node);
+        LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMTypeOf(incoming_values.at(0)), "");
+        LLVMAddIncoming(phi, incoming_values.items, incoming_blocks.items, incoming_values.length);
+        return phi;
+    } else {
+        return nullptr;
+    }
 }
 
 static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
test/self_hosted.zig
@@ -51,16 +51,24 @@ error SecondError;
 
 #attribute("test")
 fn constant_enum_with_payload() {
-    should_be_empty(AnEnumWithPayload.Empty);
-    should_be_not_empty(AnEnumWithPayload.Full(13));
+    var empty = AnEnumWithPayload.Empty;
+    var full = AnEnumWithPayload.Full(13);
+    should_be_empty(empty);
+    should_be_not_empty(full);
 }
 
 fn should_be_empty(x: AnEnumWithPayload) {
-    if (x != AnEnumWithPayload.Empty) unreachable{}
+    switch (x) {
+        AnEnumWithPayload.Empty => {},
+        else => unreachable{},
+    }
 }
 
 fn should_be_not_empty(x: AnEnumWithPayload) {
-    if (x == AnEnumWithPayload.Empty) unreachable{}
+    switch (x) {
+        AnEnumWithPayload.Empty => unreachable{},
+        else => {},
+    }
 }
 
 enum AnEnumWithPayload {