Commit bcdb3a9006

Andrew Kelley <andrew@ziglang.org>
2019-11-28 06:02:53
more progress
1 parent bf3ac66
Changed files (5)
src/ir.cpp
@@ -41,6 +41,9 @@ struct IrAnalyze {
     ZigList<IrInstruction *> src_implicit_return_type_list;
     ZigList<IrSuspendPosition> resume_stack;
     IrBasicBlock *const_predecessor_bb;
+
+    // For the purpose of using in a debugger
+    void dump();
 };
 
 enum ConstCastResultId {
@@ -350,6 +353,7 @@ static bool types_have_same_zig_comptime_repr(CodeGen *codegen, ZigType *expecte
         case ZigTypeIdErrorSet:
         case ZigTypeIdOpaque:
         case ZigTypeIdAnyFrame:
+        case ZigTypeIdFn:
             return true;
         case ZigTypeIdFloat:
             return expected->data.floating.bit_count == actual->data.floating.bit_count;
@@ -361,7 +365,6 @@ static bool types_have_same_zig_comptime_repr(CodeGen *codegen, ZigType *expecte
         case ZigTypeIdErrorUnion:
         case ZigTypeIdEnum:
         case ZigTypeIdUnion:
-        case ZigTypeIdFn:
         case ZigTypeIdArgTuple:
         case ZigTypeIdVector:
         case ZigTypeIdFnFrame:
@@ -3941,6 +3944,7 @@ static IrInstruction *ir_gen_block(IrBuilder *irb, Scope *parent_scope, AstNode
         scope_block->peer_parent = allocate<ResultLocPeerParent>(1, "ResultLocPeerParent");
         scope_block->peer_parent->base.id = ResultLocIdPeerParent;
         scope_block->peer_parent->base.source_instruction = scope_block->is_comptime;
+        scope_block->peer_parent->base.allow_write_through_const = result_loc->allow_write_through_const;
         scope_block->peer_parent->end_bb = scope_block->end_block;
         scope_block->peer_parent->is_comptime = scope_block->is_comptime;
         scope_block->peer_parent->parent = result_loc;
@@ -4195,6 +4199,7 @@ static ResultLocPeerParent *ir_build_result_peers(IrBuilder *irb, IrInstruction
     ResultLocPeerParent *peer_parent = allocate<ResultLocPeerParent>(1);
     peer_parent->base.id = ResultLocIdPeerParent;
     peer_parent->base.source_instruction = cond_br_inst;
+    peer_parent->base.allow_write_through_const = parent->allow_write_through_const;
     peer_parent->end_bb = end_block;
     peer_parent->is_comptime = is_comptime;
     peer_parent->parent = parent;
@@ -6388,6 +6393,7 @@ static ResultLocVar *ir_build_var_result_loc(IrBuilder *irb, IrInstruction *allo
     ResultLocVar *result_loc_var = allocate<ResultLocVar>(1);
     result_loc_var->base.id = ResultLocIdVar;
     result_loc_var->base.source_instruction = alloca;
+    result_loc_var->base.allow_write_through_const = true;
     result_loc_var->var = var;
 
     ir_build_reset_result(irb, alloca->scope, alloca->source_node, &result_loc_var->base);
@@ -6401,6 +6407,7 @@ static ResultLocCast *ir_build_cast_result_loc(IrBuilder *irb, IrInstruction *de
     ResultLocCast *result_loc_cast = allocate<ResultLocCast>(1);
     result_loc_cast->base.id = ResultLocIdCast;
     result_loc_cast->base.source_instruction = dest_type;
+    result_loc_cast->base.allow_write_through_const = parent_result_loc->allow_write_through_const;
     ir_ref_instruction(dest_type, irb->current_basic_block);
     result_loc_cast->parent = parent_result_loc;
 
@@ -7581,6 +7588,7 @@ static IrInstruction *ir_gen_switch_expr(IrBuilder *irb, Scope *scope, AstNode *
 
     ResultLocPeerParent *peer_parent = allocate<ResultLocPeerParent>(1);
     peer_parent->base.id = ResultLocIdPeerParent;
+    peer_parent->base.allow_write_through_const = result_loc->allow_write_through_const;
     peer_parent->end_bb = end_block;
     peer_parent->is_comptime = is_comptime;
     peer_parent->parent = result_loc;
@@ -13396,16 +13404,24 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst
         return ir_analyze_undefined_to_anything(ira, source_instr, value, wanted_type);
     }
 
-    // T to ?E!T
-    if (wanted_type->id == ZigTypeIdOptional && wanted_type->data.maybe.child_type->id == ZigTypeIdErrorUnion &&
-        actual_type->id != ZigTypeIdOptional)
-    {
+    // T to ?U, where T implicitly casts to U
+    if (wanted_type->id == ZigTypeIdOptional && actual_type->id != ZigTypeIdOptional) {
         IrInstruction *cast1 = ir_implicit_cast2(ira, source_instr, value, wanted_type->data.maybe.child_type);
         if (type_is_invalid(cast1->value->type))
             return ira->codegen->invalid_instruction;
         return ir_implicit_cast2(ira, source_instr, cast1, wanted_type);
     }
 
+    // T to E!U, where T implicitly casts to U
+    if (wanted_type->id == ZigTypeIdErrorUnion && actual_type->id != ZigTypeIdErrorUnion &&
+        actual_type->id != ZigTypeIdErrorSet)
+    {
+        IrInstruction *cast1 = ir_implicit_cast2(ira, source_instr, value, wanted_type->data.error_union.payload_type);
+        if (type_is_invalid(cast1->value->type))
+            return ira->codegen->invalid_instruction;
+        return ir_implicit_cast2(ira, source_instr, cast1, wanted_type);
+    }
+
     ErrorMsg *parent_msg = ir_add_error_node(ira, source_instr->source_node,
         buf_sprintf("expected type '%s', found '%s'",
             buf_ptr(&wanted_type->name),
@@ -16046,7 +16062,7 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe
             bool force_comptime;
             if (!ir_resolve_comptime(ira, alloca_src->is_comptime->child, &force_comptime))
                 return ira->codegen->invalid_instruction;
-            bool is_comptime = force_comptime || (value != nullptr &&
+            bool is_comptime = force_comptime || (!force_runtime && value != nullptr &&
                     value->value->special != ConstValSpecialRuntime && result_loc_var->var->gen_is_const);
 
             if (alloca_src->base.child == nullptr || is_comptime) {
@@ -16064,7 +16080,7 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe
                     alloca_gen = ir_analyze_alloca(ira, result_loc->source_instruction, value_type, align,
                             alloca_src->name_hint, force_comptime);
                 }
-                if (alloca_src->base.child != nullptr) {
+                if (alloca_src->base.child != nullptr && !result_loc->written) {
                     alloca_src->base.child->ref_count = 0;
                 }
                 alloca_src->base.child = alloca_gen;
@@ -16079,6 +16095,10 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe
             return result_loc->resolved_loc;
         }
         case ResultLocIdReturn: {
+            if (value != nullptr) {
+                reinterpret_cast<ResultLocReturn *>(result_loc)->implicit_return_type_done = true;
+                ira->src_implicit_return_type_list.append(value);
+            }
             if (!non_null_comptime) {
                 bool is_comptime = value != nullptr && value->value->special != ConstValSpecialRuntime;
                 if (is_comptime)
@@ -16121,10 +16141,10 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe
                 return result_loc->resolved_loc;
             }
 
-            bool is_comptime;
-            if (!ir_resolve_comptime(ira, peer_parent->is_comptime->child, &is_comptime))
+            bool is_condition_comptime;
+            if (!ir_resolve_comptime(ira, peer_parent->is_comptime->child, &is_condition_comptime))
                 return ira->codegen->invalid_instruction;
-            if (is_comptime) {
+            if (is_condition_comptime) {
                 peer_parent->skipped = true;
                 if (non_null_comptime) {
                     return ir_resolve_result(ira, suspend_source_instr, peer_parent->parent,
@@ -16136,17 +16156,18 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe
             if ((err = ir_result_has_type(ira, peer_parent->parent, &peer_parent_has_type)))
                 return ira->codegen->invalid_instruction;
             if (peer_parent_has_type) {
-                if (peer_parent->parent->id == ResultLocIdReturn && value != nullptr) {
-                    reinterpret_cast<ResultLocReturn *>(peer_parent->parent)->implicit_return_type_done = true;
-                    ira->src_implicit_return_type_list.append(value);
-                }
                 peer_parent->skipped = true;
                 IrInstruction *parent_result_loc = ir_resolve_result(ira, suspend_source_instr, peer_parent->parent,
-                        value_type, value, force_runtime || !is_comptime, true, true);
-                if (parent_result_loc != nullptr) {
-                    peer_parent->parent->written = true;
+                        value_type, value, force_runtime || !is_condition_comptime, true, true);
+                if (parent_result_loc == nullptr || type_is_invalid(parent_result_loc->value->type) ||
+                    parent_result_loc->value->type->id == ZigTypeIdUnreachable)
+                {
+                    return parent_result_loc;
                 }
-                return parent_result_loc;
+                peer_parent->parent->written = true;
+                result_loc->written = true;
+                result_loc->resolved_loc = parent_result_loc;
+                return result_loc->resolved_loc;
             }
 
             if (peer_parent->resolved_type == nullptr) {
@@ -16168,14 +16189,14 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe
             {
                 return parent_result_loc;
             }
-            // because is_comptime is false, we mark this a runtime pointer
+            // because is_condition_comptime is false, we mark this a runtime pointer
             parent_result_loc->value->special = ConstValSpecialRuntime;
             result_loc->written = true;
             result_loc->resolved_loc = parent_result_loc;
             return result_loc->resolved_loc;
         }
         case ResultLocIdCast: {
-            if (value != nullptr && value->value->special != ConstValSpecialRuntime)
+            if (value != nullptr && value->value->special != ConstValSpecialRuntime && !non_null_comptime)
                 return nullptr;
             ResultLocCast *result_cast = reinterpret_cast<ResultLocCast *>(result_loc);
             ZigType *dest_type = ir_resolve_type(ira, result_cast->base.source_instruction->child);
@@ -16204,6 +16225,7 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe
             {
                 return parent_result_loc;
             }
+
             ZigType *parent_ptr_type = parent_result_loc->value->type;
             assert(parent_ptr_type->id == ZigTypeIdPointer);
             if ((err = type_resolve(ira->codegen, parent_ptr_type->data.pointer.child_type,
@@ -16354,7 +16376,7 @@ static IrInstruction *ir_resolve_result(IrAnalyze *ira, IrInstruction *suspend_s
     {
         result_loc_pass1->written = false;
         return ir_analyze_unwrap_optional_payload(ira, suspend_source_instr, result_loc, false, true);
-    } else if (actual_elem_type->id == ZigTypeIdErrorUnion && value_type->id != ZigTypeIdErrorUnion) {
+    } else if (actual_elem_type->id == ZigTypeIdErrorUnion && value_type->id != ZigTypeIdErrorUnion && value == nullptr) {
         if (value_type->id == ZigTypeIdErrorSet) {
             return ir_analyze_unwrap_err_code(ira, suspend_source_instr, result_loc, true);
         } else {
@@ -28238,3 +28260,11 @@ void IrInstruction::dump() {
         ir_print_instruction(inst->scope->codegen, stderr, inst->child, 0, IrPassGen);
     }
 }
+
+void IrAnalyze::dump() {
+    ir_print(this->codegen, stderr, this->new_irb.exec, 0, IrPassGen);
+    if (this->new_irb.current_basic_block != nullptr) {
+        fprintf(stderr, "Current basic block:\n");
+        ir_print_basic_block(this->codegen, stderr, this->new_irb.current_basic_block, 1, IrPassGen);
+    }
+}
src/ir_print.cpp
@@ -2530,6 +2530,37 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction, bool
     fprintf(irp->f, "\n");
 }
 
+static void irp_print_basic_block(IrPrint *irp, IrBasicBlock *current_block) {
+    fprintf(irp->f, "%s_%" ZIG_PRI_usize ":\n", current_block->name_hint, current_block->debug_id);
+    for (size_t instr_i = 0; instr_i < current_block->instruction_list.length; instr_i += 1) {
+        IrInstruction *instruction = current_block->instruction_list.at(instr_i);
+        if (irp->pass != IrPassSrc) {
+            irp->printed.put(instruction, 0);
+            irp->pending.clear();
+        }
+        ir_print_instruction(irp, instruction, false);
+        for (size_t j = 0; j < irp->pending.length; ++j)
+            ir_print_instruction(irp, irp->pending.at(j), true);
+    }
+}
+
+void ir_print_basic_block(CodeGen *codegen, FILE *f, IrBasicBlock *bb, int indent_size, IrPass pass) {
+    IrPrint ir_print = {};
+    ir_print.pass = pass;
+    ir_print.codegen = codegen;
+    ir_print.f = f;
+    ir_print.indent = indent_size;
+    ir_print.indent_size = indent_size;
+    ir_print.printed = {};
+    ir_print.printed.init(64);
+    ir_print.pending = {};
+
+    irp_print_basic_block(&ir_print, bb);
+
+    ir_print.pending.deinit();
+    ir_print.printed.deinit();
+}
+
 void ir_print(CodeGen *codegen, FILE *f, IrExecutable *executable, int indent_size, IrPass pass) {
     IrPrint ir_print = {};
     IrPrint *irp = &ir_print;
@@ -2543,18 +2574,7 @@ void ir_print(CodeGen *codegen, FILE *f, IrExecutable *executable, int indent_si
     irp->pending = {};
 
     for (size_t bb_i = 0; bb_i < executable->basic_block_list.length; bb_i += 1) {
-        IrBasicBlock *current_block = executable->basic_block_list.at(bb_i);
-        fprintf(irp->f, "%s_%" ZIG_PRI_usize ":\n", current_block->name_hint, current_block->debug_id);
-        for (size_t instr_i = 0; instr_i < current_block->instruction_list.length; instr_i += 1) {
-            IrInstruction *instruction = current_block->instruction_list.at(instr_i);
-            if (irp->pass != IrPassSrc) {
-                irp->printed.put(instruction, 0);
-                irp->pending.clear();
-            }
-            ir_print_instruction(irp, instruction, false);
-            for (size_t j = 0; j < irp->pending.length; ++j)
-                ir_print_instruction(irp, irp->pending.at(j), true);
-        }
+        irp_print_basic_block(irp, executable->basic_block_list.at(bb_i));
     }
 
     irp->pending.deinit();
src/ir_print.hpp
@@ -15,6 +15,7 @@
 void ir_print(CodeGen *codegen, FILE *f, IrExecutable *executable, int indent_size, IrPass pass);
 void ir_print_instruction(CodeGen *codegen, FILE *f, IrInstruction *instruction, int indent_size, IrPass pass);
 void ir_print_const_expr(CodeGen *codegen, FILE *f, ZigValue *value, int indent_size, IrPass pass);
+void ir_print_basic_block(CodeGen *codegen, FILE *f, IrBasicBlock *bb, int indent_size, IrPass pass);
 
 const char* ir_instruction_type_str(IrInstructionId id);
 
test/stage1/behavior/cast.zig
@@ -697,3 +697,16 @@ test "cast i8 fn call peers to i32 result" {
     S.doTheTest();
     comptime S.doTheTest();
 }
+
+test "return u8 coercing into ?u32 return type" {
+    const S = struct {
+        fn doTheTest() void {
+            expect(foo(123).? == 123);
+        }
+        fn foo(arg: u8) ?u32 {
+            return arg;
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
test/stage1/behavior/union.zig
@@ -582,3 +582,32 @@ test "update the tag value for zero-sized unions" {
     x = S{ .U1 = {} };
     expect(x == .U1);
 }
+
+test "function call result coerces from tagged union to the tag" {
+    const S = struct {
+        const Arch = union(enum) {
+            One,
+            Two: usize,
+        };
+
+        const ArchTag = @TagType(Arch);
+
+        fn doTheTest() void {
+            var x: ArchTag = getArch1();
+            expect(x == .One);
+
+            var y: ArchTag = getArch2();
+            expect(y == .Two);
+        }
+
+        pub fn getArch1() Arch {
+            return .One;
+        }
+
+        pub fn getArch2() Arch {
+            return .{ .Two = 99 };
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}