master
  1const std = @import("std");
  2const Allocator = std.mem.Allocator;
  3const assert = std.debug.assert;
  4const log = std.log.scoped(.spirv_link);
  5
  6const BinaryModule = @import("BinaryModule.zig");
  7const Section = @import("../../codegen/spirv/Section.zig");
  8const spec = @import("../../codegen/spirv/spec.zig");
  9const ResultId = spec.Id;
 10const Word = spec.Word;
 11
 12/// This structure contains all the stuff that we need to parse from the module in
 13/// order to run this pass, as well as some functions to ease its use.
 14const ModuleInfo = struct {
 15    /// Information about a particular function.
 16    const Fn = struct {
 17        /// The index of the first callee in `callee_store`.
 18        first_callee: usize,
 19        /// The return type id of this function
 20        return_type: ResultId,
 21        /// The parameter types of this function
 22        param_types: []const ResultId,
 23        /// The set of (result-id's of) invocation globals that are accessed
 24        /// in this function, or after resolution, that are accessed in this
 25        /// function or any of it's callees.
 26        invocation_globals: std.AutoArrayHashMapUnmanaged(ResultId, void),
 27    };
 28
 29    /// Information about a particular invocation global
 30    const InvocationGlobal = struct {
 31        /// The list of invocation globals that this invocation global
 32        /// depends on.
 33        dependencies: std.AutoArrayHashMapUnmanaged(ResultId, void),
 34        /// The invocation global's type
 35        ty: ResultId,
 36        /// Initializer function. May be `none`.
 37        /// Note that if the initializer is `none`, then `dependencies` is empty.
 38        initializer: ResultId,
 39    };
 40
 41    /// Maps function result-id -> Fn information structure.
 42    functions: std.AutoArrayHashMapUnmanaged(ResultId, Fn),
 43    /// Set of OpFunction result-ids in this module.
 44    entry_points: std.AutoArrayHashMapUnmanaged(ResultId, void),
 45    /// For each function, a list of function result-ids that it calls.
 46    callee_store: []const ResultId,
 47    /// Maps each invocation global result-id to a type-id.
 48    invocation_globals: std.AutoArrayHashMapUnmanaged(ResultId, InvocationGlobal),
 49
 50    /// Fetch the list of callees per function. Guaranteed to contain only unique IDs.
 51    fn callees(self: ModuleInfo, fn_id: ResultId) []const ResultId {
 52        const fn_index = self.functions.getIndex(fn_id).?;
 53        const values = self.functions.values();
 54        const first_callee = values[fn_index].first_callee;
 55        if (fn_index == values.len - 1) {
 56            return self.callee_store[first_callee..];
 57        } else {
 58            const next_first_callee = values[fn_index + 1].first_callee;
 59            return self.callee_store[first_callee..next_first_callee];
 60        }
 61    }
 62
 63    /// Extract most of the required information from the binary. The remaining info is
 64    /// constructed by `resolve()`.
 65    fn parse(
 66        arena: Allocator,
 67        parser: *BinaryModule.Parser,
 68        binary: BinaryModule,
 69    ) BinaryModule.ParseError!ModuleInfo {
 70        var entry_points = std.AutoArrayHashMap(ResultId, void).init(arena);
 71        var functions = std.AutoArrayHashMap(ResultId, Fn).init(arena);
 72        var fn_types = std.AutoHashMap(ResultId, struct {
 73            return_type: ResultId,
 74            param_types: []const ResultId,
 75        }).init(arena);
 76        var calls = std.AutoArrayHashMap(ResultId, void).init(arena);
 77        var callee_store = std.array_list.Managed(ResultId).init(arena);
 78        var function_invocation_globals = std.AutoArrayHashMap(ResultId, void).init(arena);
 79        var result_id_offsets = std.array_list.Managed(u16).init(arena);
 80        var invocation_globals = std.AutoArrayHashMap(ResultId, InvocationGlobal).init(arena);
 81
 82        var maybe_current_function: ?ResultId = null;
 83        var fn_ty_id: ResultId = undefined;
 84
 85        var it = binary.iterateInstructions();
 86        while (it.next()) |inst| {
 87            result_id_offsets.items.len = 0;
 88            try parser.parseInstructionResultIds(binary, inst, &result_id_offsets);
 89
 90            switch (inst.opcode) {
 91                .OpEntryPoint => {
 92                    const entry_point: ResultId = @enumFromInt(inst.operands[1]);
 93                    const entry = try entry_points.getOrPut(entry_point);
 94                    if (entry.found_existing) {
 95                        log.err("Entry point type {f} has duplicate definition", .{entry_point});
 96                        return error.DuplicateId;
 97                    }
 98                },
 99                .OpTypeFunction => {
100                    const fn_type: ResultId = @enumFromInt(inst.operands[0]);
101                    const return_type: ResultId = @enumFromInt(inst.operands[1]);
102                    const param_types: []const ResultId = @ptrCast(inst.operands[2..]);
103
104                    const entry = try fn_types.getOrPut(fn_type);
105                    if (entry.found_existing) {
106                        log.err("Function type {f} has duplicate definition", .{fn_type});
107                        return error.DuplicateId;
108                    }
109
110                    entry.value_ptr.* = .{
111                        .return_type = return_type,
112                        .param_types = param_types,
113                    };
114                },
115                .OpExtInst => {
116                    // Note: format and set are already verified by parseInstructionResultIds().
117                    const global_type: ResultId = @enumFromInt(inst.operands[0]);
118                    const result_id: ResultId = @enumFromInt(inst.operands[1]);
119                    const set_id: ResultId = @enumFromInt(inst.operands[2]);
120                    const set_inst = inst.operands[3];
121
122                    const set = binary.ext_inst_map.get(set_id).?;
123                    if (set == .zig and set_inst == 0) {
124                        const initializer: ResultId = if (inst.operands.len >= 5)
125                            @enumFromInt(inst.operands[4])
126                        else
127                            .none;
128
129                        try invocation_globals.put(result_id, .{
130                            .dependencies = .{},
131                            .ty = global_type,
132                            .initializer = initializer,
133                        });
134                    }
135                },
136                .OpFunction => {
137                    if (maybe_current_function) |current_function| {
138                        log.err("OpFunction {f} does not have an OpFunctionEnd", .{current_function});
139                        return error.InvalidPhysicalFormat;
140                    }
141
142                    maybe_current_function = @enumFromInt(inst.operands[1]);
143                    fn_ty_id = @enumFromInt(inst.operands[3]);
144                    function_invocation_globals.clearRetainingCapacity();
145                },
146                .OpFunctionCall => {
147                    const callee: ResultId = @enumFromInt(inst.operands[2]);
148                    try calls.put(callee, {});
149                },
150                .OpFunctionEnd => {
151                    const current_function = maybe_current_function orelse {
152                        log.err("encountered OpFunctionEnd without corresponding OpFunction", .{});
153                        return error.InvalidPhysicalFormat;
154                    };
155                    const entry = try functions.getOrPut(current_function);
156                    if (entry.found_existing) {
157                        log.err("Function {f} has duplicate definition", .{current_function});
158                        return error.DuplicateId;
159                    }
160
161                    const first_callee = callee_store.items.len;
162                    try callee_store.appendSlice(calls.keys());
163
164                    const fn_type = fn_types.get(fn_ty_id) orelse {
165                        log.err("Function {f} has invalid OpFunction type", .{current_function});
166                        return error.InvalidId;
167                    };
168
169                    entry.value_ptr.* = .{
170                        .first_callee = first_callee,
171                        .return_type = fn_type.return_type,
172                        .param_types = fn_type.param_types,
173                        .invocation_globals = try function_invocation_globals.unmanaged.clone(arena),
174                    };
175                    maybe_current_function = null;
176                    calls.clearRetainingCapacity();
177                },
178                else => {},
179            }
180
181            for (result_id_offsets.items) |off| {
182                const result_id: ResultId = @enumFromInt(inst.operands[off]);
183                if (invocation_globals.contains(result_id)) {
184                    try function_invocation_globals.put(result_id, {});
185                }
186            }
187        }
188
189        if (maybe_current_function) |current_function| {
190            log.err("OpFunction {f} does not have an OpFunctionEnd", .{current_function});
191            return error.InvalidPhysicalFormat;
192        }
193
194        return ModuleInfo{
195            .functions = functions.unmanaged,
196            .entry_points = entry_points.unmanaged,
197            .callee_store = callee_store.items,
198            .invocation_globals = invocation_globals.unmanaged,
199        };
200    }
201
202    /// Derive the remaining info from the structures filled in by parsing.
203    fn resolve(self: *ModuleInfo, arena: Allocator) !void {
204        try self.resolveInvocationGlobalUsage(arena);
205        try self.resolveInvocationGlobalDependencies(arena);
206    }
207
208    /// For each function, extend the list of `invocation_globals` with the
209    /// invocation globals that ALL of its dependencies use.
210    fn resolveInvocationGlobalUsage(self: *ModuleInfo, arena: Allocator) !void {
211        var seen = try std.DynamicBitSetUnmanaged.initEmpty(arena, self.functions.count());
212
213        for (self.functions.keys()) |id| {
214            try self.resolveInvocationGlobalUsageStep(arena, id, &seen);
215        }
216    }
217
218    fn resolveInvocationGlobalUsageStep(
219        self: *ModuleInfo,
220        arena: Allocator,
221        id: ResultId,
222        seen: *std.DynamicBitSetUnmanaged,
223    ) !void {
224        const index = self.functions.getIndex(id) orelse {
225            log.err("function calls invalid function {f}", .{id});
226            return error.InvalidId;
227        };
228
229        if (seen.isSet(index)) {
230            return;
231        }
232        seen.set(index);
233
234        const info = &self.functions.values()[index];
235        for (self.callees(id)) |callee| {
236            try self.resolveInvocationGlobalUsageStep(arena, callee, seen);
237            const callee_info = self.functions.get(callee).?;
238            for (callee_info.invocation_globals.keys()) |global| {
239                try info.invocation_globals.put(arena, global, {});
240            }
241        }
242    }
243
244    /// For each invocation global, populate and fully resolve the `dependencies` set.
245    /// This requires `resolveInvocationGlobalUsage()` to be already done.
246    fn resolveInvocationGlobalDependencies(
247        self: *ModuleInfo,
248        arena: Allocator,
249    ) !void {
250        var seen = try std.DynamicBitSetUnmanaged.initEmpty(arena, self.invocation_globals.count());
251
252        for (self.invocation_globals.keys()) |id| {
253            try self.resolveInvocationGlobalDependenciesStep(arena, id, &seen);
254        }
255    }
256
257    fn resolveInvocationGlobalDependenciesStep(
258        self: *ModuleInfo,
259        arena: Allocator,
260        id: ResultId,
261        seen: *std.DynamicBitSetUnmanaged,
262    ) !void {
263        const index = self.invocation_globals.getIndex(id) orelse {
264            log.err("invalid invocation global {f}", .{id});
265            return error.InvalidId;
266        };
267
268        if (seen.isSet(index)) {
269            return;
270        }
271        seen.set(index);
272
273        const info = &self.invocation_globals.values()[index];
274        if (info.initializer == .none) {
275            return;
276        }
277
278        const initializer = self.functions.get(info.initializer) orelse {
279            log.err("invocation global {f} has invalid initializer {f}", .{ id, info.initializer });
280            return error.InvalidId;
281        };
282
283        for (initializer.invocation_globals.keys()) |dependency| {
284            if (dependency == id) {
285                // The set of invocation global dependencies includes the dependency itself,
286                // so we need to skip that case.
287                continue;
288            }
289
290            try info.dependencies.put(arena, dependency, {});
291            try self.resolveInvocationGlobalDependenciesStep(arena, dependency, seen);
292
293            const dep_info = self.invocation_globals.getPtr(dependency).?;
294
295            for (dep_info.dependencies.keys()) |global| {
296                try info.dependencies.put(arena, global, {});
297            }
298        }
299    }
300};
301
302const ModuleBuilder = struct {
303    const FunctionType = struct {
304        return_type: ResultId,
305        param_types: []const ResultId,
306
307        const Context = struct {
308            pub fn hash(_: @This(), ty: FunctionType) u32 {
309                var hasher = std.hash.Wyhash.init(0);
310                hasher.update(std.mem.asBytes(&ty.return_type));
311                hasher.update(std.mem.sliceAsBytes(ty.param_types));
312                return @truncate(hasher.final());
313            }
314
315            pub fn eql(_: @This(), a: FunctionType, b: FunctionType, _: usize) bool {
316                if (a.return_type != b.return_type) return false;
317                return std.mem.eql(ResultId, a.param_types, b.param_types);
318            }
319        };
320    };
321
322    const FunctionNewInfo = struct {
323        /// This is here just so that we don't need to allocate the new
324        /// param_types multiple times.
325        new_function_type: ResultId,
326        /// The first ID of the parameters for the invocation globals.
327        /// Each global is allocate here according to the index in
328        /// `ModuleInfo.Fn.invocation_globals`.
329        global_id_base: u32,
330
331        fn invocationGlobalId(self: FunctionNewInfo, index: usize) ResultId {
332            return @enumFromInt(self.global_id_base + @as(u32, @intCast(index)));
333        }
334    };
335
336    arena: Allocator,
337    section: Section,
338    /// The ID bound of the new module.
339    id_bound: u32,
340    /// The first ID of the new entry points. Entry points are allocated from
341    /// here according to their index in `info.entry_points`.
342    entry_point_new_id_base: u32,
343    /// A set of all function types in the new program. SPIR-V mandates that these are unique,
344    /// and until a general type deduplication pass is programmed, we just handle it here via this.
345    function_types: std.ArrayHashMapUnmanaged(FunctionType, ResultId, FunctionType.Context, true) = .empty,
346    /// Maps functions to new information required for creating the module
347    function_new_info: std.AutoArrayHashMapUnmanaged(ResultId, FunctionNewInfo) = .empty,
348    /// Offset of the functions section in the new binary.
349    new_functions_section: ?usize,
350
351    fn init(arena: Allocator, binary: BinaryModule, info: ModuleInfo) !ModuleBuilder {
352        var self = ModuleBuilder{
353            .arena = arena,
354            .section = .{},
355            .id_bound = binary.id_bound,
356            .entry_point_new_id_base = undefined,
357            .new_functions_section = null,
358        };
359        self.entry_point_new_id_base = @intFromEnum(self.allocIds(@intCast(info.entry_points.count())));
360        return self;
361    }
362
363    fn allocId(self: *ModuleBuilder) ResultId {
364        return self.allocIds(1);
365    }
366
367    fn allocIds(self: *ModuleBuilder, n: u32) ResultId {
368        defer self.id_bound += n;
369        return @enumFromInt(self.id_bound);
370    }
371
372    fn finalize(self: *ModuleBuilder, a: Allocator, binary: *BinaryModule) !void {
373        binary.id_bound = self.id_bound;
374        binary.instructions = try a.dupe(Word, self.section.instructions.items);
375        // Nothing is removed in this pass so we don't need to change any of the maps,
376        // just make sure the section is updated.
377        binary.sections.functions = self.new_functions_section orelse binary.instructions.len;
378    }
379
380    /// Process everything from `binary` up to the first function and emit it into the builder.
381    fn processPreamble(self: *ModuleBuilder, binary: BinaryModule, info: ModuleInfo) !void {
382        var it = binary.iterateInstructions();
383        while (it.next()) |inst| {
384            switch (inst.opcode) {
385                .OpName => {
386                    const id: ResultId = @enumFromInt(inst.operands[0]);
387                    if (info.invocation_globals.contains(id)) continue;
388                },
389                .OpExtInstImport => {
390                    const set_id: ResultId = @enumFromInt(inst.operands[0]);
391                    const set = binary.ext_inst_map.get(set_id).?;
392                    if (set == .zig) continue;
393                },
394                .OpExtInst => {
395                    const set_id: ResultId = @enumFromInt(inst.operands[2]);
396                    const set_inst = inst.operands[3];
397                    const set = binary.ext_inst_map.get(set_id).?;
398                    if (set == .zig and set_inst == 0) {
399                        continue;
400                    }
401                },
402                .OpEntryPoint => {
403                    const original_id: ResultId = @enumFromInt(inst.operands[1]);
404                    const new_id_index = info.entry_points.getIndex(original_id).?;
405                    const new_id: ResultId = @enumFromInt(self.entry_point_new_id_base + new_id_index);
406                    try self.section.emitRaw(self.arena, .OpEntryPoint, inst.operands.len);
407                    self.section.writeWord(inst.operands[0]);
408                    self.section.writeOperand(ResultId, new_id);
409                    self.section.writeWords(inst.operands[2..]);
410                    continue;
411                },
412                .OpExecutionMode, .OpExecutionModeId => {
413                    const original_id: ResultId = @enumFromInt(inst.operands[0]);
414                    const new_id_index = info.entry_points.getIndex(original_id).?;
415                    const new_id: ResultId = @enumFromInt(self.entry_point_new_id_base + new_id_index);
416                    try self.section.emitRaw(self.arena, inst.opcode, inst.operands.len);
417                    self.section.writeOperand(ResultId, new_id);
418                    self.section.writeWords(inst.operands[1..]);
419                    continue;
420                },
421                .OpTypeFunction => {
422                    // Re-emitted in `emitFunctionTypes()`. We can do this because
423                    // OpTypeFunction's may not currently be used anywhere that is not
424                    // directly with an OpFunction. For now we ignore Intels function
425                    // pointers extension, that is not a problem with a generalized
426                    // pass anyway.
427                    continue;
428                },
429                .OpFunction => break,
430                else => {},
431            }
432
433            try self.section.emitRawInstruction(self.arena, inst.opcode, inst.operands);
434        }
435    }
436
437    /// Derive new information required for further emitting this module,
438    fn deriveNewFnInfo(self: *ModuleBuilder, info: ModuleInfo) !void {
439        for (info.functions.keys(), info.functions.values()) |func, fn_info| {
440            const invocation_global_count = fn_info.invocation_globals.count();
441            const new_param_types = try self.arena.alloc(ResultId, fn_info.param_types.len + invocation_global_count);
442            for (fn_info.invocation_globals.keys(), 0..) |global, i| {
443                new_param_types[i] = info.invocation_globals.get(global).?.ty;
444            }
445            @memcpy(new_param_types[invocation_global_count..], fn_info.param_types);
446
447            const new_type = try self.internFunctionType(fn_info.return_type, new_param_types);
448            try self.function_new_info.put(self.arena, func, .{
449                .new_function_type = new_type,
450                .global_id_base = @intFromEnum(self.allocIds(@intCast(invocation_global_count))),
451            });
452        }
453    }
454
455    /// Emit the new function types, which include the parameters for the invocation globals.
456    /// Currently, this function re-emits ALL function types to ensure that there are
457    /// no duplicates in the final program.
458    /// TODO: The above should be resolved by a generalized deduplication pass, and then
459    /// we only need to emit the new function pointers type here.
460    fn emitFunctionTypes(self: *ModuleBuilder, info: ModuleInfo) !void {
461        // TODO: Handle decorators. Function types usually don't have those
462        // though, but stuff like OpName could be a possibility.
463
464        // Entry points retain their old function type, so make sure to emit
465        // those in the `function_types` set.
466        for (info.entry_points.keys()) |func| {
467            const fn_info = info.functions.get(func).?;
468            _ = try self.internFunctionType(fn_info.return_type, fn_info.param_types);
469        }
470
471        for (self.function_types.keys(), self.function_types.values()) |fn_type, result_id| {
472            try self.section.emit(self.arena, .OpTypeFunction, .{
473                .id_result = result_id,
474                .return_type = fn_type.return_type,
475                .id_ref_2 = fn_type.param_types,
476            });
477        }
478    }
479
480    fn internFunctionType(self: *ModuleBuilder, return_type: ResultId, param_types: []const ResultId) !ResultId {
481        const entry = try self.function_types.getOrPut(self.arena, .{
482            .return_type = return_type,
483            .param_types = param_types,
484        });
485
486        if (!entry.found_existing) {
487            const new_id = self.allocId();
488            entry.value_ptr.* = new_id;
489        }
490
491        return entry.value_ptr.*;
492    }
493
494    /// Rewrite the modules functions and emit them with the new parameter types.
495    fn rewriteFunctions(
496        self: *ModuleBuilder,
497        parser: *BinaryModule.Parser,
498        binary: BinaryModule,
499        info: ModuleInfo,
500    ) !void {
501        var result_id_offsets = std.array_list.Managed(u16).init(self.arena);
502        var operands = std.array_list.Managed(u32).init(self.arena);
503
504        var maybe_current_function: ?ResultId = null;
505        var it = binary.iterateInstructionsFrom(binary.sections.functions);
506        self.new_functions_section = self.section.instructions.items.len;
507        while (it.next()) |inst| {
508            result_id_offsets.items.len = 0;
509            try parser.parseInstructionResultIds(binary, inst, &result_id_offsets);
510
511            operands.items.len = 0;
512            try operands.appendSlice(inst.operands);
513
514            // Replace the result-ids with the global's new result-id if required.
515            for (result_id_offsets.items) |off| {
516                const result_id: ResultId = @enumFromInt(operands.items[off]);
517                if (info.invocation_globals.contains(result_id)) {
518                    const func = maybe_current_function.?;
519                    const new_info = self.function_new_info.get(func).?;
520                    const fn_info = info.functions.get(func).?;
521                    const index = fn_info.invocation_globals.getIndex(result_id).?;
522                    operands.items[off] = @intFromEnum(new_info.invocationGlobalId(index));
523                }
524            }
525
526            switch (inst.opcode) {
527                .OpFunction => {
528                    // Re-declare the function with the new parameters.
529                    const func: ResultId = @enumFromInt(operands.items[1]);
530                    const fn_info = info.functions.get(func).?;
531                    const new_info = self.function_new_info.get(func).?;
532
533                    try self.section.emitRaw(self.arena, .OpFunction, 4);
534                    self.section.writeOperand(ResultId, fn_info.return_type);
535                    self.section.writeOperand(ResultId, func);
536                    self.section.writeWord(operands.items[2]);
537                    self.section.writeOperand(ResultId, new_info.new_function_type);
538
539                    // Emit the OpFunctionParameters for the invocation globals. The functions
540                    // actual parameters are emitted unchanged from their original form, so
541                    // we don't need to handle those here.
542
543                    for (fn_info.invocation_globals.keys(), 0..) |global, index| {
544                        const ty = info.invocation_globals.get(global).?.ty;
545                        const id = new_info.invocationGlobalId(index);
546                        try self.section.emit(self.arena, .OpFunctionParameter, .{
547                            .id_result_type = ty,
548                            .id_result = id,
549                        });
550                    }
551
552                    maybe_current_function = func;
553                },
554                .OpFunctionCall => {
555                    // Add the required invocation globals to the function's new parameter list.
556                    const caller = maybe_current_function.?;
557                    const callee: ResultId = @enumFromInt(operands.items[2]);
558                    const caller_info = info.functions.get(caller).?;
559                    const callee_info = info.functions.get(callee).?;
560                    const caller_new_info = self.function_new_info.get(caller).?;
561                    const total_params = callee_info.invocation_globals.count() + callee_info.param_types.len;
562
563                    try self.section.emitRaw(self.arena, .OpFunctionCall, 3 + total_params);
564                    self.section.writeWord(operands.items[0]); // Copy result type-id
565                    self.section.writeWord(operands.items[1]); // Copy result-id
566                    self.section.writeOperand(ResultId, callee);
567
568                    // Add the new arguments
569                    for (callee_info.invocation_globals.keys()) |global| {
570                        const caller_global_index = caller_info.invocation_globals.getIndex(global).?;
571                        const id = caller_new_info.invocationGlobalId(caller_global_index);
572                        self.section.writeOperand(ResultId, id);
573                    }
574
575                    // Add the original arguments
576                    self.section.writeWords(operands.items[3..]);
577                },
578                else => {
579                    try self.section.emitRawInstruction(self.arena, inst.opcode, operands.items);
580                },
581            }
582        }
583    }
584
585    fn emitNewEntryPoints(self: *ModuleBuilder, info: ModuleInfo) !void {
586        var all_function_invocation_globals = std.AutoArrayHashMap(ResultId, void).init(self.arena);
587
588        for (info.entry_points.keys(), 0..) |func, entry_point_index| {
589            const fn_info = info.functions.get(func).?;
590            const ep_id: ResultId = @enumFromInt(self.entry_point_new_id_base + @as(u32, @intCast(entry_point_index)));
591            const fn_type = self.function_types.get(.{
592                .return_type = fn_info.return_type,
593                .param_types = fn_info.param_types,
594            }).?;
595
596            try self.section.emit(self.arena, .OpFunction, .{
597                .id_result_type = fn_info.return_type,
598                .id_result = ep_id,
599                .function_control = .{}, // TODO: Copy the attributes from the original function maybe?
600                .function_type = fn_type,
601            });
602
603            // Emit OpFunctionParameter instructions for the original kernel's parameters.
604            const params_id_base: u32 = @intFromEnum(self.allocIds(@intCast(fn_info.param_types.len)));
605            for (fn_info.param_types, 0..) |param_type, i| {
606                const id: ResultId = @enumFromInt(params_id_base + @as(u32, @intCast(i)));
607                try self.section.emit(self.arena, .OpFunctionParameter, .{
608                    .id_result_type = param_type,
609                    .id_result = id,
610                });
611            }
612
613            try self.section.emit(self.arena, .OpLabel, .{
614                .id_result = self.allocId(),
615            });
616
617            // Besides the IDs of the main kernel, we also need the
618            // dependencies of the globals.
619            // Just quickly construct that set here.
620            all_function_invocation_globals.clearRetainingCapacity();
621            for (fn_info.invocation_globals.keys()) |global| {
622                try all_function_invocation_globals.put(global, {});
623                const global_info = info.invocation_globals.get(global).?;
624                for (global_info.dependencies.keys()) |dependency| {
625                    try all_function_invocation_globals.put(dependency, {});
626                }
627            }
628
629            // Declare the IDs of the invocation globals.
630            const global_id_base: u32 = @intFromEnum(self.allocIds(@intCast(all_function_invocation_globals.count())));
631            for (all_function_invocation_globals.keys(), 0..) |global, i| {
632                const global_info = info.invocation_globals.get(global).?;
633
634                const id: ResultId = @enumFromInt(global_id_base + @as(u32, @intCast(i)));
635                try self.section.emit(self.arena, .OpVariable, .{
636                    .id_result_type = global_info.ty,
637                    .id_result = id,
638                    .storage_class = .function,
639                    .initializer = null,
640                });
641            }
642
643            // Call initializers for invocation globals that need it
644            for (all_function_invocation_globals.keys()) |global| {
645                const global_info = info.invocation_globals.get(global).?;
646                if (global_info.initializer == .none) continue;
647
648                const initializer_info = info.functions.get(global_info.initializer).?;
649                assert(initializer_info.param_types.len == 0);
650
651                try self.callWithGlobalsAndLinearParams(
652                    all_function_invocation_globals,
653                    global_info.initializer,
654                    initializer_info,
655                    global_id_base,
656                    undefined,
657                );
658            }
659
660            // Call the main kernel entry
661            try self.callWithGlobalsAndLinearParams(
662                all_function_invocation_globals,
663                func,
664                fn_info,
665                global_id_base,
666                params_id_base,
667            );
668
669            try self.section.emit(self.arena, .OpReturn, {});
670            try self.section.emit(self.arena, .OpFunctionEnd, {});
671        }
672    }
673
674    fn callWithGlobalsAndLinearParams(
675        self: *ModuleBuilder,
676        all_globals: std.AutoArrayHashMap(ResultId, void),
677        func: ResultId,
678        callee_info: ModuleInfo.Fn,
679        global_id_base: u32,
680        params_id_base: u32,
681    ) !void {
682        const total_arguments = callee_info.invocation_globals.count() + callee_info.param_types.len;
683        try self.section.emitRaw(self.arena, .OpFunctionCall, 3 + total_arguments);
684        self.section.writeOperand(ResultId, callee_info.return_type);
685        self.section.writeOperand(ResultId, self.allocId());
686        self.section.writeOperand(ResultId, func);
687
688        // Add the invocation globals
689        for (callee_info.invocation_globals.keys()) |global| {
690            const index = all_globals.getIndex(global).?;
691            const id: ResultId = @enumFromInt(global_id_base + @as(u32, @intCast(index)));
692            self.section.writeOperand(ResultId, id);
693        }
694
695        // Add the arguments
696        for (0..callee_info.param_types.len) |index| {
697            const id: ResultId = @enumFromInt(params_id_base + @as(u32, @intCast(index)));
698            self.section.writeOperand(ResultId, id);
699        }
700    }
701};
702
703pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule, progress: std.Progress.Node) !void {
704    const sub_node = progress.start("Lower invocation globals", 6);
705    defer sub_node.end();
706
707    var arena = std.heap.ArenaAllocator.init(parser.a);
708    defer arena.deinit();
709    const a = arena.allocator();
710
711    var info = try ModuleInfo.parse(a, parser, binary.*);
712    try info.resolve(a);
713
714    var builder = try ModuleBuilder.init(a, binary.*, info);
715    sub_node.completeOne();
716    try builder.deriveNewFnInfo(info);
717    sub_node.completeOne();
718    try builder.processPreamble(binary.*, info);
719    sub_node.completeOne();
720    try builder.emitFunctionTypes(info);
721    sub_node.completeOne();
722    try builder.rewriteFunctions(parser, binary.*, info);
723    sub_node.completeOne();
724    try builder.emitNewEntryPoints(info);
725    sub_node.completeOne();
726    try builder.finalize(parser.a, binary);
727}