Commit 747f4ae3f5

Robin Voetter <robin@voetter.nl>
2024-01-15 21:58:13
spirv: sh[rl](_exact)?
1 parent 3ef5b80
Changed files (2)
src
codegen
test
behavior
src/codegen/spirv.zig
@@ -2111,7 +2111,8 @@ const DeclGen = struct {
             .bool_and => try self.airBinOpSimple(inst, .OpLogicalAnd),
             .bool_or  => try self.airBinOpSimple(inst, .OpLogicalOr),
 
-            .shl => try self.airShift(inst, .OpShiftLeftLogical),
+            .shl, .shl_exact => try self.airShift(inst, .OpShiftLeftLogical, .OpShiftLeftLogical),
+            .shr, .shr_exact => try self.airShift(inst, .OpShiftRightLogical, .OpShiftRightArithmetic),
 
             .min => try self.airMinMax(inst, .lt),
             .max => try self.airMinMax(inst, .gt),
@@ -2254,28 +2255,42 @@ const DeclGen = struct {
         return try self.binOpSimple(ty, lhs_id, rhs_id, opcode);
     }
 
-    fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime opcode: Opcode) !?IdRef {
+    fn airShift(self: *DeclGen, inst: Air.Inst.Index, comptime unsigned: Opcode, comptime signed: Opcode) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
+        const mod = self.module;
         const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
         const lhs_id = try self.resolve(bin_op.lhs);
         const rhs_id = try self.resolve(bin_op.rhs);
-        const result_type_id = try self.resolveTypeId(self.typeOfIndex(inst));
-
-        // the shift and the base must be the same type in SPIR-V, but in Zig the shift is a smaller int.
-        const shift_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
-            .id_result_type = result_type_id,
-            .id_result = shift_id,
-            .unsigned_value = rhs_id,
-        });
+        const result_ty = self.typeOfIndex(inst);
+        const result_ty_ref = try self.resolveType(result_ty, .direct);
 
         const result_id = self.spv.allocId();
-        try self.func.body.emit(self.spv.gpa, opcode, .{
-            .id_result_type = result_type_id,
+
+        // Sometimes Zig doesn't make both of the arguments the same types here. SPIR-V expects that,
+        // so just manually upcast it if required.
+        const shift_ty_ref = try self.resolveType(self.typeOf(bin_op.rhs), .direct);
+        const shift_id = if (shift_ty_ref != result_ty_ref) blk: {
+            const shift_id = self.spv.allocId();
+            try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
+                .id_result_type = self.typeId(result_ty_ref),
+                .id_result = shift_id,
+                .unsigned_value = rhs_id,
+            });
+            break :blk shift_id;
+        } else rhs_id;
+
+        const args = .{
+            .id_result_type = self.typeId(result_ty_ref),
             .id_result = result_id,
             .base = lhs_id,
             .shift = shift_id,
-        });
+        };
+
+        if (result_ty.isSignedInt(mod)) {
+            try self.func.body.emit(self.spv.gpa, signed, args);
+        } else {
+            try self.func.body.emit(self.spv.gpa, unsigned, args);
+        }
         return result_id;
     }
 
test/behavior/math.zig
@@ -12,7 +12,6 @@ const math = std.math;
 test "assignment operators" {
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     var i: u32 = 0;
     i += 5;
@@ -649,8 +648,6 @@ test "bit shift a u1" {
 }
 
 test "truncating shift right" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     try testShrTrunc(maxInt(u16));
     try comptime testShrTrunc(maxInt(u16));
 }
@@ -1343,8 +1340,6 @@ fn testShlExact(x: u8) !void {
 }
 
 test "exact shift right" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     try testShrExact(0b10110100);
     try comptime testShrExact(0b10110100);
 }