Commit e443b1bed7

Robin Voetter <robin@voetter.nl>
2022-11-27 16:22:01
spirv: switch_br lowering
Implements lowering switch statements in the SPIR-V backend.
1 parent 205d928
Changed files (2)
src
codegen
src/codegen/spirv/Module.zig
@@ -132,6 +132,11 @@ pub fn allocId(self: *Module) spec.IdResult {
     return .{ .id = self.next_result_id };
 }
 
+pub fn allocIds(self: *Module, n: u32) spec.IdResult {
+    defer self.next_result_id += n;
+    return .{ .id = self.next_result_id };
+}
+
 pub fn idBound(self: Module) Word {
     return self.next_result_id;
 }
src/codegen/spirv.zig
@@ -887,11 +887,13 @@ pub const DeclGen = struct {
             .breakpoint => return,
             .cond_br    => return self.airCondBr(inst),
             .constant   => unreachable,
+            .const_ty   => unreachable,
             .dbg_stmt   => return self.airDbgStmt(inst),
             .loop       => return self.airLoop(inst),
             .ret        => return self.airRet(inst),
             .ret_load   => return self.airRetLoad(inst),
             .store      => return self.airStore(inst),
+            .switch_br  => return self.airSwitchBr(inst),
             .unreach    => return self.airUnreach(),
 
             .assembly => try self.airAssembly(inst),
@@ -1679,6 +1681,121 @@ pub const DeclGen = struct {
         });
     }
 
+    fn airSwitchBr(self: *DeclGen, inst: Air.Inst.Index) !void {
+        const target = self.getTarget();
+        const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+        const cond = try self.resolve(pl_op.operand);
+        const cond_ty = self.air.typeOf(pl_op.operand);
+        const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload);
+
+        const cond_words: u32 = switch (cond_ty.zigTypeTag()) {
+            .Int => blk: {
+                const bits = cond_ty.intInfo(target).bits;
+                const backing_bits = self.backingIntBits(bits) orelse {
+                    return self.todo("implement composite int switch", .{});
+                };
+                break :blk if (backing_bits <= 32) 1 else 2;
+            },
+            .Enum => blk: {
+                var buffer: Type.Payload.Bits = undefined;
+                const int_ty = cond_ty.intTagType(&buffer);
+                const int_info = int_ty.intInfo(target);
+                const backing_bits = self.backingIntBits(int_info.bits) orelse {
+                    return self.todo("implement composite int switch", .{});
+                };
+                break :blk if (backing_bits <= 32) 1 else 2;
+            },
+            else => return self.todo("implement switch for type {s}", .{@tagName(cond_ty.zigTypeTag())}), // TODO: Figure out which types apply here, and work around them as we can only do integers.
+        };
+
+        const num_cases = switch_br.data.cases_len;
+
+        // Compute the total number of arms that we need.
+        // Zig switches are grouped by condition, so we need to loop through all of them
+        const num_conditions = blk: {
+            var extra_index: usize = switch_br.end;
+            var case_i: u32 = 0;
+            var num_conditions: u32 = 0;
+            while (case_i < num_cases) : (case_i += 1) {
+                const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
+                const case_body = self.air.extra[case.end + case.data.items_len ..][0..case.data.body_len];
+                extra_index = case.end + case.data.items_len + case_body.len;
+                num_conditions += case.data.items_len;
+            }
+            break :blk num_conditions;
+        };
+
+        // First, pre-allocate the labels for the cases.
+        const first_case_label = self.spv.allocIds(num_cases);
+        // We always need the default case - if zig has none, we will generate unreachable there.
+        const default = self.spv.allocId();
+
+        // Emit the instruction before generating the blocks.
+        try self.func.body.emitRaw(self.spv.gpa, .OpSwitch, 2 + (cond_words + 1) * num_conditions);
+        self.func.body.writeOperand(IdRef, cond);
+        self.func.body.writeOperand(IdRef, default.toRef());
+
+        // Emit each of the cases
+        {
+            var extra_index: usize = switch_br.end;
+            var case_i: u32 = 0;
+            while (case_i < num_cases) : (case_i += 1) {
+                // SPIR-V needs a literal here, which' width depends on the case condition.
+                const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
+                const items = @ptrCast([]const Air.Inst.Ref, self.air.extra[case.end..][0..case.data.items_len]);
+                const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
+                extra_index = case.end + case.data.items_len + case_body.len;
+
+                const label = IdRef{ .id = first_case_label.id + case_i };
+
+                for (items) |item| {
+                    const value = self.air.value(item) orelse {
+                        return self.todo("switch on runtime value???", .{});
+                    };
+                    const int_val = switch (cond_ty.zigTypeTag()) {
+                        .Int => if (cond_ty.isSignedInt()) @bitCast(u64, value.toSignedInt()) else value.toUnsignedInt(target),
+                        .Enum => blk: {
+                            var int_buffer: Value.Payload.U64 = undefined;
+                            // TODO: figure out of cond_ty is correct (something with enum literals)
+                            break :blk value.enumToInt(cond_ty, &int_buffer).toUnsignedInt(target); // TODO: composite integer constants
+                        },
+                        else => unreachable,
+                    };
+                    const int_lit: spec.LiteralContextDependentNumber = switch (cond_words) {
+                        1 => .{ .uint32 = @intCast(u32, int_val) },
+                        2 => .{ .uint64 = int_val },
+                        else => unreachable,
+                    };
+                    self.func.body.writeOperand(spec.LiteralContextDependentNumber, int_lit);
+                    self.func.body.writeOperand(IdRef, label);
+                }
+            }
+        }
+
+        // Now, finally, we can start emitting each of the cases.
+        var extra_index: usize = switch_br.end;
+        var case_i: u32 = 0;
+        while (case_i < num_cases) : (case_i += 1) {
+            const case = self.air.extraData(Air.SwitchBr.Case, extra_index);
+            const items = @ptrCast([]const Air.Inst.Ref, self.air.extra[case.end..][0..case.data.items_len]);
+            const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len];
+            extra_index = case.end + case.data.items_len + case_body.len;
+
+            const label = IdResult{ .id = first_case_label.id + case_i };
+
+            try self.beginSpvBlock(label);
+            try self.genBody(case_body);
+        }
+
+        const else_body = self.air.extra[extra_index..][0..switch_br.data.else_body_len];
+        try self.beginSpvBlock(default);
+        if (else_body.len != 0) {
+            try self.genBody(else_body);
+        } else {
+            try self.func.body.emit(self.spv.gpa, .OpUnreachable, {});
+        }
+    }
+
     fn airUnreach(self: *DeclGen) !void {
         try self.func.body.emit(self.spv.gpa, .OpUnreachable, {});
     }