Commit 990eccf282

Vexu <git@vexu.eu>
2020-11-23 00:59:44
stage1: implement type coercion of pointer to anon list to array/struct/union/slice
1 parent ed028fd
Changed files (5)
src
stage1
test
src/stage1/ir.cpp
@@ -14996,25 +14996,16 @@ static IrInstGen *ir_analyze_enum_literal(IrAnalyze *ira, IrInst* source_instr,
 }
 
 static IrInstGen *ir_analyze_struct_literal_to_array(IrAnalyze *ira, IrInst* source_instr,
-        IrInstGen *struct_operand, ZigType *wanted_type)
+        IrInstGen *struct_ptr, ZigType *actual_type, ZigType *wanted_type)
 {
     Error err;
 
-    IrInstGen *struct_ptr = ir_get_ref(ira, source_instr, struct_operand, true, false);
-    if (type_is_invalid(struct_ptr->value->type))
-        return ira->codegen->invalid_inst_gen;
-
     if ((err = type_resolve(ira->codegen, wanted_type, ResolveStatusSizeKnown)))
         return ira->codegen->invalid_inst_gen;
     
     size_t array_len = wanted_type->data.array.len;
-    size_t instr_field_count = struct_operand->value->type->data.structure.src_field_count;
-
-    if (instr_field_count != array_len) {
-        ir_add_error(ira, source_instr, buf_sprintf("expected %" ZIG_PRI_usize " fields, found %" ZIG_PRI_usize,
-            array_len, instr_field_count));
-        return ira->codegen->invalid_inst_gen;
-    }
+    size_t instr_field_count = actual_type->data.structure.src_field_count;
+    assert(array_len == instr_field_count);
 
     bool need_comptime = ir_should_inline(ira->old_irb.exec, source_instr->scope)
         || type_requires_comptime(ira->codegen, wanted_type) == ReqCompTimeYes;
@@ -15028,10 +15019,10 @@ static IrInstGen *ir_analyze_struct_literal_to_array(IrAnalyze *ira, IrInst* sou
     IrInstGen *const_result = ir_const(ira, source_instr, wanted_type);
 
     for (size_t i = 0; i < array_len; i += 1) {
-        TypeStructField *src_field = struct_operand->value->type->data.structure.fields[i];
+        TypeStructField *src_field = actual_type->data.structure.fields[i];
 
         IrInstGen *field_ptr = ir_analyze_struct_field_ptr(ira, source_instr, src_field, struct_ptr,
-                struct_operand->value->type, false);
+                actual_type, false);
         if (type_is_invalid(field_ptr->value->type))
             return ira->codegen->invalid_inst_gen;
         IrInstGen *field_value = ir_get_deref(ira, source_instr, field_ptr, nullptr);
@@ -15087,18 +15078,14 @@ static IrInstGen *ir_analyze_struct_literal_to_array(IrAnalyze *ira, IrInst* sou
     heap::c_allocator.deallocate(elem_values, array_len);
     heap::c_allocator.deallocate(casted_fields, array_len);
 
-    return ir_get_deref(ira, source_instr, result_loc_inst, nullptr);
+    return result_loc_inst;
 }
 
 static IrInstGen *ir_analyze_struct_literal_to_struct(IrAnalyze *ira, IrInst* source_instr,
-        IrInstGen *struct_operand, ZigType *wanted_type)
+        IrInstGen *struct_ptr, ZigType *actual_type, ZigType *wanted_type)
 {
     Error err;
 
-    IrInstGen *struct_ptr = ir_get_ref(ira, source_instr, struct_operand, true, false);
-    if (type_is_invalid(struct_ptr->value->type))
-        return ira->codegen->invalid_inst_gen;
-
     if (wanted_type->data.structure.resolve_status == ResolveStatusBeingInferred) {
         ir_add_error(ira, source_instr, buf_sprintf("type coercion of anon struct literal to inferred struct"));
         return ira->codegen->invalid_inst_gen;
@@ -15108,7 +15095,7 @@ static IrInstGen *ir_analyze_struct_literal_to_struct(IrAnalyze *ira, IrInst* so
         return ira->codegen->invalid_inst_gen;
 
     size_t actual_field_count = wanted_type->data.structure.src_field_count;
-    size_t instr_field_count = struct_operand->value->type->data.structure.src_field_count;
+    size_t instr_field_count = actual_type->data.structure.src_field_count;
 
     bool need_comptime = ir_should_inline(ira->old_irb.exec, source_instr->scope)
         || type_requires_comptime(ira->codegen, wanted_type) == ReqCompTimeYes;
@@ -15122,7 +15109,7 @@ static IrInstGen *ir_analyze_struct_literal_to_struct(IrAnalyze *ira, IrInst* so
     IrInstGen *const_result = ir_const(ira, source_instr, wanted_type);
 
     for (size_t i = 0; i < instr_field_count; i += 1) {
-        TypeStructField *src_field = struct_operand->value->type->data.structure.fields[i];
+        TypeStructField *src_field = actual_type->data.structure.fields[i];
         TypeStructField *dst_field = find_struct_type_field(wanted_type, src_field->name);
         if (dst_field == nullptr) {
             ErrorMsg *msg = ir_add_error(ira, source_instr, buf_sprintf("no field named '%s' in struct '%s'",
@@ -15146,7 +15133,7 @@ static IrInstGen *ir_analyze_struct_literal_to_struct(IrAnalyze *ira, IrInst* so
         field_assign_nodes[dst_field->src_index] = src_field->decl_node;
 
         IrInstGen *field_ptr = ir_analyze_struct_field_ptr(ira, source_instr, src_field, struct_ptr,
-                struct_operand->value->type, false);
+                actual_type, false);
         if (type_is_invalid(field_ptr->value->type))
             return ira->codegen->invalid_inst_gen;
         IrInstGen *field_value = ir_get_deref(ira, source_instr, field_ptr, nullptr);
@@ -15226,14 +15213,13 @@ static IrInstGen *ir_analyze_struct_literal_to_struct(IrAnalyze *ira, IrInst* so
     heap::c_allocator.deallocate(field_values, actual_field_count);
     heap::c_allocator.deallocate(casted_fields, actual_field_count);
 
-    return ir_get_deref(ira, source_instr, result_loc_inst, nullptr);
+    return result_loc_inst;
 }
 
 static IrInstGen *ir_analyze_struct_literal_to_union(IrAnalyze *ira, IrInst* source_instr,
-        IrInstGen *value, ZigType *union_type)
+        IrInstGen *struct_ptr, ZigType *struct_type, ZigType *union_type)
 {
     Error err;
-    ZigType *struct_type = value->value->type;
 
     assert(struct_type->id == ZigTypeIdStruct);
     assert(union_type->id == ZigTypeIdUnion);
@@ -15256,7 +15242,11 @@ static IrInstGen *ir_analyze_struct_literal_to_union(IrAnalyze *ira, IrInst* sou
     if (payload_type == nullptr)
         return ira->codegen->invalid_inst_gen;
 
-    IrInstGen *field_value = ir_analyze_struct_value_field_value(ira, source_instr, value, only_field);
+    IrInstGen *field_ptr = ir_analyze_struct_field_ptr(ira, source_instr, only_field, struct_ptr,
+            struct_type, false);
+    if (type_is_invalid(field_ptr->value->type))
+        return ira->codegen->invalid_inst_gen;
+    IrInstGen *field_value =  ir_get_deref(ira, source_instr, field_ptr, nullptr);
     if (type_is_invalid(field_value->value->type))
         return ira->codegen->invalid_inst_gen;
 
@@ -15294,7 +15284,7 @@ static IrInstGen *ir_analyze_struct_literal_to_union(IrAnalyze *ira, IrInst* sou
     if (type_is_invalid(store_ptr_inst->value->type))
         return ira->codegen->invalid_inst_gen;
 
-    return ir_get_deref(ira, source_instr, result_loc_inst, nullptr);
+    return result_loc_inst;
 }
 
 // Add a compile error and return ErrorSemanticAnalyzeFail if the pointer alignment does not work,
@@ -15927,13 +15917,76 @@ static IrInstGen *ir_analyze_cast(IrAnalyze *ira, IrInst *source_instr,
         if (wanted_type->id == ZigTypeIdArray && (is_array_init || field_count == 0) &&
             wanted_type->data.array.len == field_count)
         {
-            return ir_analyze_struct_literal_to_array(ira, source_instr, value, wanted_type);
+            IrInstGen *struct_ptr = ir_get_ref(ira, source_instr, value, true, false);
+            if (type_is_invalid(struct_ptr->value->type))
+                return ira->codegen->invalid_inst_gen;
+
+            IrInstGen *ptr = ir_analyze_struct_literal_to_array(ira, source_instr, struct_ptr, actual_type, wanted_type);
+            if (ptr->value->type->id != ZigTypeIdPointer)
+                return ptr;
+            return ir_get_deref(ira, source_instr, ptr, nullptr);
         } else if (wanted_type->id == ZigTypeIdStruct && !is_slice(wanted_type) &&
                 (!is_array_init || field_count == 0))
         {
-            return ir_analyze_struct_literal_to_struct(ira, source_instr, value, wanted_type);
+            IrInstGen *struct_ptr = ir_get_ref(ira, source_instr, value, true, false);
+            if (type_is_invalid(struct_ptr->value->type))
+                return ira->codegen->invalid_inst_gen;
+
+            IrInstGen *ptr = ir_analyze_struct_literal_to_struct(ira, source_instr, struct_ptr, actual_type, wanted_type);
+            if (ptr->value->type->id != ZigTypeIdPointer)
+                return ptr;
+            return ir_get_deref(ira, source_instr, ptr, nullptr);
         } else if (wanted_type->id == ZigTypeIdUnion && !is_array_init && field_count == 1) {
-            return ir_analyze_struct_literal_to_union(ira, source_instr, value, wanted_type);
+            IrInstGen *struct_ptr = ir_get_ref(ira, source_instr, value, true, false);
+            if (type_is_invalid(struct_ptr->value->type))
+                return ira->codegen->invalid_inst_gen;
+
+            IrInstGen *ptr = ir_analyze_struct_literal_to_union(ira, source_instr, struct_ptr, actual_type, wanted_type);
+            if (ptr->value->type->id != ZigTypeIdPointer)
+                return ptr;
+            return ir_get_deref(ira, source_instr, ptr, nullptr);
+        }
+    }
+
+    // cast from pointer to inferred struct type to pointer to array, union, or struct
+    if (actual_type->id == ZigTypeIdPointer && is_anon_container(actual_type->data.pointer.child_type)) {
+        ZigType *anon_type = actual_type->data.pointer.child_type;
+        const bool is_array_init =
+            anon_type->data.structure.special == StructSpecialInferredTuple;
+        const uint32_t field_count = anon_type->data.structure.src_field_count;
+
+        if (wanted_type->id == ZigTypeIdPointer) {
+            ZigType *wanted_child = wanted_type->data.pointer.child_type;
+            if (wanted_child->id == ZigTypeIdArray && (is_array_init || field_count == 0) &&
+                wanted_child->data.array.len == field_count)
+            {
+                IrInstGen *res = ir_analyze_struct_literal_to_array(ira, source_instr, value, anon_type, wanted_child);
+                if (res->value->type->id == ZigTypeIdPointer)
+                    return res;
+                return ir_get_ref(ira, source_instr, res, wanted_type->data.pointer.is_const, wanted_type->data.pointer.is_volatile);
+            } else if (wanted_child->id == ZigTypeIdStruct && !is_slice(wanted_type) &&
+                    (!is_array_init || field_count == 0))
+            {
+                IrInstGen *res = ir_analyze_struct_literal_to_struct(ira, source_instr, value, anon_type, wanted_child);
+                if (res->value->type->id == ZigTypeIdPointer)
+                    return res;
+                return ir_get_ref(ira, source_instr, res, wanted_type->data.pointer.is_const, wanted_type->data.pointer.is_volatile);
+            } else if (wanted_child->id == ZigTypeIdUnion && !is_array_init && field_count == 1) {
+                IrInstGen *res =  ir_analyze_struct_literal_to_union(ira, source_instr, value, anon_type, wanted_child);
+                if (res->value->type->id == ZigTypeIdPointer)
+                    return res;
+                return ir_get_ref(ira, source_instr, res, wanted_type->data.pointer.is_const, wanted_type->data.pointer.is_volatile);
+            }
+        } else if (is_slice(wanted_type) && (is_array_init || field_count == 0)) {
+            ZigType *slice_child_type = wanted_type->data.structure.fields[slice_ptr_index]->type_entry->data.pointer.child_type;
+            ZigType *slice_array_type = get_array_type(ira->codegen, slice_child_type, field_count, nullptr);
+            IrInstGen *res = ir_analyze_struct_literal_to_array(ira, source_instr, value, anon_type, slice_array_type);
+            if (type_is_invalid(res->value->type))
+                return ira->codegen->invalid_inst_gen;
+            if (res->value->type->id != ZigTypeIdPointer)
+                res = ir_get_ref(ira, source_instr, res, wanted_type->data.pointer.is_const, wanted_type->data.pointer.is_volatile);
+
+            return ir_resolve_ptr_of_array_to_slice(ira, source_instr, res, wanted_type, nullptr);
         }
     }
 
test/stage1/behavior/array.zig
@@ -459,3 +459,31 @@ test "type coercion of anon struct literal to array" {
     S.doTheTest();
     comptime S.doTheTest();
 }
+
+test "type coercion of pointer to anon struct literal to pointer to array" {
+    const S = struct {
+        const U = union{
+            a: u32,
+            b: bool,
+            c: []const u8,
+        };
+
+        fn doTheTest() void {
+            var x1: u8 = 42;
+            const t1 = &.{ x1, 56, 54 };
+            var arr1: *[3]u8 = t1;
+            expect(arr1[0] == 42);
+            expect(arr1[1] == 56);
+            expect(arr1[2] == 54);
+            
+            var x2: U = .{ .a = 42 };
+            const t2 = &.{ x2, .{ .b = true }, .{ .c = "hello" } };
+            var arr2: *[3]U = t2;
+            expect(arr2[0].a == 42);
+            expect(arr2[1].b == true);
+            expect(mem.eql(u8, arr2[2].c, "hello"));
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
test/stage1/behavior/slice.zig
@@ -304,3 +304,33 @@ test "slice of hardcoded address to pointer" {
 
     S.doTheTest();
 }
+
+test "type coercion of pointer to anon struct literal to pointer to slice" {
+    const S = struct {
+        const U = union{
+            a: u32,
+            b: bool,
+            c: []const u8,
+        };
+
+        fn doTheTest() void {
+            var x1: u8 = 42;
+            const t1 = &.{ x1, 56, 54 };
+            var slice1: []u8 = t1;
+            expect(slice1.len == 3);
+            expect(slice1[0] == 42);
+            expect(slice1[1] == 56);
+            expect(slice1[2] == 54);
+            
+            var x2: []const u8 = "hello";
+            const t2 = &.{ x2, ", ", "world!" };
+            var slice2: [][]const u8 = t2;
+            expect(slice2.len == 3);
+            expect(mem.eql(u8, slice2[0], "hello"));
+            expect(mem.eql(u8, slice2[1], ", "));
+            expect(mem.eql(u8, slice2[2], "world!"));
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
test/stage1/behavior/struct.zig
@@ -885,6 +885,39 @@ test "type coercion of anon struct literal to struct" {
     comptime S.doTheTest();
 }
 
+test "type coercion of pointer to anon struct literal to pointer to struct" {
+    const S = struct {
+        const S2 = struct {
+            A: u32,
+            B: []const u8,
+            C: void,
+            D: Foo = .{},
+        };
+
+        const Foo = struct {
+            field: i32 = 1234,
+        };
+
+        fn doTheTest() void {
+            var y: u32 = 42;
+            const t0 = &.{ .A = 123, .B = "foo", .C = {} };
+            const t1 = &.{ .A = y, .B = "foo", .C = {} };
+            const y0: *S2 = t0;
+            var y1: *S2 = t1;
+            expect(y0.A == 123);
+            expect(std.mem.eql(u8, y0.B, "foo"));
+            expect(y0.C == {});
+            expect(y0.D.field == 1234);
+            expect(y1.A == y);
+            expect(std.mem.eql(u8, y1.B, "foo"));
+            expect(y1.C == {});
+            expect(y1.D.field == 1234);
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
+
 test "packed struct with undefined initializers" {
     const S = struct {
         const P = packed struct {
test/stage1/behavior/union.zig
@@ -667,6 +667,33 @@ test "cast from anonymous struct to union" {
     comptime S.doTheTest();
 }
 
+test "cast from pointer to anonymous struct to pointer to union" {
+    const S = struct {
+        const U = union(enum) {
+            A: u32,
+            B: []const u8,
+            C: void,
+        };
+        fn doTheTest() void {
+            var y: u32 = 42;
+            const t0 = &.{ .A = 123 };
+            const t1 = &.{ .B = "foo" };
+            const t2 = &.{ .C = {} };
+            const t3 = &.{ .A = y };
+            const x0: *U = t0;
+            var x1: *U = t1;
+            const x2: *U = t2;
+            var x3: *U = t3;
+            expect(x0.A == 123);
+            expect(std.mem.eql(u8, x1.B, "foo"));
+            expect(x2.* == .C);
+            expect(x3.A == y);
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
+
 test "method call on an empty union" {
     const S = struct {
         const MyUnion = union(Tag) {