Commit 436f53f55d

Ali Chraghi <alichraghi@proton.me>
2024-02-19 14:38:39
spirv: implement `@mulWithOverflow`
1 parent 9785014
Changed files (6)
src/codegen/spirv/Assembler.zig
@@ -263,7 +263,10 @@ fn processInstruction(self: *Assembler) !void {
         .OpExtInstImport => blk: {
             const set_name_offset = self.inst.operands.items[1].string;
             const set_name = std.mem.sliceTo(self.inst.string_bytes.items[set_name_offset..], 0);
-            break :blk .{ .value = try self.spv.importInstructionSet(set_name) };
+            const set_tag = std.meta.stringToEnum(spec.InstructionSet, set_name) orelse {
+                return self.fail(set_name_offset, "unknown instruction set: {s}", .{set_name});
+            };
+            break :blk .{ .value = try self.spv.importInstructionSet(set_tag) };
         },
         else => switch (self.inst.opcode.class()) {
             .TypeDeclaration => try self.processTypeInstruction(),
src/codegen/spirv.zig
@@ -2315,6 +2315,7 @@ const DeclGen = struct {
             .sub, .sub_wrap, .sub_optimized => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub),
             .mul, .mul_wrap, .mul_optimized => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul),
 
+
             .abs => try self.airAbs(inst),
             .floor => try self.airFloor(inst),
 
@@ -2330,6 +2331,7 @@ const DeclGen = struct {
 
             .add_with_overflow => try self.airAddSubOverflow(inst, .OpIAdd, .OpULessThan, .OpSLessThan),
             .sub_with_overflow => try self.airAddSubOverflow(inst, .OpISub, .OpUGreaterThan, .OpSGreaterThan),
+            .mul_with_overflow => try self.airMulOverflow(inst),
             .shl_with_overflow => try self.airShlOverflow(inst),
 
             .mul_add => try self.airMulAdd(inst),
@@ -2733,8 +2735,8 @@ const DeclGen = struct {
             else => unreachable,
         };
         const set_id = switch (target.os.tag) {
-            .opencl => try self.spv.importInstructionSet("OpenCL.std"),
-            .vulkan => try self.spv.importInstructionSet("GLSL.std.450"),
+            .opencl => try self.spv.importInstructionSet(.@"OpenCL.std"),
+            .vulkan => try self.spv.importInstructionSet(.@"GLSL.std.450"),
             else => unreachable,
         };
 
@@ -2998,6 +3000,61 @@ const DeclGen = struct {
         );
     }
 
+    fn airMulOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
+        const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
+        const lhs = try self.resolve(extra.lhs);
+        const rhs = try self.resolve(extra.rhs);
+
+        const result_ty = self.typeOfIndex(inst);
+        const operand_ty = self.typeOf(extra.lhs);
+        const ov_ty = result_ty.structFieldType(1, self.module);
+
+        const info = self.arithmeticTypeInfo(operand_ty);
+        switch (info.class) {
+            .composite_integer => return self.todo("overflow ops for composite integers", .{}),
+            .strange_integer, .integer => {},
+            .float, .bool => unreachable,
+        }
+
+        var wip_result = try self.elementWise(operand_ty, true);
+        defer wip_result.deinit();
+        var wip_ov = try self.elementWise(ov_ty, true);
+        defer wip_ov.deinit();
+
+        const zero_id = try self.constInt(wip_result.ty_ref, 0);
+        const zero_ov_id = try self.constInt(wip_ov.ty_ref, 0);
+        const one_ov_id = try self.constInt(wip_ov.ty_ref, 1);
+
+        for (wip_result.results, wip_ov.results, 0..) |*result_id, *ov_id, i| {
+            const lhs_elem_id = try wip_result.elementAt(operand_ty, lhs, i);
+            const rhs_elem_id = try wip_result.elementAt(operand_ty, rhs, i);
+
+            result_id.* = try self.arithOp(wip_result.ty, lhs_elem_id, rhs_elem_id, .OpFMul, .OpIMul, .OpIMul);
+
+            // (a != 0) and (x / a != b)
+            const not_zero_id = try self.cmp(.neq, Type.bool, wip_result.ty, lhs_elem_id, zero_id);
+            const res_rhs_id = try self.arithOp(wip_result.ty, result_id.*, lhs_elem_id, .OpFDiv, .OpSDiv, .OpUDiv);
+            const res_rhs_not_rhs_id = try self.cmp(.neq, Type.bool, wip_result.ty, res_rhs_id, rhs_elem_id);
+            const cond_id = try self.binOpSimple(Type.bool, not_zero_id, res_rhs_not_rhs_id, .OpLogicalAnd);
+
+            ov_id.* = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpSelect, .{
+                .id_result_type = wip_ov.ty_id,
+                .id_result = ov_id.*,
+                .condition = cond_id,
+                .object_1 = one_ov_id,
+                .object_2 = zero_ov_id,
+            });
+        }
+
+        return try self.constructStruct(
+            result_ty,
+            &.{ operand_ty, ov_ty },
+            &.{ try wip_result.finalize(), try wip_ov.finalize() },
+        );
+    }
+
     fn airShlOverflow(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         const mod = self.module;
         const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
test/behavior/for.zig
@@ -226,7 +226,6 @@ test "else continue outer for" {
 
 test "for loop with else branch" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     {
         var x = [_]u32{ 1, 2 };
test/behavior/hasdecl.zig
@@ -12,8 +12,6 @@ const Bar = struct {
 };
 
 test "@hasDecl" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     try expect(@hasDecl(Foo, "public_thing"));
     try expect(!@hasDecl(Foo, "private_thing"));
     try expect(!@hasDecl(Foo, "no_thing"));
@@ -24,8 +22,6 @@ test "@hasDecl" {
 }
 
 test "@hasDecl using a sliced string literal" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     try expect(@hasDecl(@This(), "std") == true);
     try expect(@hasDecl(@This(), "std"[0..0]) == false);
     try expect(@hasDecl(@This(), "std"[0..1]) == false);
test/behavior/math.zig
@@ -788,7 +788,6 @@ test "small int addition" {
 test "basic @mulWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     {
         var a: u8 = 86;
@@ -821,7 +820,6 @@ test "basic @mulWithOverflow" {
 test "extensive @mulWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     {
         var a: u5 = 3;
@@ -998,7 +996,6 @@ test "@mulWithOverflow bitsize > 32" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     {
         var a: u62 = 3;
test/behavior/vector.zig
@@ -1136,7 +1136,6 @@ test "@mulWithOverflow" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const S = struct {
         fn doTheTest() !void {