Commit 8608d6e235

Robin Voetter <robin@voetter.nl>
2022-11-27 14:15:08
spirv: div, rem, intcast, some strange integer masking
Implements the div-family and intcast AIR instructions, and starts implementing a mechanism for masking the value of 'strange' integers before they are used in an operation that does not hold under modulo.
1 parent 23e210c
Changed files (1)
src
codegen
src/codegen/spirv.zig
@@ -824,9 +824,21 @@ pub const DeclGen = struct {
         const air_tags = self.air.instructions.items(.tag);
         const maybe_result_id: ?IdRef = switch (air_tags[inst]) {
             // zig fmt: off
-            .add, .addwrap => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd),
-            .sub, .subwrap => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub),
-            .mul, .mulwrap => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul),
+            .add, .addwrap => try self.airArithOp(inst, .OpFAdd, .OpIAdd, .OpIAdd, true),
+            .sub, .subwrap => try self.airArithOp(inst, .OpFSub, .OpISub, .OpISub, true),
+            .mul, .mulwrap => try self.airArithOp(inst, .OpFMul, .OpIMul, .OpIMul, true),
+
+            .div_float,
+            .div_float_optimized,
+            // TODO: Check that this is the right operation.
+            .div_trunc,
+            .div_trunc_optimized,
+            => try self.airArithOp(inst, .OpFDiv, .OpSDiv, .OpUDiv, false),
+            // TODO: Check if this is the right operation
+            // TODO: Make airArithOp for rem not emit a mask for the LHS.
+            .rem,
+            .rem_optimized,
+            => try self.airArithOp(inst, .OpFRem, .OpSRem, .OpSRem, false),
 
             .add_with_overflow => try self.airOverflowArithOp(inst),
 
@@ -838,8 +850,9 @@ pub const DeclGen = struct {
             .bool_and => try self.airBinOpSimple(inst, .OpLogicalAnd),
             .bool_or  => try self.airBinOpSimple(inst, .OpLogicalOr),
 
-            .bitcast        => try self.airBitcast(inst),
-            .not            => try self.airNot(inst),
+            .bitcast => try self.airBitcast(inst),
+            .intcast => try self.airIntcast(inst),
+            .not     => try self.airNot(inst),
 
             .slice_ptr      => try self.airSliceField(inst, 0),
             .slice_len      => try self.airSliceField(inst, 1),
@@ -909,20 +922,47 @@ pub const DeclGen = struct {
         return result_id.toRef();
     }
 
+    fn maskStrangeInt(self: *DeclGen, ty_id: IdResultType, int_id: IdRef, bits: u16) !IdRef {
+        const backing_bits = self.backingIntBits(bits).?;
+        const mask_value = if (bits == 64) 0xFFFF_FFFF_FFFF_FFFF else (@as(u64, 1) << @intCast(u6, bits)) - 1;
+        const mask_lit: spec.LiteralContextDependentNumber = switch (backing_bits) {
+            1...32 => .{ .uint32 = @truncate(u32, mask_value) },
+            33...64 => .{ .uint64 = mask_value },
+            else => unreachable,
+        };
+        // TODO: We should probably optimize these constants a bit.
+        const mask_id = self.spv.allocId();
+        try self.spv.sections.types_globals_constants.emit(self.spv.gpa, .OpConstant, .{
+            .id_result_type = ty_id,
+            .id_result = mask_id,
+            .value = mask_lit,
+        });
+        const result_id = self.spv.allocId();
+        try self.func.body.emit(self.spv.gpa, .OpBitwiseAnd, .{
+            .id_result_type = ty_id,
+            .id_result = result_id,
+            .operand_1 = int_id,
+            .operand_2 = mask_id.toRef(),
+        });
+        return result_id.toRef();
+    }
+
     fn airArithOp(
         self: *DeclGen,
         inst: Air.Inst.Index,
         comptime fop: Opcode,
         comptime sop: Opcode,
         comptime uop: Opcode,
+        /// true if this operation holds under modular arithmetic.
+        comptime modular: bool,
     ) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
         // LHS and RHS are guaranteed to have the same type, and AIR guarantees
         // the result to be the same as the LHS and RHS, which matches SPIR-V.
         const ty = self.air.typeOfIndex(inst);
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
-        const lhs_id = try self.resolve(bin_op.lhs);
-        const rhs_id = try self.resolve(bin_op.rhs);
+        var lhs_id = try self.resolve(bin_op.lhs);
+        var rhs_id = try self.resolve(bin_op.rhs);
 
         const result_id = self.spv.allocId();
         const result_type_id = try self.resolveTypeId(ty);
@@ -938,15 +978,22 @@ pub const DeclGen = struct {
             .composite_integer => {
                 return self.todo("binary operations for composite integers", .{});
             },
-            .strange_integer => {
-                return self.todo("binary operations for strange integers", .{});
+            .strange_integer => blk: {
+                if (!modular) {
+                    lhs_id = try self.maskStrangeInt(result_type_id, lhs_id, info.bits);
+                    rhs_id = try self.maskStrangeInt(result_type_id, rhs_id, info.bits);
+                }
+                break :blk switch (info.signedness) {
+                    .signed => @as(usize, 1),
+                    .unsigned => @as(usize, 2),
+                };
             },
             .integer => switch (info.signedness) {
                 .signed => @as(usize, 1),
                 .unsigned => @as(usize, 2),
             },
             .float => 0,
-            else => unreachable,
+            .bool => unreachable,
         };
 
         const operands = .{
@@ -981,11 +1028,18 @@ pub const DeclGen = struct {
         const operand_ty = self.air.typeOf(extra.lhs);
         const result_ty = self.air.typeOfIndex(inst);
 
+        const info = try self.arithmeticTypeInfo(operand_ty);
+        switch (info.class) {
+            .composite_integer => return self.todo("overflow ops for composite integers", .{}),
+            .strange_integer => return self.todo("overflow ops for strange integers", .{}),
+            .integer => {},
+            .float, .bool => unreachable,
+        }
+
         const operand_ty_id = try self.resolveTypeId(operand_ty);
         const result_type_id = try self.resolveTypeId(result_ty);
 
-        const operand_bits = operand_ty.intInfo(target).bits;
-        const overflow_member_ty = try self.intType(.unsigned, operand_bits);
+        const overflow_member_ty = try self.intType(.unsigned, info.bits);
         const overflow_member_ty_id = self.spv.typeResultId(overflow_member_ty);
 
         const op_result_id = blk: {
@@ -1083,8 +1137,8 @@ pub const DeclGen = struct {
     fn airCmp(self: *DeclGen, inst: Air.Inst.Index, comptime fop: Opcode, comptime sop: Opcode, comptime uop: Opcode) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
-        const lhs_id = try self.resolve(bin_op.lhs);
-        const rhs_id = try self.resolve(bin_op.rhs);
+        var lhs_id = try self.resolve(bin_op.lhs);
+        var rhs_id = try self.resolve(bin_op.rhs);
         const result_id = self.spv.allocId();
         const result_type_id = try self.resolveTypeId(Type.initTag(.bool));
         const op_ty = self.air.typeOf(bin_op.lhs);
@@ -1100,10 +1154,15 @@ pub const DeclGen = struct {
             },
             .float => 0,
             .bool => 1,
-            // TODO: Should strange integers be masked before comparison?
-            .strange_integer,
-            .integer,
-            => switch (info.signedness) {
+            .strange_integer => blk: {
+                lhs_id = try self.maskStrangeInt(result_type_id, lhs_id, info.bits);
+                rhs_id = try self.maskStrangeInt(result_type_id, rhs_id, info.bits);
+                break :blk switch (info.signedness) {
+                    .signed => @as(usize, 1),
+                    .unsigned => @as(usize, 2),
+                };
+            },
+            .integer => switch (info.signedness) {
                 .signed => @as(usize, 1),
                 .unsigned => @as(usize, 2),
             },
@@ -1144,6 +1203,31 @@ pub const DeclGen = struct {
         return try self.bitcast(result_type_id, operand_id);
     }
 
+    fn airIntcast(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+
+        const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+        const operand_id = try self.resolve(ty_op.operand);
+        const dest_ty = self.air.typeOfIndex(inst);
+        const dest_info = try self.arithmeticTypeInfo(dest_ty);
+        const dest_ty_id = try self.resolveTypeId(dest_ty);
+
+        const result_id = self.spv.allocId();
+        switch (dest_info.signedness) {
+            .signed => try self.func.body.emit(self.spv.gpa, .OpSConvert, .{
+                .id_result_type = dest_ty_id,
+                .id_result = result_id,
+                .signed_value = operand_id,
+            }),
+            .unsigned => try self.func.body.emit(self.spv.gpa, .OpUConvert, .{
+                .id_result_type = dest_ty_id,
+                .id_result = result_id,
+                .unsigned_value = operand_id,
+            }),
+        }
+        return result_id.toRef();
+    }
+
     fn airNot(self: *DeclGen, inst: Air.Inst.Index) !?IdRef {
         if (self.liveness.isUnused(inst)) return null;
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;