Commit 587a4437db

Cody Tapscott <cody+topolarity@tapscott.me>
2022-01-24 19:22:38
Add `union` support to the C backend.
There are some differences vs. the union encoding in the LLVM backend: - Tagged unions with a 0-bit payload do not become their tag type. Instead, they are a struct with an empty `union` as their payload field. - We do not order the `payload`/`tag` storage based on their alignment
1 parent 983dfcd
Changed files (2)
src
codegen
test
src/codegen/c.zig
@@ -589,6 +589,40 @@ pub const DeclGen = struct {
 
                 try writer.writeAll("}");
             },
+            .Union => {
+                const union_obj = val.castTag(.@"union").?.data;
+                const target = dg.module.getTarget();
+                const layout = ty.unionGetLayout(target);
+
+                try writer.writeAll("(");
+                try dg.renderType(writer, ty);
+                try writer.writeAll("){");
+
+                if (ty.unionTagType()) |tag_ty| {
+                    if (layout.tag_size != 0) {
+                        try writer.writeAll(".tag = ");
+                        try dg.renderValue(writer, tag_ty, union_obj.tag);
+                        try writer.writeAll(", ");
+                    }
+                    try writer.writeAll(".payload = {");
+                }
+
+                const index = switch (ty.tag()) {
+                    .union_tagged => ty.castTag(.union_tagged).?.data.tag_ty.enumTagFieldIndex(union_obj.tag).?,
+                    .@"union" => ty.castTag(.@"union").?.data.tag_ty.enumTagFieldIndex(union_obj.tag).?,
+                    else => unreachable,
+                };
+                const field_ty = ty.unionFields().values()[index].ty;
+                const field_name = ty.unionFields().keys()[index];
+                if (field_ty.hasCodeGenBits()) {
+                    try writer.print(".{} = ", .{fmtIdent(field_name)});
+                    try dg.renderValue(writer, field_ty, union_obj.val);
+                }
+                if (ty.unionTagType()) |_| {
+                    try writer.writeAll("}");
+                }
+                try writer.writeAll("}");
+            },
 
             .ComptimeInt => unreachable,
             .ComptimeFloat => unreachable,
@@ -601,7 +635,6 @@ pub const DeclGen = struct {
             .BoundFn => unreachable,
             .Opaque => unreachable,
 
-            .Union,
             .Frame,
             .AnyFrame,
             .Vector,
@@ -781,6 +814,65 @@ pub const DeclGen = struct {
         return name;
     }
 
+    fn renderUnionTypedef(dg: *DeclGen, t: Type) error{ OutOfMemory, AnalysisFail }![]const u8 {
+        const fqn = switch (t.tag()) {
+            .@"union" => try t.castTag(.@"union").?.data.getFullyQualifiedName(dg.typedefs.allocator),
+            .union_tagged => try t.castTag(.union_tagged).?.data.getFullyQualifiedName(dg.typedefs.allocator),
+            else => unreachable,
+        };
+        defer dg.typedefs.allocator.free(fqn);
+
+        const target = dg.module.getTarget();
+        const layout = t.unionGetLayout(target);
+
+        var buffer = std.ArrayList(u8).init(dg.typedefs.allocator);
+        defer buffer.deinit();
+
+        try buffer.appendSlice("typedef ");
+        if (t.unionTagType()) |tag_ty| {
+            const name: CValue = .{ .bytes = "tag" };
+            try buffer.appendSlice("struct {\n ");
+            if (layout.tag_size != 0) {
+                try dg.renderTypeAndName(buffer.writer(), tag_ty, name, .Mut, Value.initTag(.abi_align_default));
+                try buffer.appendSlice(";\n");
+            }
+        }
+
+        try buffer.appendSlice("union {\n");
+        {
+            var it = t.unionFields().iterator();
+            while (it.next()) |entry| {
+                const field_ty = entry.value_ptr.ty;
+                if (!field_ty.hasCodeGenBits()) continue;
+                const alignment = entry.value_ptr.abi_align;
+                const name: CValue = .{ .identifier = entry.key_ptr.* };
+                try buffer.append(' ');
+                try dg.renderTypeAndName(buffer.writer(), field_ty, name, .Mut, alignment);
+                try buffer.appendSlice(";\n");
+            }
+        }
+        try buffer.appendSlice("} ");
+
+        if (t.unionTagType()) |_| {
+            try buffer.appendSlice("payload;\n} ");
+        }
+
+        const name_start = buffer.items.len;
+        try buffer.writer().print("zig_U_{s};\n", .{fmtIdent(fqn)});
+
+        const rendered = buffer.toOwnedSlice();
+        errdefer dg.typedefs.allocator.free(rendered);
+        const name = rendered[name_start .. rendered.len - 2];
+
+        try dg.typedefs.ensureUnusedCapacity(1);
+        dg.typedefs.putAssumeCapacityNoClobber(
+            try t.copy(dg.typedefs_arena),
+            .{ .name = name, .rendered = rendered },
+        );
+
+        return name;
+    }
+
     fn renderErrorUnionTypedef(dg: *DeclGen, t: Type) error{ OutOfMemory, AnalysisFail }![]const u8 {
         const child_type = t.errorUnionPayload();
         const err_set_type = t.errorUnionSet();
@@ -959,6 +1051,12 @@ pub const DeclGen = struct {
 
                 return w.writeAll(name);
             },
+            .Union => {
+                const name = dg.getTypedefName(t) orelse
+                    try dg.renderUnionTypedef(t);
+
+                return w.writeAll(name);
+            },
             .Enum => {
                 // For enums, we simply use the integer tag type.
                 var int_tag_ty_buffer: Type.Payload.Bits = undefined;
@@ -967,7 +1065,6 @@ pub const DeclGen = struct {
                 try dg.renderType(w, int_tag_ty);
             },
 
-            .Union,
             .Frame,
             .AnyFrame,
             .Vector,
@@ -2671,21 +2768,36 @@ fn airStructFieldPtrIndex(f: *Function, inst: Air.Inst.Index, index: u8) !CValue
 
 fn structFieldPtr(f: *Function, inst: Air.Inst.Index, struct_ptr_ty: Type, struct_ptr: CValue, index: u32) !CValue {
     const writer = f.object.writer();
-    const struct_obj = struct_ptr_ty.elemType().castTag(.@"struct").?.data;
-    const field_name = struct_obj.fields.keys()[index];
-    const field_val = struct_obj.fields.values()[index];
-    const addrof = if (field_val.ty.zigTypeTag() == .Array) "" else "&";
+    const struct_ty = struct_ptr_ty.elemType();
+    var field_name: []const u8 = undefined;
+    var field_val_ty: Type = undefined; 
+
+    switch (struct_ty.tag()) {
+        .@"struct" => {
+            const fields = struct_ty.structFields();
+            field_name = fields.keys()[index];
+            field_val_ty = fields.values()[index].ty;
+        },
+        .@"union", .union_tagged => {
+            const fields = struct_ty.unionFields();
+            field_name = fields.keys()[index];
+            field_val_ty = fields.values()[index].ty;
+        },
+        else => unreachable,
+    }
+    const addrof = if (field_val_ty.zigTypeTag() == .Array) "" else "&";
+    const payload = if (struct_ty.tag() == .union_tagged) "payload." else "";
 
     const inst_ty = f.air.typeOfIndex(inst);
     const local = try f.allocLocal(inst_ty, .Const);
     switch (struct_ptr) {
         .local_ref => |i| {
-            try writer.print(" = {s}t{d}.{};\n", .{ addrof, i, fmtIdent(field_name) });
+            try writer.print(" = {s}t{d}.{s}{};\n", .{ addrof, i, payload, fmtIdent(field_name) });
         },
         else => {
             try writer.print(" = {s}", .{addrof});
             try f.writeCValue(writer, struct_ptr);
-            try writer.print("->{};\n", .{fmtIdent(field_name)});
+            try writer.print("->{s}{};\n", .{ payload, fmtIdent(field_name) });
         },
     }
     return local;
@@ -2700,14 +2812,18 @@ fn airStructFieldVal(f: *Function, inst: Air.Inst.Index) !CValue {
     const writer = f.object.writer();
     const struct_byval = try f.resolveInst(extra.struct_operand);
     const struct_ty = f.air.typeOf(extra.struct_operand);
-    const struct_obj = struct_ty.castTag(.@"struct").?.data;
-    const field_name = struct_obj.fields.keys()[extra.field_index];
+    const field_name = switch (struct_ty.tag()) {
+        .@"struct" => struct_ty.structFields().keys()[extra.field_index],
+        .@"union", .union_tagged => struct_ty.unionFields().keys()[extra.field_index],
+        else => unreachable,
+    };
+    const payload = if (struct_ty.tag() == .union_tagged) "payload." else "";
 
     const inst_ty = f.air.typeOfIndex(inst);
     const local = try f.allocLocal(inst_ty, .Const);
     try writer.writeAll(" = ");
     try f.writeCValue(writer, struct_byval);
-    try writer.print(".{};\n", .{fmtIdent(field_name)});
+    try writer.print(".{s}{};\n", .{ payload, fmtIdent(field_name) });
     return local;
 }
 
@@ -3048,9 +3164,13 @@ fn airSetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue {
     const new_tag = try f.resolveInst(bin_op.rhs);
     const writer = f.object.writer();
 
-    try writer.writeAll("*");
+    const union_ty = f.air.typeOf(bin_op.lhs).childType();
+    const target = f.object.dg.module.getTarget();
+    const layout = union_ty.unionGetLayout(target);
+    if (layout.tag_size == 0)  return CValue.none;
+
     try f.writeCValue(writer, union_ptr);
-    try writer.writeAll(" = ");
+    try writer.writeAll("->tag = ");
     try f.writeCValue(writer, new_tag);
     try writer.writeAll(";\n");
 
@@ -3064,12 +3184,17 @@ fn airGetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue {
     const inst_ty = f.air.typeOfIndex(inst);
     const local = try f.allocLocal(inst_ty, .Const);
     const ty_op = f.air.instructions.items(.data)[inst].ty_op;
+    const un_ty = f.air.typeOf(ty_op.operand);
     const writer = f.object.writer();
     const operand = try f.resolveInst(ty_op.operand);
 
-    try writer.writeAll("get_union_tag(");
+    const target = f.object.dg.module.getTarget();
+    const layout = un_ty.unionGetLayout(target);
+    if (layout.tag_size == 0)  return CValue.none;
+
+    try writer.writeAll(" = ");
     try f.writeCValue(writer, operand);
-    try writer.writeAll(");\n");
+    try writer.writeAll(".tag;\n");
     return local;
 }
 
test/behavior.zig
@@ -68,7 +68,7 @@ test {
             _ = @import("behavior/cast_int.zig");
             _ = @import("behavior/int128.zig");
 	    _ = @import("behavior/union.zig");
-//            _ = @import("behavior/translate_c_macros.zig");
+            _ = @import("behavior/translate_c_macros.zig");
 
             if (builtin.zig_backend != .stage2_c) {
                 // Tests that pass for stage1 and the llvm backend.