Commit b845c9d532

Robin Voetter <robin@voetter.nl>
2023-09-18 22:39:44
spirv: generate module initializer
1 parent 5d844fa
Changed files (4)
src
codegen
test
src/codegen/spirv/Module.zig
@@ -94,6 +94,8 @@ pub const Global = struct {
     begin_inst: u32,
     /// The past-end offset into `self.flobals.section`.
     end_inst: u32,
+    /// The result-id of the function that initializes this value.
+    initializer_id: IdRef,
 };
 
 /// This models a kernel entry point.
@@ -174,9 +176,6 @@ globals: struct {
     section: Section = .{},
 } = .{},
 
-/// The function IDs of global variable initializers
-initializers: std.ArrayListUnmanaged(IdRef) = .{},
-
 pub fn init(gpa: Allocator, arena: Allocator) Module {
     return .{
         .gpa = gpa,
@@ -205,8 +204,6 @@ pub fn deinit(self: *Module) void {
     self.globals.globals.deinit(self.gpa);
     self.globals.section.deinit(self.gpa);
 
-    self.initializers.deinit(self.gpa);
-
     self.* = undefined;
 }
 
@@ -289,6 +286,10 @@ fn addEntryPointDeps(
     const decl = self.declPtr(decl_index);
     const deps = self.decl_deps.items[decl.begin_dep..decl.end_dep];
 
+    if (seen.isSet(@intFromEnum(decl_index))) {
+        return;
+    }
+
     seen.set(@intFromEnum(decl_index));
 
     if (self.globalPtr(decl_index)) |global| {
@@ -296,9 +297,7 @@ fn addEntryPointDeps(
     }
 
     for (deps) |dep| {
-        if (!seen.isSet(@intFromEnum(dep))) {
-            try self.addEntryPointDeps(dep, seen, interface);
-        }
+        try self.addEntryPointDeps(dep, seen, interface);
     }
 }
 
@@ -330,20 +329,76 @@ fn entryPoints(self: *Module) !Section {
     return entry_points;
 }
 
+/// Generate a function that calls all initialization functions,
+/// in unspecified order (an order should not be required here).
+/// It generated as follows:
+/// %init = OpFunction %void None
+/// foreach %initializer:
+/// OpFunctionCall %initializer
+/// OpReturn
+/// OpFunctionEnd
+fn initializer(self: *Module, entry_points: *Section) !Section {
+    var section = Section{};
+    errdefer section.deinit(self.gpa);
+
+    // const void_ty_ref = try self.resolveType(Type.void, .direct);
+    const void_ty_ref = try self.resolve(.void_type);
+    const void_ty_id = self.resultId(void_ty_ref);
+    const init_proto_ty_ref = try self.resolve(.{ .function_type = .{
+        .return_type = void_ty_ref,
+        .parameters = &.{},
+    } });
+
+    const init_id = self.allocId();
+    try section.emit(self.gpa, .OpFunction, .{
+        .id_result_type = void_ty_id,
+        .id_result = init_id,
+        .function_control = .{},
+        .function_type = self.resultId(init_proto_ty_ref),
+    });
+    try section.emit(self.gpa, .OpLabel, .{
+        .id_result = self.allocId(),
+    });
+
+    var seen = try std.DynamicBitSetUnmanaged.initEmpty(self.gpa, self.decls.items.len);
+    defer seen.deinit(self.gpa);
+
+    var interface = std.ArrayList(IdRef).init(self.gpa);
+    defer interface.deinit();
+
+    for (self.globals.globals.keys(), self.globals.globals.values()) |decl_index, global| {
+        try self.addEntryPointDeps(decl_index, &seen, &interface);
+        try section.emit(self.gpa, .OpFunctionCall, .{
+            .id_result_type = void_ty_id,
+            .id_result = self.allocId(),
+            .function = global.initializer_id,
+        });
+    }
+
+    try section.emit(self.gpa, .OpReturn, {});
+    try section.emit(self.gpa, .OpFunctionEnd, {});
+
+    try entry_points.emit(self.gpa, .OpEntryPoint, .{
+        // TODO: Rusticl does not support this because its poorly defined.
+        // Do we need to generate a workaround here?
+        .execution_model = .Kernel,
+        .entry_point = init_id,
+        .name = "zig global initializer",
+        .interface = interface.items,
+    });
+
+    try self.sections.execution_modes.emit(self.gpa, .OpExecutionMode, .{
+        .entry_point = init_id,
+        .mode = .Initializer,
+    });
+
+    return section;
+}
+
 /// Emit this module as a spir-v binary.
 pub fn flush(self: *Module, file: std.fs.File) !void {
     // See SPIR-V Spec section 2.3, "Physical Layout of a SPIR-V Module and Instruction"
 
-    const header = [_]Word{
-        spec.magic_number,
-        // TODO: From cpu features
-        //   Emit SPIR-V 1.4 for now. This is the highest version that Intel's CPU OpenCL supports.
-        (1 << 16) | (4 << 8),
-        0, // TODO: Register Zig compiler magic number.
-        self.idBound(),
-        0, // Schema (currently reserved for future use)
-    };
-
     // TODO: Perform topological sort on the globals.
     var globals = try self.orderGlobals();
     defer globals.deinit(self.gpa);
@@ -354,6 +409,19 @@ pub fn flush(self: *Module, file: std.fs.File) !void {
     var types_constants = try self.cache.materialize(self);
     defer types_constants.deinit(self.gpa);
 
+    var init_func = try self.initializer(&entry_points);
+    defer init_func.deinit(self.gpa);
+
+    const header = [_]Word{
+        spec.magic_number,
+        // TODO: From cpu features
+        //   Emit SPIR-V 1.4 for now. This is the highest version that Intel's CPU OpenCL supports.
+        (1 << 16) | (4 << 8),
+        0, // TODO: Register Zig compiler magic number.
+        self.idBound(),
+        0, // Schema (currently reserved for future use)
+    };
+
     // Note: needs to be kept in order according to section 2.3!
     const buffers = &[_][]const Word{
         &header,
@@ -368,6 +436,7 @@ pub fn flush(self: *Module, file: std.fs.File) !void {
         self.sections.types_globals_constants.toWords(),
         globals.toWords(),
         self.sections.functions.toWords(),
+        init_func.toWords(),
     };
 
     var iovc_buffers: [buffers.len]std.os.iovec_const = undefined;
@@ -529,6 +598,7 @@ pub fn allocDecl(self: *Module, kind: DeclKind) !Decl.Index {
             .result_id = undefined,
             .begin_inst = undefined,
             .end_inst = undefined,
+            .initializer_id = undefined,
         }),
     }
 
@@ -558,10 +628,14 @@ pub fn beginGlobal(self: *Module) u32 {
     return @as(u32, @intCast(self.globals.section.instructions.items.len));
 }
 
-pub fn endGlobal(self: *Module, global_index: Decl.Index, begin_inst: u32) void {
+pub fn endGlobal(self: *Module, global_index: Decl.Index, begin_inst: u32, result_id: IdRef, initializer_id: IdRef) void {
     const global = self.globalPtr(global_index).?;
-    global.begin_inst = begin_inst;
-    global.end_inst = @as(u32, @intCast(self.globals.section.instructions.items.len));
+    global.* = .{
+        .result_id = result_id,
+        .begin_inst = begin_inst,
+        .end_inst = @intCast(self.globals.section.instructions.items.len),
+        .initializer_id = initializer_id,
+    };
 }
 
 pub fn declareEntryPoint(self: *Module, decl_index: Decl.Index, name: []const u8) !void {
src/codegen/spirv.zig
@@ -1494,7 +1494,6 @@ pub const DeclGen = struct {
                 .id_result = decl_id,
                 .storage_class = actual_storage_class,
             });
-            self.spv.globalPtr(spv_decl_index).?.result_id = decl_id;
 
             // Now emit the instructions that initialize the variable.
             const initializer_id = self.spv.allocId();
@@ -1517,14 +1516,12 @@ pub const DeclGen = struct {
             });
 
             // TODO: We should be able to get rid of this by now...
-            self.spv.endGlobal(spv_decl_index, begin);
+            self.spv.endGlobal(spv_decl_index, begin, decl_id, initializer_id);
 
             try self.func.body.emit(self.spv.gpa, .OpReturn, {});
             try self.func.body.emit(self.spv.gpa, .OpFunctionEnd, {});
             try self.spv.addFunction(spv_decl_index, self.func);
 
-            try self.spv.initializers.append(self.spv.gpa, initializer_id);
-
             const fqn = ip.stringToSlice(try decl.getFullyQualifiedName(self.module));
             try self.spv.sections.debug_names.emit(self.gpa, .OpName, .{
                 .target = decl_id,
test/behavior/array.zig
@@ -48,7 +48,6 @@ fn getArrayLen(a: []const u32) usize {
 test "array concat with undefined" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {
@@ -88,7 +87,6 @@ test "array concat with tuple" {
 
 test "array init with concat" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const a = 'a';
     var i: [4]u8 = [2]u8{ a, 'b' } ++ [2]u8{ 'c', 'd' };
@@ -98,7 +96,6 @@ test "array init with concat" {
 test "array init with mult" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const a = 'a';
     var i: [8]u8 = [2]u8{ a, 'b' } ** 4;
@@ -241,7 +238,6 @@ fn plusOne(x: u32) u32 {
 test "single-item pointer to array indexing and slicing" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try testSingleItemPtrArrayIndexSlice();
     try comptime testSingleItemPtrArrayIndexSlice();
@@ -384,7 +380,6 @@ test "runtime initialize array elem and then implicit cast to slice" {
 test "array literal as argument to function" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn entry(two: i32) !void {
@@ -413,7 +408,6 @@ test "double nested array to const slice cast in array literal" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn entry(two: i32) !void {
@@ -651,7 +645,6 @@ test "tuple to array handles sentinel" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         const a = .{ 1, 2, 3 };
test/behavior/basic.zig
@@ -330,7 +330,6 @@ const FnPtrWrapper = struct {
 
 test "const ptr from var variable" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var x: u64 = undefined;
     var y: u64 = undefined;
@@ -581,7 +580,7 @@ test "comptime cast fn to ptr" {
 }
 
 test "equality compare fn ptrs" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // Test passes but should not
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var a = &emptyFn;
     try expect(a == a);
@@ -639,7 +638,6 @@ test "global constant is loaded with a runtime-known index" {
 
 test "multiline string literal is null terminated" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const s1 =
         \\one
@@ -711,7 +709,6 @@ test "comptime manyptr concatenation" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const s = "epic";
     const actual = manyptrConcat(s);
@@ -1027,7 +1024,6 @@ comptime {
 
 test "switch inside @as gets correct type" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var a: u32 = 0;
     var b: [2]u32 = undefined;
@@ -1136,8 +1132,6 @@ test "orelse coercion as function argument" {
 }
 
 test "runtime-known globals initialized with undefined" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const S = struct {
         var array: [10]u32 = [_]u32{ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
         var vp: [*]u32 = undefined;