Commit 2a6fbbd8fb

Andrew Kelley <andrew@ziglang.org>
2019-11-07 05:21:31
introduce `@as` builtin for type coercion
This commit also hooks up type coercion (previously called implicit casting) into the result location mechanism, and additionally hooks up variable declarations, maintaining the property that: var a: T = b; is semantically equivalent to: var a = @as(T, b); See #1757
1 parent 6d28b28
src/all_types.hpp
@@ -48,6 +48,7 @@ struct ResultLoc;
 struct ResultLocPeer;
 struct ResultLocPeerParent;
 struct ResultLocBitCast;
+struct ResultLocCast;
 struct ResultLocReturn;
 
 enum PtrLen {
@@ -1691,6 +1692,7 @@ enum BuiltinFnId {
     BuiltinFnIdFrameType,
     BuiltinFnIdFrameHandle,
     BuiltinFnIdFrameSize,
+    BuiltinFnIdAs,
 };
 
 struct BuiltinFnEntry {
@@ -3458,6 +3460,13 @@ struct IrInstructionPtrCastGen {
     bool safety_check_on;
 };
 
+struct IrInstructionImplicitCast {
+    IrInstruction base;
+
+    IrInstruction *operand;
+    ResultLocCast *result_loc_cast;
+};
+
 struct IrInstructionBitCastSrc {
     IrInstruction base;
 
@@ -3823,14 +3832,6 @@ struct IrInstructionEndExpr {
     ResultLoc *result_loc;
 };
 
-struct IrInstructionImplicitCast {
-    IrInstruction base;
-
-    IrInstruction *dest_type;
-    IrInstruction *target;
-    ResultLoc *result_loc;
-};
-
 // This one is for writing through the result pointer.
 struct IrInstructionResolveResult {
     IrInstruction base;
@@ -3928,6 +3929,7 @@ enum ResultLocId {
     ResultLocIdPeerParent,
     ResultLocIdInstruction,
     ResultLocIdBitCast,
+    ResultLocIdCast,
 };
 
 // Additions to this struct may need to be handled in
@@ -3995,6 +3997,13 @@ struct ResultLocBitCast {
     ResultLoc *parent;
 };
 
+// The source_instruction is the destination type
+struct ResultLocCast {
+    ResultLoc base;
+
+    ResultLoc *parent;
+};
+
 static const size_t slice_ptr_index = 0;
 static const size_t slice_len_index = 1;
 
src/codegen.cpp
@@ -8070,6 +8070,7 @@ static void define_builtin_fns(CodeGen *g) {
     create_builtin_fn(g, BuiltinFnIdFrameType, "Frame", 1);
     create_builtin_fn(g, BuiltinFnIdFrameAddress, "frameAddress", 0);
     create_builtin_fn(g, BuiltinFnIdFrameSize, "frameSize", 1);
+    create_builtin_fn(g, BuiltinFnIdAs, "as", 2);
 }
 
 static const char *bool_to_str(bool b) {
src/ir.cpp
@@ -200,6 +200,8 @@ static IrInstruction *ir_gen_union_init_expr(IrBuilder *irb, Scope *scope, AstNo
 static void ir_reset_result(ResultLoc *result_loc);
 static Buf *get_anon_type_name(CodeGen *codegen, IrExecutable *exec, const char *kind_name,
         Scope *scope, AstNode *source_node, Buf *out_bare_name);
+static ResultLocCast *ir_build_cast_result_loc(IrBuilder *irb, IrInstruction *dest_type,
+        ResultLoc *parent_result_loc);
 
 static ConstExprValue *const_ptr_pointee_unchecked(CodeGen *g, ConstExprValue *const_val) {
     assert(get_src_ptr_type(const_val->type) != nullptr);
@@ -2766,6 +2768,18 @@ static IrInstruction *ir_build_load_ptr_gen(IrAnalyze *ira, IrInstruction *sourc
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_implicit_cast(IrBuilder *irb, Scope *scope, AstNode *source_node,
+        IrInstruction *operand, ResultLocCast *result_loc_cast)
+{
+    IrInstructionImplicitCast *instruction = ir_build_instruction<IrInstructionImplicitCast>(irb, scope, source_node);
+    instruction->operand = operand;
+    instruction->result_loc_cast = result_loc_cast;
+
+    ir_ref_instruction(operand, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
 static IrInstruction *ir_build_bit_cast_src(IrBuilder *irb, Scope *scope, AstNode *source_node,
         IrInstruction *operand, ResultLocBitCast *result_loc_bit_cast)
 {
@@ -3063,20 +3077,6 @@ static IrInstruction *ir_build_align_cast(IrBuilder *irb, Scope *scope, AstNode
     return &instruction->base;
 }
 
-static IrInstruction *ir_build_implicit_cast(IrBuilder *irb, Scope *scope, AstNode *source_node,
-        IrInstruction *dest_type, IrInstruction *target, ResultLoc *result_loc)
-{
-    IrInstructionImplicitCast *instruction = ir_build_instruction<IrInstructionImplicitCast>(irb, scope, source_node);
-    instruction->dest_type = dest_type;
-    instruction->target = target;
-    instruction->result_loc = result_loc;
-
-    ir_ref_instruction(dest_type, irb->current_basic_block);
-    ir_ref_instruction(target, irb->current_basic_block);
-
-    return &instruction->base;
-}
-
 static IrInstruction *ir_build_resolve_result(IrBuilder *irb, Scope *scope, AstNode *source_node,
         ResultLoc *result_loc, IrInstruction *ty)
 {
@@ -5374,6 +5374,24 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
                 IrInstruction *bitcast = ir_build_bit_cast_src(irb, scope, arg1_node, arg1_value, result_loc_bit_cast);
                 return ir_lval_wrap(irb, scope, bitcast, lval, result_loc);
             }
+        case BuiltinFnIdAs:
+            {
+                AstNode *dest_type_node = node->data.fn_call_expr.params.at(0);
+                IrInstruction *dest_type = ir_gen_node(irb, dest_type_node, scope);
+                if (dest_type == irb->codegen->invalid_instruction)
+                    return dest_type;
+
+                ResultLocCast *result_loc_cast = ir_build_cast_result_loc(irb, dest_type, result_loc);
+
+                AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
+                IrInstruction *arg1_value = ir_gen_node_extra(irb, arg1_node, scope, LValNone,
+                        &result_loc_cast->base);
+                if (arg1_value == irb->codegen->invalid_instruction)
+                    return arg1_value;
+
+                IrInstruction *result = ir_build_implicit_cast(irb, scope, node, arg1_value, result_loc_cast);
+                return ir_lval_wrap(irb, scope, result, lval, result_loc);
+            }
         case BuiltinFnIdIntToPtr:
             {
                 AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
@@ -6214,6 +6232,20 @@ static ResultLocVar *ir_build_var_result_loc(IrBuilder *irb, IrInstruction *allo
     return result_loc_var;
 }
 
+static ResultLocCast *ir_build_cast_result_loc(IrBuilder *irb, IrInstruction *dest_type,
+        ResultLoc *parent_result_loc)
+{
+    ResultLocCast *result_loc_cast = allocate<ResultLocCast>(1);
+    result_loc_cast->base.id = ResultLocIdCast;
+    result_loc_cast->base.source_instruction = dest_type;
+    ir_ref_instruction(dest_type, irb->current_basic_block);
+    result_loc_cast->parent = parent_result_loc;
+
+    ir_build_reset_result(irb, dest_type->scope, dest_type->source_node, &result_loc_cast->base);
+
+    return result_loc_cast;
+}
+
 static void build_decl_var_and_init(IrBuilder *irb, Scope *scope, AstNode *source_node, ZigVar *var,
         IrInstruction *init, const char *name_hint, IrInstruction *is_comptime)
 {
@@ -6282,7 +6314,15 @@ static IrInstruction *ir_gen_var_decl(IrBuilder *irb, Scope *scope, AstNode *nod
 
     // Create a result location for the initialization expression.
     ResultLocVar *result_loc_var = ir_build_var_result_loc(irb, alloca, var);
-    ResultLoc *init_result_loc = (type_instruction == nullptr) ? &result_loc_var->base : nullptr;
+    ResultLoc *init_result_loc;
+    ResultLocCast *result_loc_cast;
+    if (type_instruction != nullptr) {
+        result_loc_cast = ir_build_cast_result_loc(irb, type_instruction, &result_loc_var->base);
+        init_result_loc = &result_loc_cast->base;
+    } else {
+        result_loc_cast = nullptr;
+        init_result_loc = &result_loc_var->base;
+    }
 
     Scope *init_scope = is_comptime_scalar ?
         create_comptime_scope(irb->codegen, variable_declaration->expr, scope) : scope;
@@ -6298,9 +6338,8 @@ static IrInstruction *ir_gen_var_decl(IrBuilder *irb, Scope *scope, AstNode *nod
     if (init_value == irb->codegen->invalid_instruction)
         return irb->codegen->invalid_instruction;
 
-    if (type_instruction != nullptr) {
-        IrInstruction *implicit_cast = ir_build_implicit_cast(irb, scope, node, type_instruction, init_value,
-                &result_loc_var->base);
+    if (result_loc_cast != nullptr) {
+        IrInstruction *implicit_cast = ir_build_implicit_cast(irb, scope, node, init_value, result_loc_cast);
         ir_build_end_expr(irb, scope, node, implicit_cast, &result_loc_var->base);
     }
 
@@ -15435,6 +15474,7 @@ static ZigType *ir_result_loc_expected_type(IrAnalyze *ira, IrInstruction *suspe
         case ResultLocIdNone:
         case ResultLocIdVar:
         case ResultLocIdBitCast:
+        case ResultLocIdCast:
             return nullptr;
         case ResultLocIdInstruction:
             return result_loc->source_instruction->child->value.type;
@@ -15489,6 +15529,7 @@ static bool ir_result_has_type(ResultLoc *result_loc) {
         case ResultLocIdReturn:
         case ResultLocIdInstruction:
         case ResultLocIdBitCast:
+        case ResultLocIdCast:
             return true;
         case ResultLocIdVar:
             return reinterpret_cast<ResultLocVar *>(result_loc)->var->decl_node->data.variable_declaration.type != nullptr;
@@ -15668,6 +15709,61 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe
             result_loc->resolved_loc = parent_result_loc;
             return result_loc->resolved_loc;
         }
+        case ResultLocIdCast: {
+            ResultLocCast *result_cast = reinterpret_cast<ResultLocCast *>(result_loc);
+            ZigType *dest_type = ir_resolve_type(ira, result_cast->base.source_instruction->child);
+            if (type_is_invalid(dest_type))
+                return ira->codegen->invalid_instruction;
+
+            ConstCastOnly const_cast_result = types_match_const_cast_only(ira, dest_type, value_type,
+                    result_cast->base.source_instruction->source_node, false);
+            if (const_cast_result.id == ConstCastResultIdInvalid)
+                return ira->codegen->invalid_instruction;
+            if (const_cast_result.id != ConstCastResultIdOk) {
+                // We will not be able to provide a result location for this value. Allow the
+                // code to create a new result location and then type coerce to the old one.
+                return nullptr;
+            }
+
+            // In this case we can pointer cast the result location.
+            IrInstruction *casted_value;
+            if (value != nullptr) {
+                casted_value = ir_implicit_cast(ira, value, dest_type);
+            } else {
+                casted_value = nullptr;
+            }
+
+            if (casted_value == nullptr || type_is_invalid(casted_value->value.type)) {
+                return casted_value;
+            }
+
+            IrInstruction *parent_result_loc = ir_resolve_result(ira, suspend_source_instr, result_cast->parent,
+                    dest_type, casted_value, force_runtime, non_null_comptime, true);
+            if (parent_result_loc == nullptr || type_is_invalid(parent_result_loc->value.type) ||
+                parent_result_loc->value.type->id == ZigTypeIdUnreachable)
+            {
+                return parent_result_loc;
+            }
+            ZigType *parent_ptr_type = parent_result_loc->value.type;
+            assert(parent_ptr_type->id == ZigTypeIdPointer);
+            if ((err = type_resolve(ira->codegen, parent_ptr_type->data.pointer.child_type,
+                            ResolveStatusAlignmentKnown)))
+            {
+                return ira->codegen->invalid_instruction;
+            }
+            uint64_t parent_ptr_align = get_ptr_align(ira->codegen, parent_ptr_type);
+            if ((err = type_resolve(ira->codegen, value_type, ResolveStatusAlignmentKnown))) {
+                return ira->codegen->invalid_instruction;
+            }
+            ZigType *ptr_type = get_pointer_to_type_extra(ira->codegen, value_type,
+                    parent_ptr_type->data.pointer.is_const, parent_ptr_type->data.pointer.is_volatile, PtrLenSingle,
+                    parent_ptr_align, 0, 0, parent_ptr_type->data.pointer.allow_zero);
+
+            result_loc->written = true;
+            result_loc->resolved_loc = ir_analyze_ptr_cast(ira, suspend_source_instr, parent_result_loc,
+                    ptr_type, result_cast->base.source_instruction, false);
+            return result_loc->resolved_loc;
+        }
         case ResultLocIdBitCast: {
             ResultLocBitCast *result_bit_cast = reinterpret_cast<ResultLocBitCast *>(result_loc);
             ZigType *dest_type = ir_resolve_type(ira, result_bit_cast->base.source_instruction->child);
@@ -15790,18 +15886,6 @@ static IrInstruction *ir_resolve_result(IrAnalyze *ira, IrInstruction *suspend_s
     return result_loc;
 }
 
-static IrInstruction *ir_analyze_instruction_implicit_cast(IrAnalyze *ira, IrInstructionImplicitCast *instruction) {
-    ZigType *dest_type = ir_resolve_type(ira, instruction->dest_type->child);
-    if (type_is_invalid(dest_type))
-        return ira->codegen->invalid_instruction;
-
-    IrInstruction *target = instruction->target->child;
-    if (type_is_invalid(target->value.type))
-        return ira->codegen->invalid_instruction;
-
-    return ir_implicit_cast_with_result(ira, target, dest_type, instruction->result_loc);
-}
-
 static IrInstruction *ir_analyze_instruction_resolve_result(IrAnalyze *ira, IrInstructionResolveResult *instruction) {
     ZigType *implicit_elem_type = ir_resolve_type(ira, instruction->ty->child);
     if (type_is_invalid(implicit_elem_type))
@@ -15864,6 +15948,7 @@ static void ir_reset_result(ResultLoc *result_loc) {
         case ResultLocIdNone:
         case ResultLocIdInstruction:
         case ResultLocIdBitCast:
+        case ResultLocIdCast:
             break;
     }
 }
@@ -16903,25 +16988,14 @@ static IrInstruction *ir_analyze_instruction_call(IrAnalyze *ira, IrInstructionC
 
     if (is_comptime || instr_is_comptime(fn_ref)) {
         if (fn_ref->value.type->id == ZigTypeIdMetaType) {
-            ZigType *dest_type = ir_resolve_type(ira, fn_ref);
-            if (type_is_invalid(dest_type))
+            ZigType *ty = ir_resolve_type(ira, fn_ref);
+            if (ty == nullptr)
                 return ira->codegen->invalid_instruction;
-
-            size_t actual_param_count = call_instruction->arg_count;
-
-            if (actual_param_count != 1) {
-                ir_add_error_node(ira, call_instruction->base.source_node,
-                        buf_sprintf("cast expression expects exactly one parameter"));
-                return ira->codegen->invalid_instruction;
-            }
-
-            IrInstruction *arg = call_instruction->args[0]->child;
-
-            IrInstruction *cast_instruction = ir_analyze_cast(ira, &call_instruction->base, dest_type, arg,
-                    call_instruction->result_loc);
-            if (type_is_invalid(cast_instruction->value.type))
-                return ira->codegen->invalid_instruction;
-            return ir_finish_anal(ira, cast_instruction);
+            ErrorMsg *msg = ir_add_error_node(ira, fn_ref->source_node,
+                buf_sprintf("type '%s' not a function", buf_ptr(&ty->name)));
+            add_error_note(ira->codegen, msg, call_instruction->base.source_node,
+                buf_sprintf("use @as builtin for type coercion"));
+            return ira->codegen->invalid_instruction;
         } else if (fn_ref->value.type->id == ZigTypeIdFn) {
             ZigFn *fn_table_entry = ir_resolve_fn(ira, fn_ref);
             ZigType *fn_type = fn_table_entry ? fn_table_entry->type_entry : fn_ref->value.type;
@@ -25958,6 +26032,26 @@ static IrInstruction *ir_analyze_instruction_end_expr(IrAnalyze *ira, IrInstruct
     return ir_const_void(ira, &instruction->base);
 }
 
+static IrInstruction *ir_analyze_instruction_implicit_cast(IrAnalyze *ira, IrInstructionImplicitCast *instruction) {
+    IrInstruction *operand = instruction->operand->child;
+    if (type_is_invalid(operand->value.type))
+        return operand;
+
+    IrInstruction *result_loc = ir_resolve_result(ira, &instruction->base,
+            &instruction->result_loc_cast->base, operand->value.type, operand, false, false, true);
+    if (result_loc != nullptr && (type_is_invalid(result_loc->value.type) || instr_is_unreachable(result_loc)))
+        return result_loc;
+
+    if (instruction->result_loc_cast->parent->gen_instruction != nullptr) {
+        return instruction->result_loc_cast->parent->gen_instruction;
+    }
+
+    ZigType *dest_type = ir_resolve_type(ira, instruction->result_loc_cast->base.source_instruction->child);
+    if (type_is_invalid(dest_type))
+        return ira->codegen->invalid_instruction;
+    return ir_implicit_cast(ira, operand, dest_type);
+}
+
 static IrInstruction *ir_analyze_instruction_bit_cast_src(IrAnalyze *ira, IrInstructionBitCastSrc *instruction) {
     IrInstruction *operand = instruction->operand->child;
     if (type_is_invalid(operand->value.type))
src/ir_print.cpp
@@ -601,6 +601,12 @@ static void ir_print_result_loc_bit_cast(IrPrint *irp, ResultLocBitCast *result_
     fprintf(irp->f, ")");
 }
 
+static void ir_print_result_loc_cast(IrPrint *irp, ResultLocCast *result_loc_cast) {
+    fprintf(irp->f, "cast(ty=");
+    ir_print_other_instruction(irp, result_loc_cast->base.source_instruction);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_result_loc(IrPrint *irp, ResultLoc *result_loc) {
     switch (result_loc->id) {
         case ResultLocIdInvalid:
@@ -619,6 +625,8 @@ static void ir_print_result_loc(IrPrint *irp, ResultLoc *result_loc) {
             return ir_print_result_loc_peer(irp, (ResultLocPeer *)result_loc);
         case ResultLocIdBitCast:
             return ir_print_result_loc_bit_cast(irp, (ResultLocBitCast *)result_loc);
+        case ResultLocIdCast:
+            return ir_print_result_loc_cast(irp, (ResultLocCast *)result_loc);
         case ResultLocIdPeerParent:
             fprintf(irp->f, "peer_parent");
             return;
@@ -1484,6 +1492,13 @@ static void ir_print_ptr_cast_gen(IrPrint *irp, IrInstructionPtrCastGen *instruc
     fprintf(irp->f, ")");
 }
 
+static void ir_print_implicit_cast(IrPrint *irp, IrInstructionImplicitCast *instruction) {
+    fprintf(irp->f, "@implicitCast(");
+    ir_print_other_instruction(irp, instruction->operand);
+    fprintf(irp->f, ")result=");
+    ir_print_result_loc(irp, &instruction->result_loc_cast->base);
+}
+
 static void ir_print_bit_cast_src(IrPrint *irp, IrInstructionBitCastSrc *instruction) {
     fprintf(irp->f, "@bitCast(");
     ir_print_other_instruction(irp, instruction->operand);
@@ -1739,14 +1754,6 @@ static void ir_print_align_cast(IrPrint *irp, IrInstructionAlignCast *instructio
     fprintf(irp->f, ")");
 }
 
-static void ir_print_implicit_cast(IrPrint *irp, IrInstructionImplicitCast *instruction) {
-    fprintf(irp->f, "@implicitCast(");
-    ir_print_other_instruction(irp, instruction->dest_type);
-    fprintf(irp->f, ",");
-    ir_print_other_instruction(irp, instruction->target);
-    fprintf(irp->f, ")");
-}
-
 static void ir_print_resolve_result(IrPrint *irp, IrInstructionResolveResult *instruction) {
     fprintf(irp->f, "ResolveResult(");
     ir_print_result_loc(irp, instruction->result_loc);