Commit 9b7ad12481

Robert Scott <keyboard.operator@gmail.com>
2019-05-09 11:28:14
Implement @unionInit
1 parent 163a8e9
Changed files (5)
src/all_types.hpp
@@ -1471,6 +1471,7 @@ enum BuiltinFnId {
     BuiltinFnIdErrorReturnTrace,
     BuiltinFnIdAtomicRmw,
     BuiltinFnIdAtomicLoad,
+    BuiltinFnIdUnionInit,
 };
 
 struct BuiltinFnEntry {
@@ -2297,6 +2298,7 @@ enum IrInstructionId {
     IrInstructionIdArrayToVector,
     IrInstructionIdAssertZero,
     IrInstructionIdAssertNonNull,
+    IrInstructionIdUnionInit2,
 };
 
 struct IrInstruction {
@@ -3503,6 +3505,17 @@ struct IrInstructionAssertNonNull {
     IrInstruction *target;
 };
 
+// TODO, need a better name. Using 2 because there is currently a IrInstructionUnionInit
+// It seems like the first one should only be used during the analyze phase, but still
+// don't understand it all.
+struct IrInstructionUnionInit2 {
+    IrInstruction base;
+
+    IrInstruction *union_type_value;
+    IrInstruction *field_name_expr;
+    IrInstruction *value;
+};
+
 static const size_t slice_ptr_index = 0;
 static const size_t slice_len_index = 1;
 
src/codegen.cpp
@@ -5616,6 +5616,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
         case IrInstructionIdLoadPtr:
         case IrInstructionIdBitCast:
         case IrInstructionIdGlobalAsm:
+        case IrInstructionIdUnionInit2:
             zig_unreachable();
 
         case IrInstructionIdDeclVarGen:
@@ -7409,6 +7410,7 @@ static void define_builtin_fns(CodeGen *g) {
     create_builtin_fn(g, BuiltinFnIdToBytes, "sliceToBytes", 1);
     create_builtin_fn(g, BuiltinFnIdFromBytes, "bytesToSlice", 2);
     create_builtin_fn(g, BuiltinFnIdThis, "This", 0);
+    create_builtin_fn(g, BuiltinFnIdUnionInit, "unionInit", 3);
 }
 
 static const char *bool_to_str(bool b) {
src/ir.cpp
@@ -188,7 +188,7 @@ static ConstExprValue *const_ptr_pointee_unchecked(CodeGen *g, ConstExprValue *c
     assert(get_src_ptr_type(const_val->type) != nullptr);
     assert(const_val->special == ConstValSpecialStatic);
     ConstExprValue *result;
-    
+
     switch (type_has_one_possible_value(g, const_val->type->data.pointer.child_type)) {
         case OnePossibleValueInvalid:
             zig_unreachable();
@@ -200,7 +200,7 @@ static ConstExprValue *const_ptr_pointee_unchecked(CodeGen *g, ConstExprValue *c
         case OnePossibleValueNo:
             break;
     }
-    
+
     switch (const_val->data.x_ptr.special) {
         case ConstPtrSpecialInvalid:
             zig_unreachable();
@@ -1011,6 +1011,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionAssertNonNull *)
     return IrInstructionIdAssertNonNull;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionInit2 *) {
+    return IrInstructionIdUnionInit2;
+}
+
 template<typename T>
 static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
     T *special_instruction = allocate<T>(1);
@@ -1312,6 +1316,7 @@ static IrInstruction *ir_build_union_field_ptr(IrBuilder *irb, Scope *scope, Ast
     return &instruction->base;
 }
 
+
 static IrInstruction *ir_build_call(IrBuilder *irb, Scope *scope, AstNode *source_node,
         ZigFn *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args,
         bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator,
@@ -3025,6 +3030,21 @@ static IrInstruction *ir_build_check_runtime_scope(IrBuilder *irb, Scope *scope,
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_union_init_2(IrBuilder *irb, Scope *scope, AstNode *source_node,
+    IrInstruction *union_type_value, IrInstruction *field_name_expr, IrInstruction *value) {
+    IrInstructionUnionInit2 *instruction = ir_build_instruction<IrInstructionUnionInit2>(irb, scope, source_node);
+    instruction->union_type_value = union_type_value;
+    instruction->field_name_expr = field_name_expr;
+    instruction->value = value;
+
+    ir_ref_instruction(union_type_value, irb->current_basic_block);
+    ir_ref_instruction(field_name_expr, irb->current_basic_block);
+    ir_ref_instruction(value, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
+
 static IrInstruction *ir_build_vector_to_array(IrAnalyze *ira, IrInstruction *source_instruction,
         IrInstruction *vector, ZigType *result_type)
 {
@@ -3868,7 +3888,7 @@ static void populate_invalid_variable_in_scope(CodeGen *g, Scope *scope, AstNode
     TldVar *tld_var = allocate<TldVar>(1);
     init_tld(&tld_var->base, TldIdVar, var_name, VisibModPub, node, &scope_decls->base);
     tld_var->base.resolution = TldResolutionInvalid;
-    tld_var->var = add_variable(g, node, &scope_decls->base, var_name, false, 
+    tld_var->var = add_variable(g, node, &scope_decls->base, var_name, false,
             &g->invalid_instruction->value, &tld_var->base, g->builtin_types.entry_invalid);
     scope_decls->decl_table.put(var_name, &tld_var->base);
 }
@@ -5098,6 +5118,29 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
                 }
                 return ir_lval_wrap(irb, scope, result, lval);
             }
+        case BuiltinFnIdUnionInit:
+            {
+
+                AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
+                IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
+                if (arg0_value == irb->codegen->invalid_instruction)
+                    return arg0_value;
+
+                AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
+                IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope);
+                if (arg1_value == irb->codegen->invalid_instruction)
+                    return arg1_value;
+
+                AstNode *arg2_node = node->data.fn_call_expr.params.at(2);
+                IrInstruction *arg2_value = ir_gen_node(irb, arg2_node, scope);
+                if (arg2_value == irb->codegen->invalid_instruction)
+                    return arg2_value;
+
+                IrInstruction *result = ir_build_union_init_2(irb, scope, node, arg0_value, arg1_value, arg2_value);
+
+                // TODO: Not sure if we need ir_lval_wrap or not.
+                return result;
+            }
     }
     zig_unreachable();
 }
@@ -6328,7 +6371,7 @@ static bool ir_gen_switch_prong_expr(IrBuilder *irb, Scope *scope, AstNode *swit
                     prong_values, prong_values_len);
             var_value = var_is_ptr ? var_ptr_value : ir_build_load_ptr(irb, scope, var_symbol_node, var_ptr_value);
         } else {
-            var_value = var_is_ptr ? target_value_ptr : ir_build_load_ptr(irb, scope, var_symbol_node, 
+            var_value = var_is_ptr ? target_value_ptr : ir_build_load_ptr(irb, scope, var_symbol_node,
 target_value_ptr);
         }
         IrInstruction *var_type = nullptr; // infer the type
@@ -12372,7 +12415,7 @@ static IrInstruction *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp *
         } else {
             return is_non_null;
         }
-    } else if (is_equality_cmp && 
+    } else if (is_equality_cmp &&
         ((op1->value.type->id == ZigTypeIdNull && op2->value.type->id == ZigTypeIdPointer &&
             op2->value.type->data.pointer.ptr_len == PtrLenC) ||
         (op2->value.type->id == ZigTypeIdNull && op1->value.type->id == ZigTypeIdPointer &&
@@ -19383,7 +19426,7 @@ static IrInstruction *ir_analyze_instruction_c_import(IrAnalyze *ira, IrInstruct
             ir_add_error_node(ira, node, buf_sprintf("C import failed: unable to make dir: %s", err_str(err)));
             return ira->codegen->invalid_instruction;
         }
-        
+
         if ((err = os_write_file(&tmp_c_file_path, &cimport_scope->buf))) {
             ir_add_error_node(ira, node, buf_sprintf("C import failed: unable to write .h file: %s", err_str(err)));
             return ira->codegen->invalid_instruction;
@@ -20333,7 +20376,7 @@ static IrInstruction *ir_analyze_instruction_memcpy(IrAnalyze *ira, IrInstructio
         return ira->codegen->invalid_instruction;
 
     // TODO test this at comptime with u8 and non-u8 types
-    // TODO test with dest ptr being a global runtime variable 
+    // TODO test with dest ptr being a global runtime variable
     if (casted_dest_ptr->value.special == ConstValSpecialStatic &&
         casted_src_ptr->value.special == ConstValSpecialStatic &&
         casted_count->value.special == ConstValSpecialStatic &&
@@ -23151,6 +23194,35 @@ static IrInstruction *ir_analyze_instruction_check_runtime_scope(IrAnalyze *ira,
     return ir_const_void(ira, &instruction->base);
 }
 
+static IrInstruction *ir_analyze_instruction_union_init_2(IrAnalyze *ira, IrInstructionUnionInit2 *union_init_instruction)
+{
+    Error err;
+    IrInstruction *union_type_value = union_init_instruction->union_type_value->child;
+    ZigType *union_type = ir_resolve_type(ira, union_type_value);
+    if (type_is_invalid(union_type)) {
+        return ira->codegen->invalid_instruction;
+    }
+
+    if (union_type->id != ZigTypeIdUnion)
+        return ira->codegen->invalid_instruction;
+
+    if ((err = ensure_complete_type(ira->codegen, union_type)))
+        return ira->codegen->invalid_instruction;
+
+    IrInstruction *field_name_expr = union_init_instruction->field_name_expr->child;
+    Buf *field_name = ir_resolve_str(ira, field_name_expr);
+    if (!field_name)
+        return ira->codegen->invalid_instruction;
+
+    IrInstructionContainerInitFieldsField *fields = allocate<IrInstructionContainerInitFieldsField>(1);
+
+    fields[0].name = field_name;
+    fields[0].value = union_init_instruction->value;
+    fields[0].source_node = union_init_instruction->base.source_node;
+
+    return ir_analyze_container_init_fields_union(ira, &union_init_instruction->base, union_type, 1, fields);
+}
+
 static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstruction *instruction) {
     switch (instruction->id) {
         case IrInstructionIdInvalid:
@@ -23445,6 +23517,8 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio
             return ir_analyze_instruction_enum_to_int(ira, (IrInstructionEnumToInt *)instruction);
         case IrInstructionIdCheckRuntimeScope:
             return ir_analyze_instruction_check_runtime_scope(ira, (IrInstructionCheckRuntimeScope *)instruction);
+        case IrInstructionIdUnionInit2:
+            return ir_analyze_instruction_union_init_2(ira, (IrInstructionUnionInit2 *)instruction);
     }
     zig_unreachable();
 }
@@ -23681,6 +23755,8 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdEnumToInt:
         case IrInstructionIdVectorToArray:
         case IrInstructionIdArrayToVector:
+        case IrInstructionIdUnionInit2:
+
             return false;
 
         case IrInstructionIdAsm:
src/ir_print.cpp
@@ -1453,6 +1453,17 @@ static void ir_print_decl_var_gen(IrPrint *irp, IrInstructionDeclVarGen *decl_va
     }
 }
 
+
+static void ir_print_uniont_init_2(IrPrint *irp, IrInstructionUnionInit2 *instruction) {
+    fprintf(irp->f, "@unionInit(");
+    ir_print_other_instruction(irp, instruction->union_type_value);
+    fprintf(irp->f, ", ");
+    ir_print_other_instruction(irp, instruction->field_name_expr);
+    fprintf(irp->f, ", ");
+    ir_print_other_instruction(irp, instruction->value);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
     ir_print_prefix(irp, instruction);
     switch (instruction->id) {
@@ -1920,6 +1931,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdResizeSlice:
             ir_print_resize_slice(irp, (IrInstructionResizeSlice *)instruction);
             break;
+        case IrInstructionIdUnionInit2:
+            ir_print_uniont_init_2(irp, (IrInstructionUnionInit2 *)instruction);
+            break;
     }
     fprintf(irp->f, "\n");
 }
test/stage1/behavior/union.zig
@@ -374,7 +374,7 @@ const Attribute = union(enum) {
 fn setAttribute(attr: Attribute) void {}
 
 fn Setter(attr: Attribute) type {
-    return struct{
+    return struct {
         fn set() void {
             setAttribute(attr);
         }
@@ -402,3 +402,38 @@ test "comptime union field value equality" {
     expect(a0 != a1);
     expect(b0 != b1);
 }
+
+test "unionInit can modify a union type" {
+    const UnionInitEnum = union(enum) {
+        Boolean: bool,
+        Byte: u8,
+    };
+
+    var value: UnionInitEnum = undefined;
+
+    value = @unionInit(UnionInitEnum, "Boolean", true);
+    expect(value.Boolean == true);
+    value.Boolean = false;
+    expect(value.Boolean == false);
+
+    value = @unionInit(UnionInitEnum, "Byte", 2);
+    expect(value.Byte == 2);
+    value.Byte = 3;
+    expect(value.Byte == 3);
+}
+
+test "unionInit can modify a pointer value" {
+    const UnionInitEnum = union(enum) {
+        Boolean: bool,
+        Byte: u8,
+    };
+
+    var value: UnionInitEnum = undefined;
+    var value_ptr = &value;
+
+    value_ptr.* = @unionInit(UnionInitEnum, "Boolean", true);
+    expect(value.Boolean == true);
+
+    value_ptr.* = @unionInit(UnionInitEnum, "Byte", 2);
+    expect(value.Byte == 2);
+}