Commit 093e0d1bb0

Andrew Kelley <superjoe30@gmail.com>
2016-02-05 01:21:08
support variable in switch expression prongs
See #43
1 parent a4cba90
src/all_types.hpp
@@ -510,6 +510,7 @@ struct AstNodeSwitchProng {
     // populated by semantic analyzer
     BlockContext *block_context;
     VariableTableEntry *var;
+    bool var_is_target_expr;
 };
 
 struct AstNodeSwitchRange {
src/analyze.cpp
@@ -4495,6 +4495,7 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
             AstNode *prong_node = node->data.switch_expr.prongs.at(prong_i);
 
             TypeTableEntry *var_type;
+            bool var_is_target_expr;
             if (prong_node->data.switch_prong.items.length == 0) {
                 if (else_prong) {
                     add_node_error(g, prong_node, buf_sprintf("multiple else prongs in switch expression"));
@@ -4502,7 +4503,11 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
                     else_prong = prong_node;
                 }
                 var_type = expr_type;
+                var_is_target_expr = true;
             } else {
+                bool all_agree_on_var_type = true;
+                var_type = nullptr;
+
                 for (int item_i = 0; item_i < prong_node->data.switch_prong.items.length; item_i += 1) {
                     AstNode *item_node = prong_node->data.switch_prong.items.at(item_i);
                     if (item_node->type == NodeTypeSwitchRange) {
@@ -4515,6 +4520,12 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
                             TypeEnumField *type_enum_field = get_enum_field(expr_type, field_name);
                             if (type_enum_field) {
                                 item_node->data.symbol_expr.enum_field = type_enum_field;
+                                if (!var_type) {
+                                    var_type = type_enum_field->type_entry;
+                                }
+                                if (type_enum_field->type_entry != var_type) {
+                                    all_agree_on_var_type = false;
+                                }
                             } else {
                                 add_node_error(g, item_node,
                                         buf_sprintf("enum '%s' has no field '%s'",
@@ -4534,7 +4545,12 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
                         }
                     }
                 }
-                var_type = expr_type;
+                if (!var_type || !all_agree_on_var_type) {
+                    var_type = expr_type;
+                    var_is_target_expr = true;
+                } else {
+                    var_is_target_expr = false;
+                }
             }
 
             BlockContext *child_context = new_block_context(node, context);
@@ -4546,6 +4562,7 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
                 var_node->block_context = child_context;
                 prong_node->data.switch_prong.var = add_local_var(g, var_node, child_context, var_name,
                         var_type, true);
+                prong_node->data.switch_prong.var_is_target_expr = var_is_target_expr;
             }
 
             peer_types[prong_i] = analyze_expression(g, import, child_context, expected_type,
src/codegen.cpp
@@ -184,6 +184,15 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, TypeTableEntry *type_entry,
     return *fn;
 }
 
+static LLVMValueRef get_handle_value(CodeGen *g, AstNode *source_node, LLVMValueRef ptr, TypeTableEntry *type) {
+    if (handle_is_ptr(type)) {
+        return ptr;
+    } else {
+        add_debug_source_node(g, source_node);
+        return LLVMBuildLoad(g->builder, ptr, "");
+    }
+}
+
 static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeFnCallExpr);
     AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
@@ -1024,11 +1033,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
 
                 if (type_has_bits(child_type)) {
                     LLVMValueRef child_val_ptr = LLVMBuildStructGEP(g->builder, expr_val, 1, "");
-                    if (handle_is_ptr(child_type)) {
-                        return child_val_ptr;
-                    } else {
-                        return LLVMBuildLoad(g->builder, child_val_ptr, "");
-                    }
+                    return get_handle_value(g, expr_node, child_val_ptr, child_type);
                 } else {
                     return nullptr;
                 }
@@ -1073,11 +1078,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
                 } else {
                     add_debug_source_node(g, node);
                     LLVMValueRef maybe_field_ptr = LLVMBuildStructGEP(g->builder, expr_val, 0, "");
-                    if (handle_is_ptr(child_type)) {
-                        return maybe_field_ptr;
-                    } else {
-                        return LLVMBuildLoad(g->builder, maybe_field_ptr, "");
-                    }
+                    return get_handle_value(g, node, maybe_field_ptr, child_type);
                 }
             }
     }
@@ -1412,11 +1413,7 @@ static LLVMValueRef gen_unwrap_maybe(CodeGen *g, AstNode *node, LLVMValueRef may
     } else {
         add_debug_source_node(g, node);
         LLVMValueRef maybe_field_ptr = LLVMBuildStructGEP(g->builder, maybe_struct_ref, 0, "");
-        if (handle_is_ptr(child_type)) {
-            return maybe_field_ptr;
-        } else {
-            return LLVMBuildLoad(g->builder, maybe_field_ptr, "");
-        }
+        return get_handle_value(g, node, maybe_field_ptr, child_type);
     }
 }
 
@@ -1580,12 +1577,7 @@ static LLVMValueRef gen_unwrap_err_expr(CodeGen *g, AstNode *node) {
         return nullptr;
     }
     LLVMValueRef child_val_ptr = LLVMBuildStructGEP(g->builder, expr_val, 1, "");
-    LLVMValueRef child_val;
-    if (handle_is_ptr(child_type)) {
-        child_val = child_val_ptr;
-    } else {
-        child_val = LLVMBuildLoad(g->builder, child_val_ptr, "");
-    }
+    LLVMValueRef child_val = get_handle_value(g, node, child_val_ptr, child_type);
 
     if (!have_end_block) {
         return child_val;
@@ -1667,11 +1659,7 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
                 if (type_has_bits(child_type)) {
                     add_debug_source_node(g, node);
                     LLVMValueRef val_ptr = LLVMBuildStructGEP(g->builder, value, 1, "");
-                    if (handle_is_ptr(child_type)) {
-                        return val_ptr;
-                    } else {
-                        return LLVMBuildLoad(g->builder, val_ptr, "");
-                    }
+                    return get_handle_value(g, node, val_ptr, child_type);
                 } else {
                     return nullptr;
                 }
@@ -2294,12 +2282,7 @@ static LLVMValueRef gen_symbol(CodeGen *g, AstNode *node) {
             return nullptr;
         } else if (variable->is_ptr) {
             assert(variable->value_ref);
-            if (handle_is_ptr(variable->type)) {
-                return variable->value_ref;
-            } else {
-                add_debug_source_node(g, node);
-                return LLVMBuildLoad(g->builder, variable->value_ref, "");
-            }
+            return get_handle_value(g, node, variable->value_ref, variable->type);
         } else {
             return variable->value_ref;
         }
@@ -2347,6 +2330,8 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
     AstNode *else_prong = nullptr;
     for (int prong_i = 0; prong_i < prong_count; prong_i += 1) {
         AstNode *prong_node = node->data.switch_expr.prongs.at(prong_i);
+        VariableTableEntry *prong_var = prong_node->data.switch_prong.var;
+
         LLVMBasicBlockRef prong_block;
         if (prong_node->data.switch_prong.items.length == 0) {
             assert(!else_prong);
@@ -2354,8 +2339,12 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
             prong_block = else_block;
         } else {
             prong_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "SwitchProng");
-            for (int item_i = 0; item_i < prong_node->data.switch_prong.items.length; item_i += 1) {
+            int prong_item_count = prong_node->data.switch_prong.items.length;
+            bool make_item_blocks = prong_var && prong_item_count > 1;
+
+            for (int item_i = 0; item_i < prong_item_count; item_i += 1) {
                 AstNode *item_node = prong_node->data.switch_prong.items.at(item_i);
+
                 assert(item_node->type != NodeTypeSwitchRange);
                 LLVMValueRef val;
                 if (target_type->id == TypeTableEntryIdEnum) {
@@ -2364,14 +2353,50 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
                     assert(enum_field);
                     val = LLVMConstInt(target_type->data.enumeration.tag_type->type_ref,
                             enum_field->value, false);
+
+                    if (prong_var && type_has_bits(prong_var->type)) {
+                        LLVMBasicBlockRef item_block;
+
+                        if (make_item_blocks) {
+                            item_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "SwitchProngItem");
+                            LLVMAddCase(switch_instr, val, item_block);
+                            LLVMPositionBuilderAtEnd(g->builder, item_block);
+                        } else {
+                            LLVMAddCase(switch_instr, val, prong_block);
+                            LLVMPositionBuilderAtEnd(g->builder, prong_block);
+                        }
+
+                        AstNode *var_node = prong_node->data.switch_prong.var_symbol;
+                        add_debug_source_node(g, var_node);
+                        if (prong_node->data.switch_prong.var_is_target_expr) {
+                            gen_assign_raw(g, var_node, BinOpTypeAssign,
+                                    prong_var->value_ref, target_value, prong_var->type, target_type);
+                        } else if (target_type->id == TypeTableEntryIdEnum) {
+                            assert(type_has_bits(enum_field->type_entry));
+                            LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, target_value_handle,
+                                    1, "");
+                            LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr,
+                                    LLVMPointerType(enum_field->type_entry->type_ref, 0), "");
+                            LLVMValueRef handle_val = get_handle_value(g, var_node, bitcasted_union_field_ptr,
+                                    enum_field->type_entry);
+
+                            gen_assign_raw(g, var_node, BinOpTypeAssign,
+                                    prong_var->value_ref, handle_val, prong_var->type, enum_field->type_entry);
+                        }
+                        if (make_item_blocks) {
+                            LLVMBuildBr(g->builder, prong_block);
+                        }
+                    } else {
+                        LLVMAddCase(switch_instr, val, prong_block);
+                    }
                 } else {
                     assert(get_resolved_expr(item_node)->const_val.ok);
                     val = gen_expr(g, item_node);
+                    LLVMAddCase(switch_instr, val, prong_block);
                 }
-                LLVMAddCase(switch_instr, val, prong_block);
             }
         }
-        assert(!prong_node->data.switch_prong.var_symbol);
+
         LLVMPositionBuilderAtEnd(g->builder, prong_block);
         AstNode *prong_expr = prong_node->data.switch_prong.expr;
         LLVMValueRef prong_val = gen_expr(g, prong_expr);
test/self_hosted.zig
@@ -147,3 +147,29 @@ enum SwitchStatmentFoo {
     C,
     D,
 }
+
+
+#attribute("test")
+fn switch_prong_with_var() {
+    switch_prong_with_var_fn(SwitchProngWithVarEnum.One(13));
+    switch_prong_with_var_fn(SwitchProngWithVarEnum.Two(13.0));
+    switch_prong_with_var_fn(SwitchProngWithVarEnum.Meh);
+}
+enum SwitchProngWithVarEnum {
+    One: i32,
+    Two: f32,
+    Meh,
+}
+fn switch_prong_with_var_fn(a: SwitchProngWithVarEnum) {
+    switch(a) {
+        One => |x| {
+            if (x != 13) unreachable{};
+        },
+        Two => |x| {
+            if (x != 13.0) unreachable{};
+        },
+        Meh => |x| {
+            const v: void = x;
+        },
+    }
+}