Commit d0378057d1

Andrew Kelley <superjoe30@gmail.com>
2016-04-26 06:51:04
support switching on error union type
closes #23
1 parent d1b65c6
src/all_types.hpp
@@ -691,6 +691,7 @@ struct AstNodeSymbolExpr {
     // set this to instead of analyzing the node, pretend it's a type entry and it's this one.
     TypeTableEntry *override_type_entry;
     TypeEnumField *enum_field;
+    uint32_t err_value;
 };
 
 struct AstNodeBoolLiteral {
src/analyze.cpp
@@ -5116,8 +5116,11 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
 
 
     int *field_use_counts = nullptr;
+    HashMap<int, AstNode *, int_hash, int_eq> err_use_nodes;
     if (expr_type->id == TypeTableEntryIdEnum) {
         field_use_counts = allocate<int>(expr_type->data.enumeration.field_count);
+    } else if (expr_type->id == TypeTableEntryIdErrorUnion) {
+        err_use_nodes.init(10);
     }
 
     int *const_chosen_prong_index = &node->data.switch_expr.const_chosen_prong_index;
@@ -5186,8 +5189,54 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
                         add_node_error(g, item_node, buf_sprintf("expected enum tag name"));
                         any_errors = true;
                     }
+                } else if (expr_type->id == TypeTableEntryIdErrorUnion) {
+                    if (item_node->type == NodeTypeSymbol) {
+                        Buf *err_name = &item_node->data.symbol_expr.symbol;
+                        bool is_ok_case = buf_eql_str(err_name, "Ok");
+                        auto err_table_entry = is_ok_case ? nullptr: g->error_table.maybe_get(err_name);
+                        if (is_ok_case || err_table_entry) {
+                            uint32_t err_value = is_ok_case ? 0 : err_table_entry->value->value;
+                            item_node->data.symbol_expr.err_value = err_value;
+                            TypeTableEntry *this_var_type;
+                            if (is_ok_case) {
+                                this_var_type = expr_type->data.error.child_type;
+                            } else {
+                                this_var_type = g->builtin_types.entry_pure_error;
+                            }
+                            if (!var_type) {
+                                var_type = this_var_type;
+                            }
+                            if (this_var_type != var_type) {
+                                all_agree_on_var_type = false;
+                            }
+
+                            // detect duplicate switch values
+                            auto existing_entry = err_use_nodes.maybe_get(err_value);
+                            if (existing_entry) {
+                                add_node_error(g, existing_entry->value,
+                                        buf_sprintf("duplicate switch value: '%s'", buf_ptr(err_name)));
+                                any_errors = true;
+                            } else {
+                                err_use_nodes.put(err_value, item_node);
+                            }
+
+                            if (!any_errors && expr_val->ok) {
+                                if (expr_val->data.x_err.err->value == err_value) {
+                                    *const_chosen_prong_index = prong_i;
+                                }
+                            }
+                        } else {
+                            add_node_error(g, item_node,
+                                    buf_sprintf("use of undeclared error value '%s'", buf_ptr(err_name)));
+                            any_errors = true;
+                        }
+                    } else {
+                        add_node_error(g, item_node, buf_sprintf("expected error value name"));
+                        any_errors = true;
+                    }
                 } else {
                     if (!any_errors && expr_val->ok) {
+                        // note: there is now a function in eval.cpp for doing const expr comparison
                         zig_panic("TODO determine if const exprs are equal");
                     }
                     TypeTableEntry *item_type = analyze_expression(g, import, context, expr_type, item_node);
@@ -5252,17 +5301,25 @@ static TypeTableEntry *analyze_switch_expr(CodeGen *g, ImportTableEntry *import,
         return g->builtin_types.entry_invalid;
     }
 
+    TypeTableEntry *result_type = resolve_peer_type_compatibility(g, import, context, node,
+            peer_nodes, peer_types, prong_count);
+
     if (expr_val->ok) {
         assert(*const_chosen_prong_index != -1);
 
         *const_val = get_resolved_expr(peer_nodes[*const_chosen_prong_index])->const_val;
-        // the target expr depends on a compile var,
-        // so the entire if statement does too
+        // the target expr depends on a compile var because we have an error on unnecessary
+        // switch statement, so the entire switch statement does too
         const_val->depends_on_compile_var = true;
-    }
 
+        if (!const_val->ok) {
+            return add_error_if_type_is_num_lit(g, result_type, node);
+        }
+    } else {
+        return add_error_if_type_is_num_lit(g, result_type, node);
+    }
 
-    return resolve_peer_type_compatibility(g, import, context, node, peer_nodes, peer_types, prong_count);
+    return result_type;
 }
 
 static TypeTableEntry *analyze_return_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
src/codegen.cpp
@@ -2627,12 +2627,6 @@ static LLVMValueRef gen_symbol(CodeGen *g, AstNode *node) {
     }
 
     zig_unreachable();
-
-    /* TODO delete
-    FnTableEntry *fn_entry = node->data.symbol_expr.fn_entry;
-    assert(fn_entry);
-    return fn_entry->fn_value;
-    */
 }
 
 static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
@@ -2653,6 +2647,10 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
             add_debug_source_node(g, node);
             LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, target_value_handle, 0, "");
             target_value = LLVMBuildLoad(g->builder, tag_field_ptr, "");
+        } else if (target_type->id == TypeTableEntryIdErrorUnion) {
+            add_debug_source_node(g, node);
+            LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, target_value_handle, 0, "");
+            target_value = LLVMBuildLoad(g->builder, tag_field_ptr, "");
         } else {
             zig_unreachable();
         }
@@ -2696,12 +2694,23 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
 
                 assert(item_node->type != NodeTypeSwitchRange);
                 LLVMValueRef val;
-                if (target_type->id == TypeTableEntryIdEnum) {
+                if (target_type->id == TypeTableEntryIdEnum ||
+                    target_type->id == TypeTableEntryIdErrorUnion)
+                {
                     assert(item_node->type == NodeTypeSymbol);
-                    TypeEnumField *enum_field = item_node->data.symbol_expr.enum_field;
-                    assert(enum_field);
-                    val = LLVMConstInt(target_type->data.enumeration.tag_type->type_ref,
-                            enum_field->value, false);
+                    TypeEnumField *enum_field = nullptr;
+                    uint32_t err_value = 0;
+                    if (target_type->id == TypeTableEntryIdEnum) {
+                        enum_field = item_node->data.symbol_expr.enum_field;
+                        assert(enum_field);
+                        val = LLVMConstInt(target_type->data.enumeration.tag_type->type_ref,
+                                enum_field->value, false);
+                    } else if (target_type->id == TypeTableEntryIdErrorUnion) {
+                        err_value = item_node->data.symbol_expr.err_value;
+                        val = LLVMConstInt(g->err_tag_type->type_ref, err_value, false);
+                    } else {
+                        zig_unreachable();
+                    }
 
                     if (prong_var && type_has_bits(prong_var->type)) {
                         LLVMBasicBlockRef item_block;
@@ -2721,6 +2730,7 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
                             gen_assign_raw(g, var_node, BinOpTypeAssign,
                                     prong_var->value_ref, target_value, prong_var->type, target_type);
                         } else if (target_type->id == TypeTableEntryIdEnum) {
+                            assert(enum_field);
                             assert(type_has_bits(enum_field->type_entry));
                             LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, target_value_handle,
                                     1, "");
@@ -2731,6 +2741,25 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
 
                             gen_assign_raw(g, var_node, BinOpTypeAssign,
                                     prong_var->value_ref, handle_val, prong_var->type, enum_field->type_entry);
+                        } else if (target_type->id == TypeTableEntryIdErrorUnion) {
+                            if (err_value == 0) {
+                                // variable is the payload
+                                LLVMValueRef err_payload_ptr = LLVMBuildStructGEP(g->builder,
+                                        target_value_handle, 1, "");
+                                LLVMValueRef handle_val = get_handle_value(g, var_node,
+                                        err_payload_ptr, prong_var->type);
+                                gen_assign_raw(g, var_node, BinOpTypeAssign,
+                                        prong_var->value_ref, handle_val, prong_var->type, prong_var->type);
+                            } else {
+                                // variable is the pure error value
+                                LLVMValueRef err_tag_ptr = LLVMBuildStructGEP(g->builder,
+                                        target_value_handle, 0, "");
+                                LLVMValueRef handle_val = LLVMBuildLoad(g->builder, err_tag_ptr, "");
+                                gen_assign_raw(g, var_node, BinOpTypeAssign,
+                                        prong_var->value_ref, handle_val, prong_var->type, g->err_tag_type);
+                            }
+                        } else {
+                            zig_unreachable();
                         }
                         if (make_item_blocks) {
                             LLVMBuildBr(g->builder, prong_block);
test/run_tests.cpp
@@ -1233,6 +1233,18 @@ fn bad_eql_2(a: EnumWithData, b: EnumWithData) -> bool {
     )SOURCE", 2,
             ".tmp_source.zig:3:7: error: operator not allowed for type '[]u8'",
             ".tmp_source.zig:10:7: error: operator not allowed for type 'EnumWithData'");
+
+    add_compile_fail_case("non-const switch number literal", R"SOURCE(
+fn foo() {
+    const x = switch (bar()) {
+        1, 2 => 1,
+        3, 4 => 2,
+        else => 3,
+    };
+}
+#static_eval_enable(false)
+fn bar() -> i32 { 2 }
+    )SOURCE", 1, ".tmp_source.zig:3:15: error: unable to infer expression type");
 }
 
 //////////////////////////////////////////////////////////////////////////////
test/self_hosted.zig
@@ -1339,3 +1339,32 @@ fn character_literals() {
     assert('\'' == single_quote);
 }
 const single_quote = '\'';
+
+
+#attribute("test")
+fn switch_with_multiple_expressions() {
+    const x: i32 = switch (returns_five()) {
+        1, 2, 3 => 1,
+        4, 5, 6 => 2,
+        else => 3,
+    };
+    assert(x == 2);
+}
+#static_eval_enable(false)
+fn returns_five() -> i32 { 5 }
+
+
+#attribute("test")
+fn switch_on_error_union() {
+    const x = switch (returns_ten()) {
+        Ok => |val| val + 1,
+        ItBroke, NoMem => 1,
+        CrappedOut => 2,
+    };
+    assert(x == 11);
+}
+error ItBroke;
+error NoMem;
+error CrappedOut;
+#static_eval_enable(false)
+fn returns_ten() -> %i32 { 10 }