Commit 54c097f50d

Ali Cheraghi <alichraghi@proton.me>
2025-03-12 05:08:50
spirv: packed struct init + field val access
1 parent 50539a2
src/codegen/spirv/Module.zig
@@ -613,6 +613,17 @@ pub fn functionType(self: *Module, return_ty_id: IdRef, param_type_ids: []const
     return result_id;
 }
 
+pub fn constant(self: *Module, result_ty_id: IdRef, value: spec.LiteralContextDependentNumber) !IdRef {
+    const result_id = self.allocId();
+    const section = &self.sections.types_globals_constants;
+    try section.emit(self.gpa, .OpConstant, .{
+        .id_result_type = result_ty_id,
+        .id_result = result_id,
+        .value = value,
+    });
+    return result_id;
+}
+
 pub fn constBool(self: *Module, value: bool) !IdRef {
     if (self.cache.bool_const[@intFromBool(value)]) |b| return b;
 
src/codegen/spirv.zig
@@ -714,6 +714,7 @@ const NavGen = struct {
         const int_info = scalar_ty.intInfo(zcu);
         // Use backing bits so that negatives are sign extended
         const backing_bits = self.backingIntBits(int_info.bits).?; // Assertion failure means big int
+        assert(backing_bits != 0); // u0 is comptime
 
         const signedness: Signedness = switch (@typeInfo(@TypeOf(value))) {
             .int => |int| int.signedness,
@@ -721,35 +722,35 @@ const NavGen = struct {
             else => unreachable,
         };
 
-        const value64: u64 = switch (signedness) {
-            .signed => @bitCast(@as(i64, @intCast(value))),
-            .unsigned => @as(u64, @intCast(value)),
-        };
+        const final_value: spec.LiteralContextDependentNumber = blk: {
+            if (self.spv.hasFeature(.kernel)) {
+                const value64: u64 = switch (signedness) {
+                    .signed => @bitCast(@as(i64, @intCast(value))),
+                    .unsigned => @as(u64, @intCast(value)),
+                };
 
-        // Manually truncate the value to the right amount of bits.
-        const truncated_value = if (backing_bits == 64)
-            value64
-        else
-            value64 & (@as(u64, 1) << @intCast(backing_bits)) - 1;
+                // Manually truncate the value to the right amount of bits.
+                const truncated_value = if (backing_bits == 64)
+                    value64
+                else
+                    value64 & (@as(u64, 1) << @intCast(backing_bits)) - 1;
 
-        const result_ty_id = try self.resolveType(scalar_ty, .indirect);
-        const result_id = self.spv.allocId();
+                break :blk switch (backing_bits) {
+                    1...32 => .{ .uint32 = @truncate(truncated_value) },
+                    33...64 => .{ .uint64 = truncated_value },
+                    else => unreachable, // TODO: Large integer constants
+                };
+            }
 
-        const section = &self.spv.sections.types_globals_constants;
-        switch (backing_bits) {
-            0 => unreachable, // u0 is comptime
-            1...32 => try section.emit(self.spv.gpa, .OpConstant, .{
-                .id_result_type = result_ty_id,
-                .id_result = result_id,
-                .value = .{ .uint32 = @truncate(truncated_value) },
-            }),
-            33...64 => try section.emit(self.spv.gpa, .OpConstant, .{
-                .id_result_type = result_ty_id,
-                .id_result = result_id,
-                .value = .{ .uint64 = truncated_value },
-            }),
-            else => unreachable, // TODO: Large integer constants
-        }
+            break :blk switch (backing_bits) {
+                1...32 => if (signedness == .signed) .{ .int32 = @intCast(value) } else .{ .uint32 = @intCast(value) },
+                33...64 => if (signedness == .signed) .{ .int64 = value } else .{ .uint64 = value },
+                else => unreachable, // TODO: Large integer constants
+            };
+        };
+
+        const result_ty_id = try self.resolveType(scalar_ty, .indirect);
+        const result_id = try self.spv.constant(result_ty_id, final_value);
 
         if (!ty.isVector(zcu)) return result_id;
         return self.constructCompositeSplat(ty, result_id);
@@ -804,8 +805,6 @@ const NavGen = struct {
             return self.spv.constUndef(result_ty_id);
         }
 
-        const section = &self.spv.sections.types_globals_constants;
-
         const cacheable_id = cache: {
             switch (ip.indexToKey(val.toIntern())) {
                 .int_type,
@@ -860,13 +859,7 @@ const NavGen = struct {
                         80, 128 => unreachable, // TODO
                         else => unreachable,
                     };
-                    const result_id = self.spv.allocId();
-                    try section.emit(self.spv.gpa, .OpConstant, .{
-                        .id_result_type = result_ty_id,
-                        .id_result = result_id,
-                        .value = lit,
-                    });
-                    break :cache result_id;
+                    break :cache try self.spv.constant(result_ty_id, lit);
                 },
                 .err => |err| {
                     const value = try pt.getErrorValue(err.name);
@@ -989,8 +982,17 @@ const NavGen = struct {
                     },
                     .struct_type => {
                         const struct_type = zcu.typeToStruct(ty).?;
+
                         if (struct_type.layout == .@"packed") {
-                            return self.todo("packed struct constants", .{});
+                            // TODO: composite int
+                            // TODO: endianness
+                            const bits: u16 = @intCast(ty.bitSize(zcu));
+                            const bytes = std.mem.alignForward(u16, self.backingIntBits(bits).?, 8) / 8;
+                            var limbs: [8]u8 = undefined;
+                            @memset(&limbs, 0);
+                            val.writeToPackedMemory(ty, pt, limbs[0..bytes], 0) catch unreachable;
+                            const backing_ty = Type.fromInterned(struct_type.backingIntTypeUnordered(ip));
+                            return try self.constInt(backing_ty, @as(u64, @bitCast(limbs)));
                         }
 
                         var types = std.ArrayList(Type).init(self.gpa);
@@ -4309,6 +4311,7 @@ const NavGen = struct {
     ) !Temporary {
         const pt = self.pt;
         const zcu = pt.zcu;
+        const ip = &zcu.intern_pool;
         const scalar_ty = lhs.ty.scalarType(zcu);
         const is_vector = lhs.ty.isVector(zcu);
 
@@ -4319,6 +4322,11 @@ const NavGen = struct {
                 const ty = lhs.ty.intTagType(zcu);
                 return try self.cmp(op, lhs.pun(ty), rhs.pun(ty));
             },
+            .@"struct" => {
+                const struct_ty = zcu.typeToPackedStruct(scalar_ty).?;
+                const ty = Type.fromInterned(struct_ty.backingIntTypeUnordered(ip));
+                return try self.cmp(op, lhs.pun(ty), rhs.pun(ty));
+            },
             .error_set => {
                 assert(!is_vector);
                 const err_int_ty = try pt.errorIntType();
@@ -4746,8 +4754,42 @@ const NavGen = struct {
         switch (result_ty.zigTypeTag(zcu)) {
             .@"struct" => {
                 if (zcu.typeToPackedStruct(result_ty)) |struct_type| {
-                    _ = struct_type;
-                    unreachable; // TODO
+                    comptime assert(Type.packed_struct_layout_version == 2);
+                    const backing_int_ty = Type.fromInterned(struct_type.backingIntTypeUnordered(ip));
+                    var running_int_id = try self.constInt(backing_int_ty, 0);
+                    var running_bits: u16 = 0;
+                    for (struct_type.field_types.get(ip), elements) |field_ty_ip, element| {
+                        const field_ty = Type.fromInterned(field_ty_ip);
+                        if (!field_ty.hasRuntimeBitsIgnoreComptime(zcu)) continue;
+                        const field_id = try self.resolve(element);
+                        const ty_bit_size: u16 = @intCast(field_ty.bitSize(zcu));
+                        const field_int_ty = try self.pt.intType(.unsigned, ty_bit_size);
+                        const field_int_id = blk: {
+                            if (field_ty.isPtrAtRuntime(zcu)) {
+                                assert(self.spv.hasFeature(.addresses) or
+                                    (self.spv.hasFeature(.physical_storage_buffer) and field_ty.ptrAddressSpace(zcu) == .storage_buffer));
+                                break :blk try self.intFromPtr(field_id);
+                            }
+                            break :blk try self.bitCast(field_int_ty, field_ty, field_id);
+                        };
+                        const shift_rhs = try self.constInt(backing_int_ty, running_bits);
+                        const extended_int_conv = try self.buildIntConvert(backing_int_ty, .{
+                            .ty = field_int_ty,
+                            .value = .{ .singleton = field_int_id },
+                        });
+                        const shifted = try self.buildBinary(.sll, extended_int_conv, .{
+                            .ty = backing_int_ty,
+                            .value = .{ .singleton = shift_rhs },
+                        });
+                        const running_int_tmp = try self.buildBinary(
+                            .bit_or,
+                            .{ .ty = backing_int_ty, .value = .{ .singleton = running_int_id } },
+                            shifted,
+                        );
+                        running_int_id = try running_int_tmp.materialize(self);
+                        running_bits += ty_bit_size;
+                    }
+                    return running_int_id;
                 }
 
                 const types = try self.gpa.alloc(Type, elements.len);
@@ -5156,6 +5198,7 @@ const NavGen = struct {
     fn airStructFieldVal(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
         const pt = self.pt;
         const zcu = pt.zcu;
+        const ip = &zcu.intern_pool;
         const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
         const struct_field = self.air.extraData(Air.StructField, ty_pl.payload).data;
 
@@ -5168,7 +5211,28 @@ const NavGen = struct {
 
         switch (object_ty.zigTypeTag(zcu)) {
             .@"struct" => switch (object_ty.containerLayout(zcu)) {
-                .@"packed" => unreachable, // TODO
+                .@"packed" => {
+                    const struct_ty = zcu.typeToPackedStruct(object_ty).?;
+                    const backing_int_ty = Type.fromInterned(struct_ty.backingIntTypeUnordered(ip));
+                    const bit_offset = pt.structPackedFieldBitOffset(struct_ty, field_index);
+                    const bit_offset_id = try self.constInt(.u16, bit_offset);
+                    const signedness = if (field_ty.isInt(zcu)) field_ty.intInfo(zcu).signedness else .unsigned;
+                    const field_bit_size: u16 = @intCast(field_ty.bitSize(zcu));
+                    const int_ty = try pt.intType(signedness, field_bit_size);
+                    const shift_lhs: Temporary = .{ .ty = backing_int_ty, .value = .{ .singleton = object_id } };
+                    const shift = try self.buildBinary(.srl, shift_lhs, .{ .ty = .u16, .value = .{ .singleton = bit_offset_id } });
+                    const mask_id = try self.constInt(backing_int_ty, (@as(u64, 1) << @as(u6, @intCast(field_bit_size))) - 1);
+                    const masked = try self.buildBinary(.bit_and, shift, .{ .ty = backing_int_ty, .value = .{ .singleton = mask_id } });
+                    const result_id = blk: {
+                        if (self.backingIntBits(field_bit_size).? == self.backingIntBits(@intCast(backing_int_ty.bitSize(zcu))).?)
+                            break :blk try self.bitCast(int_ty, backing_int_ty, try masked.materialize(self));
+                        const trunc = try self.buildIntConvert(int_ty, masked);
+                        break :blk try trunc.materialize(self);
+                    };
+                    if (field_ty.ip_index == .bool_type) return try self.convertToDirect(.bool, result_id);
+                    if (field_ty.isInt(zcu)) return result_id;
+                    return try self.bitCast(field_ty, int_ty, result_id);
+                },
                 else => return try self.extractField(field_ty, object_id, field_index),
             },
             .@"union" => switch (object_ty.containerLayout(zcu)) {
test/behavior/bitcast.zig
@@ -165,7 +165,6 @@ test "@bitCast packed structs at runtime and comptime" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const Full = packed struct {
         number: u16,
@@ -226,7 +225,6 @@ test "bitcast packed struct to integer and back" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const LevelUpMove = packed struct {
         move_id: u9,
test/behavior/packed-struct.zig
@@ -123,7 +123,6 @@ test "correct sizeOf and offsets in packed structs" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const PStruct = packed struct {
         bool_a: bool,
@@ -191,7 +190,6 @@ test "nested packed structs" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S1 = packed struct { a: u8, b: u8, c: u8 };
 
@@ -257,7 +255,6 @@ test "nested packed struct unaligned" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (native_endian != .little) return error.SkipZigTest; // Byte aligned packed struct field pointers have not been implemented yet
 
     const S1 = packed struct {
@@ -895,7 +892,6 @@ test "packed struct passed to callconv(.c) function" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
@@ -944,7 +940,6 @@ test "packed struct initialized in bitcast" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const T = packed struct { val: u8 };
@@ -982,7 +977,6 @@ test "pointer to container level packed struct field" {
 test "store undefined to packed result location" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     var x: u4 = 0;
@@ -992,8 +986,6 @@ test "store undefined to packed result location" {
 }
 
 test "bitcast back and forth" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     // Originally reported at https://github.com/ziglang/zig/issues/9914
     const S = packed struct { one: u6, two: u1 };
     const s = S{ .one = 0b110101, .two = 0b1 };
@@ -1290,8 +1282,6 @@ test "2-byte packed struct argument in C calling convention" {
 }
 
 test "packed struct contains optional pointer" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const foo: packed struct {
         a: ?*@This() = null,
     } = .{};
@@ -1299,8 +1289,6 @@ test "packed struct contains optional pointer" {
 }
 
 test "packed struct equality" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const Foo = packed struct {
         a: u4,
         b: u4,
@@ -1321,8 +1309,6 @@ test "packed struct equality" {
 }
 
 test "packed struct with signed field" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     var s: packed struct {
         a: i2,
         b: u6,
test/behavior/packed_struct_explicit_backing_int.zig
@@ -9,7 +9,6 @@ test "packed struct explicit backing integer" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S1 = packed struct { a: u8, b: u8, c: u8 };
 
test/behavior/ptrcast.zig
@@ -287,8 +287,6 @@ test "@ptrCast undefined value at comptime" {
 }
 
 test "comptime @ptrCast with packed struct leaves value unmodified" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const S = packed struct { three: u3 };
     const st: S = .{ .three = 6 };
     try expect(st.three == 6);
test/behavior/struct.zig
@@ -1023,7 +1023,6 @@ test "packed struct with undefined initializers" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         const P = packed struct {
@@ -1221,7 +1220,6 @@ test "packed struct aggregate init" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
 
     const S = struct {
@@ -1971,7 +1969,6 @@ test "struct field default value is a call" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const Z = packed struct {
         a: u32,
test/behavior/vector.zig
@@ -11,6 +11,7 @@ test "implicit cast vector to array - bool" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -29,6 +30,7 @@ test "vector wrap operators" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64 and
         !comptime std.Target.x86.featureSetHas(builtin.cpu.features, .sse4_1)) return error.SkipZigTest;