Commit 6e3770e970

Robin Voetter <robin@voetter.nl>
2023-05-18 02:45:21
spirv: implement pointer comparison in for air cmp
It turns out that the Khronos LLVM SPIRV translator does not support OpPtrEqual. Therefore, this instruction is emitted using a series of conversions. This commit breaks intToEnum, because enum was removed from the arithmetic type info. The enum should be converted to an int before this function is called.
1 parent 7077e90
Changed files (4)
src
codegen
test
src/codegen/spirv.zig
@@ -369,17 +369,6 @@ pub const DeclGen = struct {
                         .composite_integer,
                 };
             },
-            .Enum => blk: {
-                var buffer: Type.Payload.Bits = undefined;
-                const int_ty = ty.intTagType(&buffer);
-                const int_info = int_ty.intInfo(target);
-                break :blk ArithmeticTypeInfo{
-                    .bits = int_info.bits,
-                    .is_vector = false,
-                    .signedness = int_info.signedness,
-                    .class = .integer,
-                };
-            },
             // As of yet, there is no vector support in the self-hosted compiler.
             .Vector => self.todo("implement arithmeticTypeInfo for Vector", .{}),
             // TODO: For which types is this the case?
@@ -1742,12 +1731,12 @@ pub const DeclGen = struct {
             .struct_field_ptr_index_2 => try self.airStructFieldPtrIndex(inst, 2),
             .struct_field_ptr_index_3 => try self.airStructFieldPtrIndex(inst, 3),
 
-            .cmp_eq  => try self.airCmp(inst, .OpFOrdEqual,            .OpLogicalEqual,      .OpIEqual),
-            .cmp_neq => try self.airCmp(inst, .OpFOrdNotEqual,         .OpLogicalNotEqual,   .OpINotEqual),
-            .cmp_gt  => try self.airCmp(inst, .OpFOrdGreaterThan,      .OpSGreaterThan,      .OpUGreaterThan),
-            .cmp_gte => try self.airCmp(inst, .OpFOrdGreaterThanEqual, .OpSGreaterThanEqual, .OpUGreaterThanEqual),
-            .cmp_lt  => try self.airCmp(inst, .OpFOrdLessThan,         .OpSLessThan,         .OpULessThan),
-            .cmp_lte => try self.airCmp(inst, .OpFOrdLessThanEqual,    .OpSLessThanEqual,    .OpULessThanEqual),
+            .cmp_eq  => try self.airCmp(inst, .eq),
+            .cmp_neq => try self.airCmp(inst, .neq),
+            .cmp_gt  => try self.airCmp(inst, .gt),
+            .cmp_gte => try self.airCmp(inst, .gte),
+            .cmp_lt  => try self.airCmp(inst, .lt),
+            .cmp_lte => try self.airCmp(inst, .lte),
 
             .arg     => self.airArg(),
             .alloc   => try self.airAlloc(inst),
@@ -2039,58 +2028,122 @@ pub const DeclGen = struct {
         return result_id;
     }
 
-    fn airCmp(self: *DeclGen, inst: Air.Inst.Index, comptime fop: Opcode, comptime sop: Opcode, comptime uop: Opcode) !?IdRef {
-        if (self.liveness.isUnused(inst)) return null;
-        const bin_op = self.air.instructions.items(.data)[inst].bin_op;
-        var lhs_id = try self.resolve(bin_op.lhs);
-        var rhs_id = try self.resolve(bin_op.rhs);
-        const result_id = self.spv.allocId();
-        const result_type_id = try self.resolveTypeId(Type.bool);
-        const op_ty = self.air.typeOf(bin_op.lhs);
-        assert(op_ty.eql(self.air.typeOf(bin_op.rhs), self.module));
+    fn cmp(
+        self: *DeclGen,
+        comptime op: std.math.CompareOperator,
+        bool_ty_id: IdRef,
+        ty: Type,
+        lhs_id: IdRef,
+        rhs_id: IdRef,
+    ) !IdRef {
+        var cmp_lhs_id = lhs_id;
+        var cmp_rhs_id = rhs_id;
+        const opcode: Opcode = opcode: {
+            var int_buffer: Type.Payload.Bits = undefined;
+            const op_ty = switch (ty.zigTypeTag()) {
+                .Int, .Bool, .Float => ty,
+                .Enum => ty.intTagType(&int_buffer),
+                .ErrorSet => Type.u16,
+                .Pointer => blk: {
+                    // Note that while SPIR-V offers OpPtrEqual and OpPtrNotEqual, they are
+                    // currently not implemented in the SPIR-V LLVM translator. Thus, we emit these using
+                    // OpConvertPtrToU...
+                    cmp_lhs_id = self.spv.allocId();
+                    cmp_rhs_id = self.spv.allocId();
+
+                    const usize_ty_id = self.typeId(try self.sizeType());
+
+                    try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
+                        .id_result_type = usize_ty_id,
+                        .id_result = cmp_lhs_id,
+                        .pointer = lhs_id,
+                    });
 
-        // Comparisons are generally applicable to both scalar and vector operations in SPIR-V,
-        // but int and float versions of operations require different opcodes.
-        const info = try self.arithmeticTypeInfo(op_ty);
+                    try self.func.body.emit(self.spv.gpa, .OpConvertPtrToU, .{
+                        .id_result_type = usize_ty_id,
+                        .id_result = cmp_rhs_id,
+                        .pointer = rhs_id,
+                    });
 
-        const opcode_index: usize = switch (info.class) {
-            .composite_integer => {
-                return self.todo("binary operations for composite integers", .{});
-            },
-            .float => 0,
-            .bool => 1,
-            .strange_integer => blk: {
-                const op_ty_ref = try self.resolveType(op_ty, .direct);
-                lhs_id = try self.maskStrangeInt(op_ty_ref, lhs_id, info.bits);
-                rhs_id = try self.maskStrangeInt(op_ty_ref, rhs_id, info.bits);
-                break :blk switch (info.signedness) {
-                    .signed => @as(usize, 1),
-                    .unsigned => @as(usize, 2),
-                };
-            },
-            .integer => switch (info.signedness) {
-                .signed => @as(usize, 1),
-                .unsigned => @as(usize, 2),
-            },
-        };
+                    break :blk Type.usize;
+                },
+                .Optional => unreachable, // TODO
+                else => unreachable,
+            };
 
-        const operands = .{
-            .id_result_type = result_type_id,
-            .id_result = result_id,
-            .operand_1 = lhs_id,
-            .operand_2 = rhs_id,
-        };
+            const info = try self.arithmeticTypeInfo(op_ty);
+            const signedness = switch (info.class) {
+                .composite_integer => {
+                    return self.todo("binary operations for composite integers", .{});
+                },
+                .float => break :opcode switch (op) {
+                    .eq => .OpFOrdEqual,
+                    .neq => .OpFOrdNotEqual,
+                    .lt => .OpFOrdLessThan,
+                    .lte => .OpFOrdLessThanEqual,
+                    .gt => .OpFOrdGreaterThan,
+                    .gte => .OpFOrdGreaterThanEqual,
+                },
+                .bool => break :opcode switch (op) {
+                    .eq => .OpIEqual,
+                    .neq => .OpINotEqual,
+                    else => unreachable,
+                },
+                .strange_integer => sign: {
+                    const op_ty_ref = try self.resolveType(op_ty, .direct);
+                    // Mask operands before performing comparison.
+                    cmp_lhs_id = try self.maskStrangeInt(op_ty_ref, cmp_lhs_id, info.bits);
+                    cmp_rhs_id = try self.maskStrangeInt(op_ty_ref, cmp_rhs_id, info.bits);
+                    break :sign info.signedness;
+                },
+                .integer => info.signedness,
+            };
 
-        switch (opcode_index) {
-            0 => try self.func.body.emit(self.spv.gpa, fop, operands),
-            1 => try self.func.body.emit(self.spv.gpa, sop, operands),
-            2 => try self.func.body.emit(self.spv.gpa, uop, operands),
-            else => unreachable,
-        }
+            break :opcode switch (signedness) {
+                .unsigned => switch (op) {
+                    .eq => .OpIEqual,
+                    .neq => .OpINotEqual,
+                    .lt => .OpULessThan,
+                    .lte => .OpULessThanEqual,
+                    .gt => .OpUGreaterThan,
+                    .gte => .OpUGreaterThanEqual,
+                },
+                .signed => switch (op) {
+                    .eq => .OpIEqual,
+                    .neq => .OpINotEqual,
+                    .lt => .OpSLessThan,
+                    .lte => .OpSLessThanEqual,
+                    .gt => .OpSGreaterThan,
+                    .gte => .OpSGreaterThanEqual,
+                },
+            };
+        };
 
+        const result_id = self.spv.allocId();
+        try self.func.body.emitRaw(self.spv.gpa, opcode, 4);
+        self.func.body.writeOperand(spec.IdResultType, bool_ty_id);
+        self.func.body.writeOperand(spec.IdResult, result_id);
+        self.func.body.writeOperand(spec.IdResultType, cmp_lhs_id);
+        self.func.body.writeOperand(spec.IdResultType, cmp_rhs_id);
         return result_id;
     }
 
+    fn airCmp(
+        self: *DeclGen,
+        inst: Air.Inst.Index,
+        comptime op: std.math.CompareOperator,
+    ) !?IdRef {
+        if (self.liveness.isUnused(inst)) return null;
+        const bin_op = self.air.instructions.items(.data)[inst].bin_op;
+        const lhs_id = try self.resolve(bin_op.lhs);
+        const rhs_id = try self.resolve(bin_op.rhs);
+        const bool_ty_id = try self.resolveTypeId(Type.bool);
+        const ty = self.air.typeOf(bin_op.lhs);
+        assert(ty.eql(self.air.typeOf(bin_op.rhs), self.module));
+
+        return try self.cmp(op, bool_ty_id, ty, lhs_id, rhs_id);
+    }
+
     fn bitcast(self: *DeclGen, target_type_id: IdResultType, value_id: IdRef) !IdRef {
         const result_id = self.spv.allocId();
         try self.func.body.emit(self.spv.gpa, .OpBitcast, .{
test/behavior/basic.zig
@@ -134,21 +134,18 @@ fn first4KeysOfHomeRow() []const u8 {
 
 test "return string from function" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try expect(mem.eql(u8, first4KeysOfHomeRow(), "aoeu"));
 }
 
 test "hex escape" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try expect(mem.eql(u8, "\x68\x65\x6c\x6c\x6f", "hello"));
 }
 
 test "multiline string" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const s1 =
         \\one
@@ -161,7 +158,6 @@ test "multiline string" {
 
 test "multiline string comments at start" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const s1 =
         //\\one
@@ -174,7 +170,6 @@ test "multiline string comments at start" {
 
 test "multiline string comments at end" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const s1 =
         \\one
@@ -187,7 +182,6 @@ test "multiline string comments at end" {
 
 test "multiline string comments in middle" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const s1 =
         \\one
@@ -200,7 +194,6 @@ test "multiline string comments in middle" {
 
 test "multiline string comments at multiple places" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     const s1 =
         \\one
@@ -214,14 +207,11 @@ test "multiline string comments at multiple places" {
 }
 
 test "string concatenation simple" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     try expect(mem.eql(u8, "OK" ++ " IT " ++ "WORKED", "OK IT WORKED"));
 }
 
 test "array mult operator" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try expect(mem.eql(u8, "ab" ** 5, "ababababab"));
 }
@@ -387,7 +377,6 @@ test "take address of parameter" {
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
 
     try testTakeAddressOfParameter(12.34);
 }
@@ -690,8 +679,6 @@ test "explicit cast optional pointers" {
 }
 
 test "pointer comparison" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const a = @as([]const u8, "a");
     const b = &a;
     try expect(ptrEql(b, b));
@@ -892,8 +879,6 @@ test "catch in block has correct result location" {
 }
 
 test "labeled block with runtime branch forwards its result location type to break statements" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const E = enum { a, b };
     var a = false;
     const e: E = blk: {
@@ -1062,8 +1047,6 @@ test "switch inside @as gets correct type" {
 }
 
 test "inline call of function with a switch inside the return statement" {
-    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
-
     const S = struct {
         inline fn foo(x: anytype) @TypeOf(x) {
             return switch (x) {
test/behavior/enum.zig
@@ -20,6 +20,8 @@ test "enum to int" {
 }
 
 fn testIntToEnumEval(x: i32) !void {
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+
     try expect(@intToEnum(IntToEnumNumber, x) == IntToEnumNumber.Three);
 }
 const IntToEnumNumber = enum { Zero, One, Two, Three, Four };
test/behavior/memcpy.zig
@@ -67,14 +67,16 @@ fn testMemcpyDestManyPtr() !void {
 }
 
 comptime {
-    const S = struct {
-        buffer: [8]u8 = undefined,
-        fn set(self: *@This(), items: []const u8) void {
-            @memcpy(self.buffer[0..items.len], items);
-        }
-    };
+    if (builtin.zig_backend != .stage2_spirv64) {
+        const S = struct {
+            buffer: [8]u8 = undefined,
+            fn set(self: *@This(), items: []const u8) void {
+                @memcpy(self.buffer[0..items.len], items);
+            }
+        };
 
-    var s = S{};
-    s.set("hello");
-    if (!std.mem.eql(u8, s.buffer[0..5], "hello")) @compileError("bad");
+        var s = S{};
+        s.set("hello");
+        if (!std.mem.eql(u8, s.buffer[0..5], "hello")) @compileError("bad");
+    }
 }