Commit 6281ad91df

Robin Voetter <robin@voetter.nl>
2023-10-21 17:26:59
spirv: self-referential pointers via new fwd_ptr_type
Its a little ugly but it works.
1 parent 6e955af
src/codegen/spirv/Assembler.zig
@@ -304,10 +304,16 @@ fn processTypeInstruction(self: *Assembler) !AsmValue {
             // and so some consideration must be taken when entering this in the type system.
             return self.todo("process OpTypeArray", .{});
         },
-        .OpTypePointer => try self.spv.ptrType(
-            try self.resolveTypeRef(operands[2].ref_id),
-            @as(spec.StorageClass, @enumFromInt(operands[1].value)),
-        ),
+        .OpTypePointer => blk: {
+            break :blk try self.spv.resolve(.{
+                .ptr_type = .{
+                    .storage_class = @enumFromInt(operands[1].value),
+                    .child_type = try self.resolveTypeRef(operands[2].ref_id),
+                    // TODO: This should be a proper reference resolved via OpTypeForwardPointer
+                    .fwd = @enumFromInt(std.math.maxInt(u32)),
+                },
+            });
+        },
         .OpTypeFunction => blk: {
             const param_operands = operands[2..];
             const param_types = try self.spv.gpa.alloc(CacheRef, param_operands.len);
src/codegen/spirv/Cache.zig
@@ -22,6 +22,8 @@ const Opcode = spec.Opcode;
 const IdResult = spec.IdResult;
 const StorageClass = spec.StorageClass;
 
+const InternPool = @import("../../InternPool.zig");
+
 const Self = @This();
 
 map: std.AutoArrayHashMapUnmanaged(void, void) = .{},
@@ -31,6 +33,8 @@ extra: std.ArrayListUnmanaged(u32) = .{},
 string_bytes: std.ArrayListUnmanaged(u8) = .{},
 strings: std.AutoArrayHashMapUnmanaged(void, u32) = .{},
 
+recursive_ptrs: std.AutoHashMapUnmanaged(Ref, void) = .{},
+
 const Item = struct {
     tag: Tag,
     /// The result-id that this item uses.
@@ -62,18 +66,21 @@ const Tag = enum {
     /// Function (proto)type
     /// data is payload to FunctionType
     type_function,
-    /// Pointer type in the CrossWorkgroup storage class
-    /// data is child type
-    type_ptr_generic,
-    /// Pointer type in the CrossWorkgroup storage class
-    /// data is child type
-    type_ptr_crosswgp,
-    /// Pointer type in the Function storage class
-    /// data is child type
-    type_ptr_function,
+    // /// Pointer type in the CrossWorkgroup storage class
+    // /// data is child type
+    // type_ptr_generic,
+    // /// Pointer type in the CrossWorkgroup storage class
+    // /// data is child type
+    // type_ptr_crosswgp,
+    // /// Pointer type in the Function storage class
+    // /// data is child type
+    // type_ptr_function,
     /// Simple pointer type that does not have any decorations.
     /// data is payload to SimplePointerType
     type_ptr_simple,
+    /// A forward declaration for a pointer.
+    /// data is ForwardPointerType
+    type_fwd_ptr,
     /// Simple structure type that does not have any decorations.
     /// data is payload to SimpleStructType
     type_struct_simple,
@@ -142,6 +149,12 @@ const Tag = enum {
     const SimplePointerType = struct {
         storage_class: StorageClass,
         child_type: Ref,
+        fwd: Ref,
+    };
+
+    const ForwardPointerType = struct {
+        storage_class: StorageClass,
+        zig_child_type: InternPool.Index,
     };
 
     /// Trailing:
@@ -163,14 +176,14 @@ const Tag = enum {
         fn encode(value: f64) Float64 {
             const bits = @as(u64, @bitCast(value));
             return .{
-                .low = @as(u32, @truncate(bits)),
-                .high = @as(u32, @truncate(bits >> 32)),
+                .low = @truncate(bits),
+                .high = @truncate(bits >> 32),
             };
         }
 
         fn decode(self: Float64) f64 {
             const bits = @as(u64, self.low) | (@as(u64, self.high) << 32);
-            return @as(f64, @bitCast(bits));
+            return @bitCast(bits);
         }
     };
 
@@ -192,8 +205,8 @@ const Tag = enum {
         fn encode(ty: Ref, value: u64) Int64 {
             return .{
                 .ty = ty,
-                .low = @as(u32, @truncate(value)),
-                .high = @as(u32, @truncate(value >> 32)),
+                .low = @truncate(value),
+                .high = @truncate(value >> 32),
             };
         }
 
@@ -210,8 +223,8 @@ const Tag = enum {
         fn encode(ty: Ref, value: i64) Int64 {
             return .{
                 .ty = ty,
-                .low = @as(u32, @truncate(@as(u64, @bitCast(value)))),
-                .high = @as(u32, @truncate(@as(u64, @bitCast(value)) >> 32)),
+                .low = @truncate(@as(u64, @bitCast(value))),
+                .high = @truncate(@as(u64, @bitCast(value)) >> 32),
             };
         }
 
@@ -237,6 +250,7 @@ pub const Key = union(enum) {
     array_type: ArrayType,
     function_type: FunctionType,
     ptr_type: PointerType,
+    fwd_ptr_type: ForwardPointerType,
     struct_type: StructType,
     opaque_type: OpaqueType,
 
@@ -273,12 +287,18 @@ pub const Key = union(enum) {
     pub const PointerType = struct {
         storage_class: StorageClass,
         child_type: Ref,
+        fwd: Ref,
         // TODO: Decorations:
         // - Alignment
         // - ArrayStride,
         // - MaxByteOffset,
     };
 
+    pub const ForwardPointerType = struct {
+        zig_child_type: InternPool.Index,
+        storage_class: StorageClass,
+    };
+
     pub const StructType = struct {
         // TODO: Decorations.
         /// The name of the structure. Can be `.none`.
@@ -313,21 +333,21 @@ pub const Key = union(enum) {
         /// Turns this value into the corresponding 32-bit literal, 2s complement signed.
         fn toBits32(self: Int) u32 {
             return switch (self.value) {
-                .uint64 => |val| @as(u32, @intCast(val)),
-                .int64 => |val| if (val < 0) @as(u32, @bitCast(@as(i32, @intCast(val)))) else @as(u32, @intCast(val)),
+                .uint64 => |val| @intCast(val),
+                .int64 => |val| if (val < 0) @bitCast(@as(i32, @intCast(val))) else @intCast(val),
             };
         }
 
         fn toBits64(self: Int) u64 {
             return switch (self.value) {
                 .uint64 => |val| val,
-                .int64 => |val| @as(u64, @bitCast(val)),
+                .int64 => |val| @bitCast(val),
             };
         }
 
         fn to(self: Int, comptime T: type) T {
             return switch (self.value) {
-                inline else => |val| @as(T, @intCast(val)),
+                inline else => |val| @intCast(val),
             };
         }
     };
@@ -387,7 +407,7 @@ pub const Key = union(enum) {
             },
             inline else => |key| std.hash.autoHash(&hasher, key),
         }
-        return @as(u32, @truncate(hasher.final()));
+        return @truncate(hasher.final());
     }
 
     fn eql(a: Key, b: Key) bool {
@@ -419,7 +439,7 @@ pub const Key = union(enum) {
 
         pub fn eql(ctx: @This(), a: Key, b_void: void, b_index: usize) bool {
             _ = b_void;
-            return ctx.self.lookup(@as(Ref, @enumFromInt(b_index))).eql(a);
+            return ctx.self.lookup(@enumFromInt(b_index)).eql(a);
         }
 
         pub fn hash(ctx: @This(), a: Key) u32 {
@@ -450,6 +470,7 @@ pub fn deinit(self: *Self, spv: *const Module) void {
     self.extra.deinit(spv.gpa);
     self.string_bytes.deinit(spv.gpa);
     self.strings.deinit(spv.gpa);
+    self.recursive_ptrs.deinit(spv.gpa);
 }
 
 /// Actually materialize the database into spir-v instructions.
@@ -460,7 +481,7 @@ pub fn materialize(self: *const Self, spv: *Module) !Section {
     var section = Section{};
     errdefer section.deinit(spv.gpa);
     for (self.items.items(.result_id), 0..) |result_id, index| {
-        try self.emit(spv, result_id, @as(Ref, @enumFromInt(index)), &section);
+        try self.emit(spv, result_id, @enumFromInt(index), &section);
     }
     return section;
 }
@@ -538,6 +559,15 @@ fn emit(
             });
             // TODO: Decorations?
         },
+        .fwd_ptr_type => |fwd| {
+            // Only emit the OpTypeForwardPointer if its actually required.
+            if (self.recursive_ptrs.contains(ref)) {
+                try section.emit(spv.gpa, .OpTypeForwardPointer, .{
+                    .pointer_type = result_id,
+                    .storage_class = fwd.storage_class,
+                });
+            }
+        },
         .struct_type => |struct_type| {
             try section.emitRaw(spv.gpa, .OpTypeStruct, 1 + struct_type.member_types.len);
             section.writeOperand(IdResult, result_id);
@@ -549,7 +579,7 @@ fn emit(
             }
             for (struct_type.memberNames(), 0..) |member_name, i| {
                 if (self.getString(member_name)) |name| {
-                    try spv.memberDebugName(result_id, @as(u32, @intCast(i)), name);
+                    try spv.memberDebugName(result_id, @intCast(i), name);
                 }
             }
             // TODO: Decorations?
@@ -625,13 +655,12 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
     const adapter: Key.Adapter = .{ .self = self };
     const entry = try self.map.getOrPutAdapted(spv.gpa, key, adapter);
     if (entry.found_existing) {
-        return @as(Ref, @enumFromInt(entry.index));
+        return @enumFromInt(entry.index);
     }
-    const result_id = spv.allocId();
     const item: Item = switch (key) {
         inline .void_type, .bool_type => .{
             .tag = .type_simple,
-            .result_id = result_id,
+            .result_id = spv.allocId(),
             .data = @intFromEnum(key.toSimpleType()),
         },
         .int_type => |int| blk: {
@@ -641,87 +670,104 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
             };
             break :blk .{
                 .tag = t,
-                .result_id = result_id,
+                .result_id = spv.allocId(),
                 .data = int.bits,
             };
         },
         .float_type => |float| .{
             .tag = .type_float,
-            .result_id = result_id,
+            .result_id = spv.allocId(),
             .data = float.bits,
         },
         .vector_type => |vector| .{
             .tag = .type_vector,
-            .result_id = result_id,
+            .result_id = spv.allocId(),
             .data = try self.addExtra(spv, vector),
         },
         .array_type => |array| .{
             .tag = .type_array,
-            .result_id = result_id,
+            .result_id = spv.allocId(),
             .data = try self.addExtra(spv, array),
         },
         .function_type => |function| blk: {
             const extra = try self.addExtra(spv, Tag.FunctionType{
-                .param_len = @as(u32, @intCast(function.parameters.len)),
+                .param_len = @intCast(function.parameters.len),
                 .return_type = function.return_type,
             });
-            try self.extra.appendSlice(spv.gpa, @as([]const u32, @ptrCast(function.parameters)));
+            try self.extra.appendSlice(spv.gpa, @ptrCast(function.parameters));
             break :blk .{
                 .tag = .type_function,
-                .result_id = result_id,
+                .result_id = spv.allocId(),
                 .data = extra,
             };
         },
-        .ptr_type => |ptr| switch (ptr.storage_class) {
-            .Generic => Item{
-                .tag = .type_ptr_generic,
-                .result_id = result_id,
-                .data = @intFromEnum(ptr.child_type),
-            },
-            .CrossWorkgroup => Item{
-                .tag = .type_ptr_crosswgp,
-                .result_id = result_id,
-                .data = @intFromEnum(ptr.child_type),
-            },
-            .Function => Item{
-                .tag = .type_ptr_function,
-                .result_id = result_id,
-                .data = @intFromEnum(ptr.child_type),
-            },
-            else => |storage_class| Item{
-                .tag = .type_ptr_simple,
-                .result_id = result_id,
-                .data = try self.addExtra(spv, Tag.SimplePointerType{
-                    .storage_class = storage_class,
-                    .child_type = ptr.child_type,
-                }),
-            },
+        // .ptr_type => |ptr| switch (ptr.storage_class) {
+        //     .Generic => Item{
+        //         .tag = .type_ptr_generic,
+        //         .result_id = spv.allocId(),
+        //         .data = @intFromEnum(ptr.child_type),
+        //     },
+        //     .CrossWorkgroup => Item{
+        //         .tag = .type_ptr_crosswgp,
+        //         .result_id = spv.allocId(),
+        //         .data = @intFromEnum(ptr.child_type),
+        //     },
+        //     .Function => Item{
+        //         .tag = .type_ptr_function,
+        //         .result_id = spv.allocId(),
+        //         .data = @intFromEnum(ptr.child_type),
+        //     },
+        //     else => |storage_class| Item{
+        //         .tag = .type_ptr_simple,
+        //         .result_id = spv.allocId(),
+        //         .data = try self.addExtra(spv, Tag.SimplePointerType{
+        //             .storage_class = storage_class,
+        //             .child_type = ptr.child_type,
+        //         }),
+        //     },
+        // },
+        .ptr_type => |ptr| Item{
+            .tag = .type_ptr_simple,
+            .result_id = self.resultId(ptr.fwd),
+            .data = try self.addExtra(spv, Tag.SimplePointerType{
+                .storage_class = ptr.storage_class,
+                .child_type = ptr.child_type,
+                .fwd = ptr.fwd,
+            }),
+        },
+        .fwd_ptr_type => |fwd| Item{
+            .tag = .type_fwd_ptr,
+            .result_id = spv.allocId(),
+            .data = try self.addExtra(spv, Tag.ForwardPointerType{
+                .zig_child_type = fwd.zig_child_type,
+                .storage_class = fwd.storage_class,
+            }),
         },
         .struct_type => |struct_type| blk: {
             const extra = try self.addExtra(spv, Tag.SimpleStructType{
                 .name = struct_type.name,
-                .members_len = @as(u32, @intCast(struct_type.member_types.len)),
+                .members_len = @intCast(struct_type.member_types.len),
             });
-            try self.extra.appendSlice(spv.gpa, @as([]const u32, @ptrCast(struct_type.member_types)));
+            try self.extra.appendSlice(spv.gpa, @ptrCast(struct_type.member_types));
 
             if (struct_type.member_names) |member_names| {
-                try self.extra.appendSlice(spv.gpa, @as([]const u32, @ptrCast(member_names)));
+                try self.extra.appendSlice(spv.gpa, @ptrCast(member_names));
                 break :blk Item{
                     .tag = .type_struct_simple_with_member_names,
-                    .result_id = result_id,
+                    .result_id = spv.allocId(),
                     .data = extra,
                 };
             } else {
                 break :blk Item{
                     .tag = .type_struct_simple,
-                    .result_id = result_id,
+                    .result_id = spv.allocId(),
                     .data = extra,
                 };
             }
         },
         .opaque_type => |opaque_type| Item{
             .tag = .type_opaque,
-            .result_id = result_id,
+            .result_id = spv.allocId(),
             .data = @intFromEnum(opaque_type.name),
         },
         .int => |int| blk: {
@@ -729,13 +775,13 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
             if (int_type.signedness == .unsigned and int_type.bits == 8) {
                 break :blk .{
                     .tag = .uint8,
-                    .result_id = result_id,
+                    .result_id = spv.allocId(),
                     .data = int.to(u8),
                 };
             } else if (int_type.signedness == .unsigned and int_type.bits == 32) {
                 break :blk .{
                     .tag = .uint32,
-                    .result_id = result_id,
+                    .result_id = spv.allocId(),
                     .data = int.to(u32),
                 };
             }
@@ -745,32 +791,32 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
                     if (val >= 0 and val <= std.math.maxInt(u32)) {
                         break :blk .{
                             .tag = .uint_small,
-                            .result_id = result_id,
+                            .result_id = spv.allocId(),
                             .data = try self.addExtra(spv, Tag.UInt32{
                                 .ty = int.ty,
-                                .value = @as(u32, @intCast(val)),
+                                .value = @intCast(val),
                             }),
                         };
                     } else if (val >= std.math.minInt(i32) and val <= std.math.maxInt(i32)) {
                         break :blk .{
                             .tag = .int_small,
-                            .result_id = result_id,
+                            .result_id = spv.allocId(),
                             .data = try self.addExtra(spv, Tag.Int32{
                                 .ty = int.ty,
-                                .value = @as(i32, @intCast(val)),
+                                .value = @intCast(val),
                             }),
                         };
                     } else if (val < 0) {
                         break :blk .{
                             .tag = .int_large,
-                            .result_id = result_id,
-                            .data = try self.addExtra(spv, Tag.Int64.encode(int.ty, @as(i64, @intCast(val)))),
+                            .result_id = spv.allocId(),
+                            .data = try self.addExtra(spv, Tag.Int64.encode(int.ty, @intCast(val))),
                         };
                     } else {
                         break :blk .{
                             .tag = .uint_large,
-                            .result_id = result_id,
-                            .data = try self.addExtra(spv, Tag.UInt64.encode(int.ty, @as(u64, @intCast(val)))),
+                            .result_id = spv.allocId(),
+                            .data = try self.addExtra(spv, Tag.UInt64.encode(int.ty, @intCast(val))),
                         };
                     }
                 },
@@ -779,29 +825,29 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
         .float => |float| switch (self.lookup(float.ty).float_type.bits) {
             16 => .{
                 .tag = .float16,
-                .result_id = result_id,
+                .result_id = spv.allocId(),
                 .data = @as(u16, @bitCast(float.value.float16)),
             },
             32 => .{
                 .tag = .float32,
-                .result_id = result_id,
+                .result_id = spv.allocId(),
                 .data = @as(u32, @bitCast(float.value.float32)),
             },
             64 => .{
                 .tag = .float64,
-                .result_id = result_id,
+                .result_id = spv.allocId(),
                 .data = try self.addExtra(spv, Tag.Float64.encode(float.value.float64)),
             },
             else => unreachable,
         },
         .undef => |undef| .{
             .tag = .undef,
-            .result_id = result_id,
+            .result_id = spv.allocId(),
             .data = @intFromEnum(undef.ty),
         },
         .null => |null_info| .{
             .tag = .null,
-            .result_id = result_id,
+            .result_id = spv.allocId(),
             .data = @intFromEnum(null_info.ty),
         },
         .bool => |bool_info| .{
@@ -809,13 +855,13 @@ pub fn resolve(self: *Self, spv: *Module, key: Key) !Ref {
                 true => Tag.bool_true,
                 false => Tag.bool_false,
             },
-            .result_id = result_id,
+            .result_id = spv.allocId(),
             .data = @intFromEnum(bool_info.ty),
         },
     };
     try self.items.append(spv.gpa, item);
 
-    return @as(Ref, @enumFromInt(entry.index));
+    return @enumFromInt(entry.index);
 }
 
 /// Turn a Ref back into a Key.
@@ -830,14 +876,14 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
         },
         .type_int_signed => .{ .int_type = .{
             .signedness = .signed,
-            .bits = @as(u16, @intCast(data)),
+            .bits = @intCast(data),
         } },
         .type_int_unsigned => .{ .int_type = .{
             .signedness = .unsigned,
-            .bits = @as(u16, @intCast(data)),
+            .bits = @intCast(data),
         } },
         .type_float => .{ .float_type = .{
-            .bits = @as(u16, @intCast(data)),
+            .bits = @intCast(data),
         } },
         .type_vector => .{ .vector_type = self.extraData(Tag.VectorType, data) },
         .type_array => .{ .array_type = self.extraData(Tag.ArrayType, data) },
@@ -846,40 +892,50 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
             return .{
                 .function_type = .{
                     .return_type = payload.data.return_type,
-                    .parameters = @as([]const Ref, @ptrCast(self.extra.items[payload.trail..][0..payload.data.param_len])),
+                    .parameters = @ptrCast(self.extra.items[payload.trail..][0..payload.data.param_len]),
                 },
             };
         },
-        .type_ptr_generic => .{
-            .ptr_type = .{
-                .storage_class = .Generic,
-                .child_type = @as(Ref, @enumFromInt(data)),
-            },
-        },
-        .type_ptr_crosswgp => .{
-            .ptr_type = .{
-                .storage_class = .CrossWorkgroup,
-                .child_type = @as(Ref, @enumFromInt(data)),
-            },
-        },
-        .type_ptr_function => .{
-            .ptr_type = .{
-                .storage_class = .Function,
-                .child_type = @as(Ref, @enumFromInt(data)),
-            },
-        },
+        // .type_ptr_generic => .{
+        //     .ptr_type = .{
+        //         .storage_class = .Generic,
+        //         .child_type = @enumFromInt(data),
+        //     },
+        // },
+        // .type_ptr_crosswgp => .{
+        //     .ptr_type = .{
+        //         .storage_class = .CrossWorkgroup,
+        //         .child_type = @enumFromInt(data),
+        //     },
+        // },
+        // .type_ptr_function => .{
+        //     .ptr_type = .{
+        //         .storage_class = .Function,
+        //         .child_type = @enumFromInt(data),
+        //     },
+        // },
         .type_ptr_simple => {
             const payload = self.extraData(Tag.SimplePointerType, data);
             return .{
                 .ptr_type = .{
                     .storage_class = payload.storage_class,
                     .child_type = payload.child_type,
+                    .fwd = payload.fwd,
+                },
+            };
+        },
+        .type_fwd_ptr => {
+            const payload = self.extraData(Tag.ForwardPointerType, data);
+            return .{
+                .fwd_ptr_type = .{
+                    .zig_child_type = payload.zig_child_type,
+                    .storage_class = payload.storage_class,
                 },
             };
         },
         .type_struct_simple => {
             const payload = self.extraDataTrail(Tag.SimpleStructType, data);
-            const member_types = @as([]const Ref, @ptrCast(self.extra.items[payload.trail..][0..payload.data.members_len]));
+            const member_types: []const Ref = @ptrCast(self.extra.items[payload.trail..][0..payload.data.members_len]);
             return .{
                 .struct_type = .{
                     .name = payload.data.name,
@@ -891,8 +947,8 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
         .type_struct_simple_with_member_names => {
             const payload = self.extraDataTrail(Tag.SimpleStructType, data);
             const trailing = self.extra.items[payload.trail..];
-            const member_types = @as([]const Ref, @ptrCast(trailing[0..payload.data.members_len]));
-            const member_names = @as([]const String, @ptrCast(trailing[payload.data.members_len..][0..payload.data.members_len]));
+            const member_types: []const Ref = @ptrCast(trailing[0..payload.data.members_len]);
+            const member_names: []const String = @ptrCast(trailing[payload.data.members_len..][0..payload.data.members_len]);
             return .{
                 .struct_type = .{
                     .name = payload.data.name,
@@ -903,16 +959,16 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
         },
         .type_opaque => .{
             .opaque_type = .{
-                .name = @as(String, @enumFromInt(data)),
+                .name = @enumFromInt(data),
             },
         },
         .float16 => .{ .float = .{
             .ty = self.get(.{ .float_type = .{ .bits = 16 } }),
-            .value = .{ .float16 = @as(f16, @bitCast(@as(u16, @intCast(data)))) },
+            .value = .{ .float16 = @bitCast(@as(u16, @intCast(data))) },
         } },
         .float32 => .{ .float = .{
             .ty = self.get(.{ .float_type = .{ .bits = 32 } }),
-            .value = .{ .float32 = @as(f32, @bitCast(data)) },
+            .value = .{ .float32 = @bitCast(data) },
         } },
         .float64 => .{ .float = .{
             .ty = self.get(.{ .float_type = .{ .bits = 64 } }),
@@ -955,17 +1011,17 @@ pub fn lookup(self: *const Self, ref: Ref) Key {
             } };
         },
         .undef => .{ .undef = .{
-            .ty = @as(Ref, @enumFromInt(data)),
+            .ty = @enumFromInt(data),
         } },
         .null => .{ .null = .{
-            .ty = @as(Ref, @enumFromInt(data)),
+            .ty = @enumFromInt(data),
         } },
         .bool_true => .{ .bool = .{
-            .ty = @as(Ref, @enumFromInt(data)),
+            .ty = @enumFromInt(data),
             .value = true,
         } },
         .bool_false => .{ .bool = .{
-            .ty = @as(Ref, @enumFromInt(data)),
+            .ty = @enumFromInt(data),
             .value = false,
         } },
     };
@@ -981,7 +1037,7 @@ pub fn resultId(self: Self, ref: Ref) IdResult {
 fn get(self: *const Self, key: Key) Ref {
     const adapter: Key.Adapter = .{ .self = self };
     const index = self.map.getIndexAdapted(key, adapter).?;
-    return @as(Ref, @enumFromInt(index));
+    return @enumFromInt(index);
 }
 
 fn addExtra(self: *Self, spv: *Module, extra: anytype) !u32 {
@@ -991,15 +1047,16 @@ fn addExtra(self: *Self, spv: *Module, extra: anytype) !u32 {
 }
 
 fn addExtraAssumeCapacity(self: *Self, extra: anytype) !u32 {
-    const payload_offset = @as(u32, @intCast(self.extra.items.len));
+    const payload_offset: u32 = @intCast(self.extra.items.len);
     inline for (@typeInfo(@TypeOf(extra)).Struct.fields) |field| {
         const field_val = @field(extra, field.name);
-        const word = switch (field.type) {
+        const word: u32 = switch (field.type) {
             u32 => field_val,
-            i32 => @as(u32, @bitCast(field_val)),
+            i32 => @bitCast(field_val),
             Ref => @intFromEnum(field_val),
             StorageClass => @intFromEnum(field_val),
             String => @intFromEnum(field_val),
+            InternPool.Index => @intFromEnum(field_val),
             else => @compileError("Invalid type: " ++ @typeName(field.type)),
         };
         self.extra.appendAssumeCapacity(word);
@@ -1018,10 +1075,11 @@ fn extraDataTrail(self: Self, comptime T: type, offset: u32) struct { data: T, t
         const word = self.extra.items[offset + i];
         @field(result, field.name) = switch (field.type) {
             u32 => word,
-            i32 => @as(i32, @bitCast(word)),
-            Ref => @as(Ref, @enumFromInt(word)),
-            StorageClass => @as(StorageClass, @enumFromInt(word)),
-            String => @as(String, @enumFromInt(word)),
+            i32 => @bitCast(word),
+            Ref => @enumFromInt(word),
+            StorageClass => @enumFromInt(word),
+            String => @enumFromInt(word),
+            InternPool.Index => @enumFromInt(word),
             else => @compileError("Invalid type: " ++ @typeName(field.type)),
         };
     }
@@ -1049,7 +1107,7 @@ pub const String = enum(u32) {
             _ = ctx;
             var hasher = std.hash.Wyhash.init(0);
             hasher.update(a);
-            return @as(u32, @truncate(hasher.final()));
+            return @truncate(hasher.final());
         }
     };
 };
@@ -1064,10 +1122,10 @@ pub fn addString(self: *Self, spv: *Module, str: []const u8) !String {
         try self.string_bytes.ensureUnusedCapacity(spv.gpa, 1 + str.len);
         self.string_bytes.appendSliceAssumeCapacity(str);
         self.string_bytes.appendAssumeCapacity(0);
-        entry.value_ptr.* = @as(u32, @intCast(offset));
+        entry.value_ptr.* = @intCast(offset);
     }
 
-    return @as(String, @enumFromInt(entry.index));
+    return @enumFromInt(entry.index);
 }
 
 pub fn getString(self: *const Self, ref: String) ?[]const u8 {
src/codegen/spirv/Module.zig
@@ -507,17 +507,6 @@ pub fn arrayType(self: *Module, len: u32, elem_ty_ref: CacheRef) !CacheRef {
     } });
 }
 
-pub fn ptrType(
-    self: *Module,
-    child: CacheRef,
-    storage_class: spec.StorageClass,
-) !CacheRef {
-    return try self.resolve(.{ .ptr_type = .{
-        .storage_class = storage_class,
-        .child_type = child,
-    } });
-}
-
 pub fn constInt(self: *Module, ty_ref: CacheRef, value: anytype) !IdRef {
     const ty = self.cache.lookup(ty_ref).int_type;
     const Value = Cache.Key.Int.Value;
src/codegen/spirv.zig
@@ -209,6 +209,10 @@ const DeclGen = struct {
     /// See Object.type_map
     type_map: *TypeMap,
 
+    /// Child types of pointers that are currently in progress of being resolved. If a pointer
+    /// is already in this map, its recursive.
+    wip_pointers: std.AutoHashMapUnmanaged(struct { InternPool.Index, StorageClass }, CacheRef) = .{},
+
     /// We need to keep track of result ids for block labels, as well as the 'incoming'
     /// blocks for a block.
     blocks: BlockMap = .{},
@@ -295,6 +299,7 @@ const DeclGen = struct {
     pub fn deinit(self: *DeclGen) void {
         self.args.deinit(self.gpa);
         self.inst_results.deinit(self.gpa);
+        self.wip_pointers.deinit(self.gpa);
         self.blocks.deinit(self.gpa);
         self.func.deinit(self.gpa);
         self.base_line_stack.deinit(self.gpa);
@@ -1100,9 +1105,30 @@ const DeclGen = struct {
     }
 
     fn ptrType(self: *DeclGen, child_ty: Type, storage_class: StorageClass) !CacheRef {
-        // TODO: This function will be rewritten so that forward declarations work properly
+        const key = .{ child_ty.toIntern(), storage_class };
+        const entry = try self.wip_pointers.getOrPut(self.gpa, key);
+        if (entry.found_existing) {
+            const fwd_ref = entry.value_ptr.*;
+            try self.spv.cache.recursive_ptrs.put(self.spv.gpa, fwd_ref, {});
+            return fwd_ref;
+        }
+
+        const fwd_ref = try self.spv.resolve(.{ .fwd_ptr_type = .{
+            .zig_child_type = child_ty.toIntern(),
+            .storage_class = storage_class,
+        } });
+        entry.value_ptr.* = fwd_ref;
+
         const child_ty_ref = try self.resolveType(child_ty, .indirect);
-        return try self.spv.ptrType(child_ty_ref, storage_class);
+        _ = try self.spv.resolve(.{ .ptr_type = .{
+            .storage_class = storage_class,
+            .child_type = child_ty_ref,
+            .fwd = fwd_ref,
+        } });
+
+        assert(self.wip_pointers.remove(key));
+
+        return fwd_ref;
     }
 
     /// Generate a union type. Union types are always generated with the
@@ -1323,12 +1349,12 @@ const DeclGen = struct {
             .Pointer => {
                 const ptr_info = ty.ptrInfo(mod);
 
+                // Note: Don't cache this pointer type, it would mess up the recursive pointer functionality
+                // in ptrType()!
+
                 const storage_class = spvStorageClass(ptr_info.flags.address_space);
-                const child_ty_ref = try self.resolveType(ptr_info.child.toType(), .indirect);
-                const ptr_ty_ref = try self.spv.resolve(.{ .ptr_type = .{
-                    .storage_class = storage_class,
-                    .child_type = child_ty_ref,
-                } });
+                const ptr_ty_ref = try self.ptrType(ptr_info.child.toType(), storage_class);
+
                 if (ptr_info.flags.size != .Slice) {
                     return ptr_ty_ref;
                 }
@@ -4371,6 +4397,7 @@ const DeclGen = struct {
             }
 
             // TODO: Multiple results
+            // TODO: Check that the output type from assembly is the same as the type actually expected by Zig.
         }
 
         return null;
test/behavior/bugs/12000.zig
@@ -9,7 +9,6 @@ test {
     if (builtin.zig_backend == .stage2_x86) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var t: T = .{ .next = null };
     try std.testing.expect(t.next == null);
test/behavior/bugs/1735.zig
@@ -44,7 +44,6 @@ const a = struct {
 test "initialization" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var t = a.init();
     try std.testing.expect(t.foo.len == 0);
test/behavior/bugs/1914.zig
@@ -12,8 +12,6 @@ const b_list: []B = &[_]B{};
 const a = A{ .b_list_pointer = &b_list };
 
 test "segfault bug" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const assert = std.debug.assert;
     const obj = B{ .a_pointer = &a };
     assert(obj.a_pointer == &a); // this makes zig crash
@@ -30,7 +28,5 @@ pub const B2 = struct {
 var b_value = B2{ .pointer_array = &[_]*A2{} };
 
 test "basic stuff" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     std.debug.assert(&b_value == &b_value);
 }
test/behavior/bugs/2006.zig
@@ -7,7 +7,6 @@ const S = struct {
 };
 test "bug 2006" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var a: S = undefined;
     a = S{ .p = undefined };
test/behavior/bugs/3007.zig
@@ -22,7 +22,6 @@ test "fixed" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     default_foo = get_foo() catch null; // This Line
     try std.testing.expect(!default_foo.?.free);
test/behavior/bugs/6947.zig
@@ -8,7 +8,6 @@ test {
     if (builtin.zig_backend == .stage2_x86) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var slice: []void = undefined;
     destroy(&slice[0]);
test/behavior/bugs/7325.zig
@@ -81,7 +81,6 @@ test {
     if (builtin.zig_backend == .stage2_x86) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var param: ParamType = .{
         .one_of = .{ .name = "name" },
test/behavior/error.zig
@@ -943,7 +943,6 @@ test "returning an error union containing a type with no runtime bits" {
 test "try used in recursive function with inferred error set" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const Value = union(enum) {
         values: []const @This(),
test/behavior/eval.zig
@@ -391,7 +391,6 @@ test "return 0 from function that has u0 return type" {
 test "statically initialized struct" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     st_init_str_foo.x += 1;
     try expect(st_init_str_foo.x == 14);
@@ -498,7 +497,6 @@ test "comptime shlWithOverflow" {
 test "const ptr to variable data changes at runtime" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try expect(foo_ref.name[0] == 'a');
     foo_ref.name = "b";
@@ -1551,8 +1549,6 @@ test "comptime function turns function value to function pointer" {
 }
 
 test "container level const and var have unique addresses" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const S = struct {
         x: i32,
         y: i32,
test/behavior/generics.zig
@@ -205,7 +205,6 @@ fn foo2(arg: anytype) bool {
 
 test "generic struct" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var a1 = GenNode(i32){
         .value = 13,
test/behavior/null.zig
@@ -185,7 +185,6 @@ test "unwrap optional which is field of global var" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     struct_with_optional.field = null;
     if (struct_with_optional.field) |payload| {
test/behavior/optional.zig
@@ -193,7 +193,6 @@ test "nested orelse" {
 test "self-referential struct through a slice of optional" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         const Node = struct {
test/behavior/ptrcast.zig
@@ -130,7 +130,6 @@ test "lower reinterpreted comptime field ptr (with under-aligned fields)" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     // Test lowering a field ptr
     comptime var bytes align(2) = [_]u8{ 1, 2, 3, 4, 5, 6 };
@@ -153,7 +152,6 @@ test "lower reinterpreted comptime field ptr" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     // Test lowering a field ptr
     comptime var bytes align(4) = [_]u8{ 1, 2, 3, 4, 5, 6, 7, 8 };
test/behavior/struct.zig
@@ -292,7 +292,6 @@ const Val = struct {
 test "struct point to self" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var root: Node = undefined;
     root.val.x = 1;
@@ -347,7 +346,6 @@ test "self-referencing struct via array member" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const T = struct {
         children: [1]*@This(),
@@ -370,7 +368,6 @@ const EmptyStruct = struct {
 
 test "align 1 field before self referential align 8 field as slice return type" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const result = alloc(Expr);
     try expect(result.len == 0);
@@ -1422,7 +1419,6 @@ test "fieldParentPtr of a zero-bit field" {
 
 test "struct field has a pointer to an aligned version of itself" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const E = struct {
         next: *align(1) @This(),
@@ -1518,7 +1514,6 @@ test "function pointer in struct returns the struct" {
 
 test "no dependency loop on optional field wrapped in generic function" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn Atomic(comptime T: type) type {
test/behavior/struct_contains_null_ptr_itself.zig
@@ -5,7 +5,6 @@ const builtin = @import("builtin");
 test "struct contains null pointer which contains original struct" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var x: ?*NodeLineComment = null;
     try expect(x == null);
test/behavior/struct_contains_slice_of_itself.zig
@@ -13,7 +13,6 @@ const NodeAligned = struct {
 
 test "struct contains slice of itself" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var other_nodes = [_]Node{
         Node{
@@ -54,7 +53,6 @@ test "struct contains slice of itself" {
 test "struct contains aligned slice of itself" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var other_nodes = [_]NodeAligned{
         NodeAligned{