Commit 24bfefa75e

Ali Cheraghi <alichraghi@proton.me>
2025-06-22 11:09:28
std.mem.byteSwapAllFields: support untagged unions
1 parent 3034d1e
Changed files (2)
lib/std/mem.zig
@@ -2150,7 +2150,7 @@ pub fn byteSwapAllFields(comptime S: type, ptr: *S) void {
                     } else {
                         byteSwapAllFields(f.type, &@field(ptr, f.name));
                     },
-                    .array => byteSwapAllFields(f.type, &@field(ptr, f.name)),
+                    .@"union", .array => byteSwapAllFields(f.type, &@field(ptr, f.name)),
                     .@"enum" => {
                         @field(ptr, f.name) = @enumFromInt(@byteSwap(@intFromEnum(@field(ptr, f.name))));
                     },
@@ -2164,10 +2164,25 @@ pub fn byteSwapAllFields(comptime S: type, ptr: *S) void {
                 }
             }
         },
+        .@"union" => |union_info| {
+            if (union_info.tag_type != null) {
+                @compileError("byteSwapAllFields expects an untagged union");
+            }
+
+            const first_size = @bitSizeOf(union_info.fields[0].type);
+            inline for (union_info.fields) |field| {
+                if (@bitSizeOf(field.type) != first_size) {
+                    @compileError("Unable to byte-swap unions with varying field sizes");
+                }
+            }
+
+            const BackingInt = std.meta.Int(.unsigned, @bitSizeOf(S));
+            ptr.* = @bitCast(@byteSwap(@as(BackingInt, @bitCast(ptr.*))));
+        },
         .array => {
             for (ptr) |*item| {
                 switch (@typeInfo(@TypeOf(item.*))) {
-                    .@"struct", .array => byteSwapAllFields(@TypeOf(item.*), item),
+                    .@"struct", .@"union", .array => byteSwapAllFields(@TypeOf(item.*), item),
                     .@"enum" => {
                         item.* = @enumFromInt(@byteSwap(@intFromEnum(item.*)));
                     },
@@ -2193,6 +2208,7 @@ test byteSwapAllFields {
         f3: [1]u8,
         f4: bool,
         f5: f32,
+        f6: extern union { f0: u16, f1: u16 },
     };
     const K = extern struct {
         f0: u8,
@@ -2209,6 +2225,7 @@ test byteSwapAllFields {
         .f3 = .{0x12},
         .f4 = true,
         .f5 = @as(f32, @bitCast(@as(u32, 0x4640e400))),
+        .f6 = .{ .f0 = 0x1234 },
     };
     var k = K{
         .f0 = 0x12,
@@ -2227,6 +2244,7 @@ test byteSwapAllFields {
         .f3 = .{0x12},
         .f4 = true,
         .f5 = @as(f32, @bitCast(@as(u32, 0x00e44046))),
+        .f6 = .{ .f0 = 0x3412 },
     }, s);
     try std.testing.expectEqual(K{
         .f0 = 0x12,
lib/std/testing.zig
@@ -153,7 +153,18 @@ fn expectEqualInner(comptime T: type, expected: T, actual: T) !void {
 
         .@"union" => |union_info| {
             if (union_info.tag_type == null) {
-                @compileError("Unable to compare untagged union values for type " ++ @typeName(@TypeOf(actual)));
+                const first_size = @bitSizeOf(union_info.fields[0].type);
+                inline for (union_info.fields) |field| {
+                    if (@bitSizeOf(field.type) != first_size) {
+                        @compileError("Unable to compare untagged unions with varying field sizes for type " ++ @typeName(@TypeOf(actual)));
+                    }
+                }
+
+                const BackingInt = std.meta.Int(.unsigned, @bitSizeOf(T));
+                return expectEqual(
+                    @as(BackingInt, @bitCast(expected)),
+                    @as(BackingInt, @bitCast(actual)),
+                );
             }
 
             const Tag = std.meta.Tag(@TypeOf(expected));