master
1const std = @import("std");
2const Allocator = std.mem.Allocator;
3const assert = std.debug.assert;
4const log = std.log.scoped(.spirv_link);
5
6const BinaryModule = @import("BinaryModule.zig");
7const Section = @import("../../codegen/spirv/Section.zig");
8const spec = @import("../../codegen/spirv/spec.zig");
9const ResultId = spec.Id;
10const Word = spec.Word;
11
12/// This structure contains all the stuff that we need to parse from the module in
13/// order to run this pass, as well as some functions to ease its use.
14const ModuleInfo = struct {
15 /// Information about a particular function.
16 const Fn = struct {
17 /// The index of the first callee in `callee_store`.
18 first_callee: usize,
19 /// The return type id of this function
20 return_type: ResultId,
21 /// The parameter types of this function
22 param_types: []const ResultId,
23 /// The set of (result-id's of) invocation globals that are accessed
24 /// in this function, or after resolution, that are accessed in this
25 /// function or any of it's callees.
26 invocation_globals: std.AutoArrayHashMapUnmanaged(ResultId, void),
27 };
28
29 /// Information about a particular invocation global
30 const InvocationGlobal = struct {
31 /// The list of invocation globals that this invocation global
32 /// depends on.
33 dependencies: std.AutoArrayHashMapUnmanaged(ResultId, void),
34 /// The invocation global's type
35 ty: ResultId,
36 /// Initializer function. May be `none`.
37 /// Note that if the initializer is `none`, then `dependencies` is empty.
38 initializer: ResultId,
39 };
40
41 /// Maps function result-id -> Fn information structure.
42 functions: std.AutoArrayHashMapUnmanaged(ResultId, Fn),
43 /// Set of OpFunction result-ids in this module.
44 entry_points: std.AutoArrayHashMapUnmanaged(ResultId, void),
45 /// For each function, a list of function result-ids that it calls.
46 callee_store: []const ResultId,
47 /// Maps each invocation global result-id to a type-id.
48 invocation_globals: std.AutoArrayHashMapUnmanaged(ResultId, InvocationGlobal),
49
50 /// Fetch the list of callees per function. Guaranteed to contain only unique IDs.
51 fn callees(self: ModuleInfo, fn_id: ResultId) []const ResultId {
52 const fn_index = self.functions.getIndex(fn_id).?;
53 const values = self.functions.values();
54 const first_callee = values[fn_index].first_callee;
55 if (fn_index == values.len - 1) {
56 return self.callee_store[first_callee..];
57 } else {
58 const next_first_callee = values[fn_index + 1].first_callee;
59 return self.callee_store[first_callee..next_first_callee];
60 }
61 }
62
63 /// Extract most of the required information from the binary. The remaining info is
64 /// constructed by `resolve()`.
65 fn parse(
66 arena: Allocator,
67 parser: *BinaryModule.Parser,
68 binary: BinaryModule,
69 ) BinaryModule.ParseError!ModuleInfo {
70 var entry_points = std.AutoArrayHashMap(ResultId, void).init(arena);
71 var functions = std.AutoArrayHashMap(ResultId, Fn).init(arena);
72 var fn_types = std.AutoHashMap(ResultId, struct {
73 return_type: ResultId,
74 param_types: []const ResultId,
75 }).init(arena);
76 var calls = std.AutoArrayHashMap(ResultId, void).init(arena);
77 var callee_store = std.array_list.Managed(ResultId).init(arena);
78 var function_invocation_globals = std.AutoArrayHashMap(ResultId, void).init(arena);
79 var result_id_offsets = std.array_list.Managed(u16).init(arena);
80 var invocation_globals = std.AutoArrayHashMap(ResultId, InvocationGlobal).init(arena);
81
82 var maybe_current_function: ?ResultId = null;
83 var fn_ty_id: ResultId = undefined;
84
85 var it = binary.iterateInstructions();
86 while (it.next()) |inst| {
87 result_id_offsets.items.len = 0;
88 try parser.parseInstructionResultIds(binary, inst, &result_id_offsets);
89
90 switch (inst.opcode) {
91 .OpEntryPoint => {
92 const entry_point: ResultId = @enumFromInt(inst.operands[1]);
93 const entry = try entry_points.getOrPut(entry_point);
94 if (entry.found_existing) {
95 log.err("Entry point type {f} has duplicate definition", .{entry_point});
96 return error.DuplicateId;
97 }
98 },
99 .OpTypeFunction => {
100 const fn_type: ResultId = @enumFromInt(inst.operands[0]);
101 const return_type: ResultId = @enumFromInt(inst.operands[1]);
102 const param_types: []const ResultId = @ptrCast(inst.operands[2..]);
103
104 const entry = try fn_types.getOrPut(fn_type);
105 if (entry.found_existing) {
106 log.err("Function type {f} has duplicate definition", .{fn_type});
107 return error.DuplicateId;
108 }
109
110 entry.value_ptr.* = .{
111 .return_type = return_type,
112 .param_types = param_types,
113 };
114 },
115 .OpExtInst => {
116 // Note: format and set are already verified by parseInstructionResultIds().
117 const global_type: ResultId = @enumFromInt(inst.operands[0]);
118 const result_id: ResultId = @enumFromInt(inst.operands[1]);
119 const set_id: ResultId = @enumFromInt(inst.operands[2]);
120 const set_inst = inst.operands[3];
121
122 const set = binary.ext_inst_map.get(set_id).?;
123 if (set == .zig and set_inst == 0) {
124 const initializer: ResultId = if (inst.operands.len >= 5)
125 @enumFromInt(inst.operands[4])
126 else
127 .none;
128
129 try invocation_globals.put(result_id, .{
130 .dependencies = .{},
131 .ty = global_type,
132 .initializer = initializer,
133 });
134 }
135 },
136 .OpFunction => {
137 if (maybe_current_function) |current_function| {
138 log.err("OpFunction {f} does not have an OpFunctionEnd", .{current_function});
139 return error.InvalidPhysicalFormat;
140 }
141
142 maybe_current_function = @enumFromInt(inst.operands[1]);
143 fn_ty_id = @enumFromInt(inst.operands[3]);
144 function_invocation_globals.clearRetainingCapacity();
145 },
146 .OpFunctionCall => {
147 const callee: ResultId = @enumFromInt(inst.operands[2]);
148 try calls.put(callee, {});
149 },
150 .OpFunctionEnd => {
151 const current_function = maybe_current_function orelse {
152 log.err("encountered OpFunctionEnd without corresponding OpFunction", .{});
153 return error.InvalidPhysicalFormat;
154 };
155 const entry = try functions.getOrPut(current_function);
156 if (entry.found_existing) {
157 log.err("Function {f} has duplicate definition", .{current_function});
158 return error.DuplicateId;
159 }
160
161 const first_callee = callee_store.items.len;
162 try callee_store.appendSlice(calls.keys());
163
164 const fn_type = fn_types.get(fn_ty_id) orelse {
165 log.err("Function {f} has invalid OpFunction type", .{current_function});
166 return error.InvalidId;
167 };
168
169 entry.value_ptr.* = .{
170 .first_callee = first_callee,
171 .return_type = fn_type.return_type,
172 .param_types = fn_type.param_types,
173 .invocation_globals = try function_invocation_globals.unmanaged.clone(arena),
174 };
175 maybe_current_function = null;
176 calls.clearRetainingCapacity();
177 },
178 else => {},
179 }
180
181 for (result_id_offsets.items) |off| {
182 const result_id: ResultId = @enumFromInt(inst.operands[off]);
183 if (invocation_globals.contains(result_id)) {
184 try function_invocation_globals.put(result_id, {});
185 }
186 }
187 }
188
189 if (maybe_current_function) |current_function| {
190 log.err("OpFunction {f} does not have an OpFunctionEnd", .{current_function});
191 return error.InvalidPhysicalFormat;
192 }
193
194 return ModuleInfo{
195 .functions = functions.unmanaged,
196 .entry_points = entry_points.unmanaged,
197 .callee_store = callee_store.items,
198 .invocation_globals = invocation_globals.unmanaged,
199 };
200 }
201
202 /// Derive the remaining info from the structures filled in by parsing.
203 fn resolve(self: *ModuleInfo, arena: Allocator) !void {
204 try self.resolveInvocationGlobalUsage(arena);
205 try self.resolveInvocationGlobalDependencies(arena);
206 }
207
208 /// For each function, extend the list of `invocation_globals` with the
209 /// invocation globals that ALL of its dependencies use.
210 fn resolveInvocationGlobalUsage(self: *ModuleInfo, arena: Allocator) !void {
211 var seen = try std.DynamicBitSetUnmanaged.initEmpty(arena, self.functions.count());
212
213 for (self.functions.keys()) |id| {
214 try self.resolveInvocationGlobalUsageStep(arena, id, &seen);
215 }
216 }
217
218 fn resolveInvocationGlobalUsageStep(
219 self: *ModuleInfo,
220 arena: Allocator,
221 id: ResultId,
222 seen: *std.DynamicBitSetUnmanaged,
223 ) !void {
224 const index = self.functions.getIndex(id) orelse {
225 log.err("function calls invalid function {f}", .{id});
226 return error.InvalidId;
227 };
228
229 if (seen.isSet(index)) {
230 return;
231 }
232 seen.set(index);
233
234 const info = &self.functions.values()[index];
235 for (self.callees(id)) |callee| {
236 try self.resolveInvocationGlobalUsageStep(arena, callee, seen);
237 const callee_info = self.functions.get(callee).?;
238 for (callee_info.invocation_globals.keys()) |global| {
239 try info.invocation_globals.put(arena, global, {});
240 }
241 }
242 }
243
244 /// For each invocation global, populate and fully resolve the `dependencies` set.
245 /// This requires `resolveInvocationGlobalUsage()` to be already done.
246 fn resolveInvocationGlobalDependencies(
247 self: *ModuleInfo,
248 arena: Allocator,
249 ) !void {
250 var seen = try std.DynamicBitSetUnmanaged.initEmpty(arena, self.invocation_globals.count());
251
252 for (self.invocation_globals.keys()) |id| {
253 try self.resolveInvocationGlobalDependenciesStep(arena, id, &seen);
254 }
255 }
256
257 fn resolveInvocationGlobalDependenciesStep(
258 self: *ModuleInfo,
259 arena: Allocator,
260 id: ResultId,
261 seen: *std.DynamicBitSetUnmanaged,
262 ) !void {
263 const index = self.invocation_globals.getIndex(id) orelse {
264 log.err("invalid invocation global {f}", .{id});
265 return error.InvalidId;
266 };
267
268 if (seen.isSet(index)) {
269 return;
270 }
271 seen.set(index);
272
273 const info = &self.invocation_globals.values()[index];
274 if (info.initializer == .none) {
275 return;
276 }
277
278 const initializer = self.functions.get(info.initializer) orelse {
279 log.err("invocation global {f} has invalid initializer {f}", .{ id, info.initializer });
280 return error.InvalidId;
281 };
282
283 for (initializer.invocation_globals.keys()) |dependency| {
284 if (dependency == id) {
285 // The set of invocation global dependencies includes the dependency itself,
286 // so we need to skip that case.
287 continue;
288 }
289
290 try info.dependencies.put(arena, dependency, {});
291 try self.resolveInvocationGlobalDependenciesStep(arena, dependency, seen);
292
293 const dep_info = self.invocation_globals.getPtr(dependency).?;
294
295 for (dep_info.dependencies.keys()) |global| {
296 try info.dependencies.put(arena, global, {});
297 }
298 }
299 }
300};
301
302const ModuleBuilder = struct {
303 const FunctionType = struct {
304 return_type: ResultId,
305 param_types: []const ResultId,
306
307 const Context = struct {
308 pub fn hash(_: @This(), ty: FunctionType) u32 {
309 var hasher = std.hash.Wyhash.init(0);
310 hasher.update(std.mem.asBytes(&ty.return_type));
311 hasher.update(std.mem.sliceAsBytes(ty.param_types));
312 return @truncate(hasher.final());
313 }
314
315 pub fn eql(_: @This(), a: FunctionType, b: FunctionType, _: usize) bool {
316 if (a.return_type != b.return_type) return false;
317 return std.mem.eql(ResultId, a.param_types, b.param_types);
318 }
319 };
320 };
321
322 const FunctionNewInfo = struct {
323 /// This is here just so that we don't need to allocate the new
324 /// param_types multiple times.
325 new_function_type: ResultId,
326 /// The first ID of the parameters for the invocation globals.
327 /// Each global is allocate here according to the index in
328 /// `ModuleInfo.Fn.invocation_globals`.
329 global_id_base: u32,
330
331 fn invocationGlobalId(self: FunctionNewInfo, index: usize) ResultId {
332 return @enumFromInt(self.global_id_base + @as(u32, @intCast(index)));
333 }
334 };
335
336 arena: Allocator,
337 section: Section,
338 /// The ID bound of the new module.
339 id_bound: u32,
340 /// The first ID of the new entry points. Entry points are allocated from
341 /// here according to their index in `info.entry_points`.
342 entry_point_new_id_base: u32,
343 /// A set of all function types in the new program. SPIR-V mandates that these are unique,
344 /// and until a general type deduplication pass is programmed, we just handle it here via this.
345 function_types: std.ArrayHashMapUnmanaged(FunctionType, ResultId, FunctionType.Context, true) = .empty,
346 /// Maps functions to new information required for creating the module
347 function_new_info: std.AutoArrayHashMapUnmanaged(ResultId, FunctionNewInfo) = .empty,
348 /// Offset of the functions section in the new binary.
349 new_functions_section: ?usize,
350
351 fn init(arena: Allocator, binary: BinaryModule, info: ModuleInfo) !ModuleBuilder {
352 var self = ModuleBuilder{
353 .arena = arena,
354 .section = .{},
355 .id_bound = binary.id_bound,
356 .entry_point_new_id_base = undefined,
357 .new_functions_section = null,
358 };
359 self.entry_point_new_id_base = @intFromEnum(self.allocIds(@intCast(info.entry_points.count())));
360 return self;
361 }
362
363 fn allocId(self: *ModuleBuilder) ResultId {
364 return self.allocIds(1);
365 }
366
367 fn allocIds(self: *ModuleBuilder, n: u32) ResultId {
368 defer self.id_bound += n;
369 return @enumFromInt(self.id_bound);
370 }
371
372 fn finalize(self: *ModuleBuilder, a: Allocator, binary: *BinaryModule) !void {
373 binary.id_bound = self.id_bound;
374 binary.instructions = try a.dupe(Word, self.section.instructions.items);
375 // Nothing is removed in this pass so we don't need to change any of the maps,
376 // just make sure the section is updated.
377 binary.sections.functions = self.new_functions_section orelse binary.instructions.len;
378 }
379
380 /// Process everything from `binary` up to the first function and emit it into the builder.
381 fn processPreamble(self: *ModuleBuilder, binary: BinaryModule, info: ModuleInfo) !void {
382 var it = binary.iterateInstructions();
383 while (it.next()) |inst| {
384 switch (inst.opcode) {
385 .OpName => {
386 const id: ResultId = @enumFromInt(inst.operands[0]);
387 if (info.invocation_globals.contains(id)) continue;
388 },
389 .OpExtInstImport => {
390 const set_id: ResultId = @enumFromInt(inst.operands[0]);
391 const set = binary.ext_inst_map.get(set_id).?;
392 if (set == .zig) continue;
393 },
394 .OpExtInst => {
395 const set_id: ResultId = @enumFromInt(inst.operands[2]);
396 const set_inst = inst.operands[3];
397 const set = binary.ext_inst_map.get(set_id).?;
398 if (set == .zig and set_inst == 0) {
399 continue;
400 }
401 },
402 .OpEntryPoint => {
403 const original_id: ResultId = @enumFromInt(inst.operands[1]);
404 const new_id_index = info.entry_points.getIndex(original_id).?;
405 const new_id: ResultId = @enumFromInt(self.entry_point_new_id_base + new_id_index);
406 try self.section.emitRaw(self.arena, .OpEntryPoint, inst.operands.len);
407 self.section.writeWord(inst.operands[0]);
408 self.section.writeOperand(ResultId, new_id);
409 self.section.writeWords(inst.operands[2..]);
410 continue;
411 },
412 .OpExecutionMode, .OpExecutionModeId => {
413 const original_id: ResultId = @enumFromInt(inst.operands[0]);
414 const new_id_index = info.entry_points.getIndex(original_id).?;
415 const new_id: ResultId = @enumFromInt(self.entry_point_new_id_base + new_id_index);
416 try self.section.emitRaw(self.arena, inst.opcode, inst.operands.len);
417 self.section.writeOperand(ResultId, new_id);
418 self.section.writeWords(inst.operands[1..]);
419 continue;
420 },
421 .OpTypeFunction => {
422 // Re-emitted in `emitFunctionTypes()`. We can do this because
423 // OpTypeFunction's may not currently be used anywhere that is not
424 // directly with an OpFunction. For now we ignore Intels function
425 // pointers extension, that is not a problem with a generalized
426 // pass anyway.
427 continue;
428 },
429 .OpFunction => break,
430 else => {},
431 }
432
433 try self.section.emitRawInstruction(self.arena, inst.opcode, inst.operands);
434 }
435 }
436
437 /// Derive new information required for further emitting this module,
438 fn deriveNewFnInfo(self: *ModuleBuilder, info: ModuleInfo) !void {
439 for (info.functions.keys(), info.functions.values()) |func, fn_info| {
440 const invocation_global_count = fn_info.invocation_globals.count();
441 const new_param_types = try self.arena.alloc(ResultId, fn_info.param_types.len + invocation_global_count);
442 for (fn_info.invocation_globals.keys(), 0..) |global, i| {
443 new_param_types[i] = info.invocation_globals.get(global).?.ty;
444 }
445 @memcpy(new_param_types[invocation_global_count..], fn_info.param_types);
446
447 const new_type = try self.internFunctionType(fn_info.return_type, new_param_types);
448 try self.function_new_info.put(self.arena, func, .{
449 .new_function_type = new_type,
450 .global_id_base = @intFromEnum(self.allocIds(@intCast(invocation_global_count))),
451 });
452 }
453 }
454
455 /// Emit the new function types, which include the parameters for the invocation globals.
456 /// Currently, this function re-emits ALL function types to ensure that there are
457 /// no duplicates in the final program.
458 /// TODO: The above should be resolved by a generalized deduplication pass, and then
459 /// we only need to emit the new function pointers type here.
460 fn emitFunctionTypes(self: *ModuleBuilder, info: ModuleInfo) !void {
461 // TODO: Handle decorators. Function types usually don't have those
462 // though, but stuff like OpName could be a possibility.
463
464 // Entry points retain their old function type, so make sure to emit
465 // those in the `function_types` set.
466 for (info.entry_points.keys()) |func| {
467 const fn_info = info.functions.get(func).?;
468 _ = try self.internFunctionType(fn_info.return_type, fn_info.param_types);
469 }
470
471 for (self.function_types.keys(), self.function_types.values()) |fn_type, result_id| {
472 try self.section.emit(self.arena, .OpTypeFunction, .{
473 .id_result = result_id,
474 .return_type = fn_type.return_type,
475 .id_ref_2 = fn_type.param_types,
476 });
477 }
478 }
479
480 fn internFunctionType(self: *ModuleBuilder, return_type: ResultId, param_types: []const ResultId) !ResultId {
481 const entry = try self.function_types.getOrPut(self.arena, .{
482 .return_type = return_type,
483 .param_types = param_types,
484 });
485
486 if (!entry.found_existing) {
487 const new_id = self.allocId();
488 entry.value_ptr.* = new_id;
489 }
490
491 return entry.value_ptr.*;
492 }
493
494 /// Rewrite the modules functions and emit them with the new parameter types.
495 fn rewriteFunctions(
496 self: *ModuleBuilder,
497 parser: *BinaryModule.Parser,
498 binary: BinaryModule,
499 info: ModuleInfo,
500 ) !void {
501 var result_id_offsets = std.array_list.Managed(u16).init(self.arena);
502 var operands = std.array_list.Managed(u32).init(self.arena);
503
504 var maybe_current_function: ?ResultId = null;
505 var it = binary.iterateInstructionsFrom(binary.sections.functions);
506 self.new_functions_section = self.section.instructions.items.len;
507 while (it.next()) |inst| {
508 result_id_offsets.items.len = 0;
509 try parser.parseInstructionResultIds(binary, inst, &result_id_offsets);
510
511 operands.items.len = 0;
512 try operands.appendSlice(inst.operands);
513
514 // Replace the result-ids with the global's new result-id if required.
515 for (result_id_offsets.items) |off| {
516 const result_id: ResultId = @enumFromInt(operands.items[off]);
517 if (info.invocation_globals.contains(result_id)) {
518 const func = maybe_current_function.?;
519 const new_info = self.function_new_info.get(func).?;
520 const fn_info = info.functions.get(func).?;
521 const index = fn_info.invocation_globals.getIndex(result_id).?;
522 operands.items[off] = @intFromEnum(new_info.invocationGlobalId(index));
523 }
524 }
525
526 switch (inst.opcode) {
527 .OpFunction => {
528 // Re-declare the function with the new parameters.
529 const func: ResultId = @enumFromInt(operands.items[1]);
530 const fn_info = info.functions.get(func).?;
531 const new_info = self.function_new_info.get(func).?;
532
533 try self.section.emitRaw(self.arena, .OpFunction, 4);
534 self.section.writeOperand(ResultId, fn_info.return_type);
535 self.section.writeOperand(ResultId, func);
536 self.section.writeWord(operands.items[2]);
537 self.section.writeOperand(ResultId, new_info.new_function_type);
538
539 // Emit the OpFunctionParameters for the invocation globals. The functions
540 // actual parameters are emitted unchanged from their original form, so
541 // we don't need to handle those here.
542
543 for (fn_info.invocation_globals.keys(), 0..) |global, index| {
544 const ty = info.invocation_globals.get(global).?.ty;
545 const id = new_info.invocationGlobalId(index);
546 try self.section.emit(self.arena, .OpFunctionParameter, .{
547 .id_result_type = ty,
548 .id_result = id,
549 });
550 }
551
552 maybe_current_function = func;
553 },
554 .OpFunctionCall => {
555 // Add the required invocation globals to the function's new parameter list.
556 const caller = maybe_current_function.?;
557 const callee: ResultId = @enumFromInt(operands.items[2]);
558 const caller_info = info.functions.get(caller).?;
559 const callee_info = info.functions.get(callee).?;
560 const caller_new_info = self.function_new_info.get(caller).?;
561 const total_params = callee_info.invocation_globals.count() + callee_info.param_types.len;
562
563 try self.section.emitRaw(self.arena, .OpFunctionCall, 3 + total_params);
564 self.section.writeWord(operands.items[0]); // Copy result type-id
565 self.section.writeWord(operands.items[1]); // Copy result-id
566 self.section.writeOperand(ResultId, callee);
567
568 // Add the new arguments
569 for (callee_info.invocation_globals.keys()) |global| {
570 const caller_global_index = caller_info.invocation_globals.getIndex(global).?;
571 const id = caller_new_info.invocationGlobalId(caller_global_index);
572 self.section.writeOperand(ResultId, id);
573 }
574
575 // Add the original arguments
576 self.section.writeWords(operands.items[3..]);
577 },
578 else => {
579 try self.section.emitRawInstruction(self.arena, inst.opcode, operands.items);
580 },
581 }
582 }
583 }
584
585 fn emitNewEntryPoints(self: *ModuleBuilder, info: ModuleInfo) !void {
586 var all_function_invocation_globals = std.AutoArrayHashMap(ResultId, void).init(self.arena);
587
588 for (info.entry_points.keys(), 0..) |func, entry_point_index| {
589 const fn_info = info.functions.get(func).?;
590 const ep_id: ResultId = @enumFromInt(self.entry_point_new_id_base + @as(u32, @intCast(entry_point_index)));
591 const fn_type = self.function_types.get(.{
592 .return_type = fn_info.return_type,
593 .param_types = fn_info.param_types,
594 }).?;
595
596 try self.section.emit(self.arena, .OpFunction, .{
597 .id_result_type = fn_info.return_type,
598 .id_result = ep_id,
599 .function_control = .{}, // TODO: Copy the attributes from the original function maybe?
600 .function_type = fn_type,
601 });
602
603 // Emit OpFunctionParameter instructions for the original kernel's parameters.
604 const params_id_base: u32 = @intFromEnum(self.allocIds(@intCast(fn_info.param_types.len)));
605 for (fn_info.param_types, 0..) |param_type, i| {
606 const id: ResultId = @enumFromInt(params_id_base + @as(u32, @intCast(i)));
607 try self.section.emit(self.arena, .OpFunctionParameter, .{
608 .id_result_type = param_type,
609 .id_result = id,
610 });
611 }
612
613 try self.section.emit(self.arena, .OpLabel, .{
614 .id_result = self.allocId(),
615 });
616
617 // Besides the IDs of the main kernel, we also need the
618 // dependencies of the globals.
619 // Just quickly construct that set here.
620 all_function_invocation_globals.clearRetainingCapacity();
621 for (fn_info.invocation_globals.keys()) |global| {
622 try all_function_invocation_globals.put(global, {});
623 const global_info = info.invocation_globals.get(global).?;
624 for (global_info.dependencies.keys()) |dependency| {
625 try all_function_invocation_globals.put(dependency, {});
626 }
627 }
628
629 // Declare the IDs of the invocation globals.
630 const global_id_base: u32 = @intFromEnum(self.allocIds(@intCast(all_function_invocation_globals.count())));
631 for (all_function_invocation_globals.keys(), 0..) |global, i| {
632 const global_info = info.invocation_globals.get(global).?;
633
634 const id: ResultId = @enumFromInt(global_id_base + @as(u32, @intCast(i)));
635 try self.section.emit(self.arena, .OpVariable, .{
636 .id_result_type = global_info.ty,
637 .id_result = id,
638 .storage_class = .function,
639 .initializer = null,
640 });
641 }
642
643 // Call initializers for invocation globals that need it
644 for (all_function_invocation_globals.keys()) |global| {
645 const global_info = info.invocation_globals.get(global).?;
646 if (global_info.initializer == .none) continue;
647
648 const initializer_info = info.functions.get(global_info.initializer).?;
649 assert(initializer_info.param_types.len == 0);
650
651 try self.callWithGlobalsAndLinearParams(
652 all_function_invocation_globals,
653 global_info.initializer,
654 initializer_info,
655 global_id_base,
656 undefined,
657 );
658 }
659
660 // Call the main kernel entry
661 try self.callWithGlobalsAndLinearParams(
662 all_function_invocation_globals,
663 func,
664 fn_info,
665 global_id_base,
666 params_id_base,
667 );
668
669 try self.section.emit(self.arena, .OpReturn, {});
670 try self.section.emit(self.arena, .OpFunctionEnd, {});
671 }
672 }
673
674 fn callWithGlobalsAndLinearParams(
675 self: *ModuleBuilder,
676 all_globals: std.AutoArrayHashMap(ResultId, void),
677 func: ResultId,
678 callee_info: ModuleInfo.Fn,
679 global_id_base: u32,
680 params_id_base: u32,
681 ) !void {
682 const total_arguments = callee_info.invocation_globals.count() + callee_info.param_types.len;
683 try self.section.emitRaw(self.arena, .OpFunctionCall, 3 + total_arguments);
684 self.section.writeOperand(ResultId, callee_info.return_type);
685 self.section.writeOperand(ResultId, self.allocId());
686 self.section.writeOperand(ResultId, func);
687
688 // Add the invocation globals
689 for (callee_info.invocation_globals.keys()) |global| {
690 const index = all_globals.getIndex(global).?;
691 const id: ResultId = @enumFromInt(global_id_base + @as(u32, @intCast(index)));
692 self.section.writeOperand(ResultId, id);
693 }
694
695 // Add the arguments
696 for (0..callee_info.param_types.len) |index| {
697 const id: ResultId = @enumFromInt(params_id_base + @as(u32, @intCast(index)));
698 self.section.writeOperand(ResultId, id);
699 }
700 }
701};
702
703pub fn run(parser: *BinaryModule.Parser, binary: *BinaryModule, progress: std.Progress.Node) !void {
704 const sub_node = progress.start("Lower invocation globals", 6);
705 defer sub_node.end();
706
707 var arena = std.heap.ArenaAllocator.init(parser.a);
708 defer arena.deinit();
709 const a = arena.allocator();
710
711 var info = try ModuleInfo.parse(a, parser, binary.*);
712 try info.resolve(a);
713
714 var builder = try ModuleBuilder.init(a, binary.*, info);
715 sub_node.completeOne();
716 try builder.deriveNewFnInfo(info);
717 sub_node.completeOne();
718 try builder.processPreamble(binary.*, info);
719 sub_node.completeOne();
720 try builder.emitFunctionTypes(info);
721 sub_node.completeOne();
722 try builder.rewriteFunctions(parser, binary.*, info);
723 sub_node.completeOne();
724 try builder.emitNewEntryPoints(info);
725 sub_node.completeOne();
726 try builder.finalize(parser.a, binary);
727}