Commit adfcc8851b

John Schmidt <john.schmidt.h@gmail.com>
2022-03-17 17:25:42
Implement `@byteSwap` for vectors
Make the behavior tests for this a little more primitive to exercise as little extra functionality as possible.
1 parent 7233a33
Changed files (3)
src
test
behavior
src/codegen/llvm.zig
@@ -6078,9 +6078,26 @@ pub const FuncGen = struct {
         if (bits % 16 == 8) {
             // If not an even byte-multiple, we need zero-extend + shift-left 1 byte
             // The truncated result at the end will be the correct bswap
-            operand_llvm_ty = self.context.intType(bits + 8);
-            const extended = self.builder.buildZExt(operand, operand_llvm_ty, "");
-            operand = self.builder.buildShl(extended, operand_llvm_ty.constInt(8, .False), "");
+            const scalar_llvm_ty = self.context.intType(bits + 8);
+            if (operand_ty.zigTypeTag() == .Vector) {
+                const vec_len = operand_ty.vectorLen();
+                operand_llvm_ty = scalar_llvm_ty.vectorType(vec_len);
+
+                const shifts = try self.gpa.alloc(*const llvm.Value, vec_len);
+                defer self.gpa.free(shifts);
+
+                for (shifts) |*elem| {
+                    elem.* = scalar_llvm_ty.constInt(8, .False);
+                }
+                const shift_vec = llvm.constVector(shifts.ptr, vec_len);
+
+                const extended = self.builder.buildZExt(operand, operand_llvm_ty, "");
+                operand = self.builder.buildShl(extended, shift_vec, "");
+            } else {
+                const extended = self.builder.buildZExt(operand, scalar_llvm_ty, "");
+                operand = self.builder.buildShl(extended, scalar_llvm_ty.constInt(8, .False), "");
+                operand_llvm_ty = scalar_llvm_ty;
+            }
             bits = bits + 8;
         }
 
src/Sema.zig
@@ -13491,30 +13491,77 @@ fn zirByteSwap(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
     const operand = sema.resolveInst(inst_data.operand);
     const operand_ty = sema.typeOf(operand);
-    // TODO implement support for vectors
-    if (operand_ty.zigTypeTag() != .Int) {
-        return sema.fail(block, ty_src, "expected integer type, found '{}'", .{
-            operand_ty,
-        });
+
+    const scalar_ty = if (operand_ty.zigTypeTag() == .Vector)
+        operand_ty.elemType2()
+    else
+        operand_ty;
+
+    switch (operand_ty.zigTypeTag()) {
+        .Int, .ComptimeInt => {},
+        .Vector => {
+            switch (scalar_ty.zigTypeTag()) {
+                .Int, .ComptimeInt => {},
+                else => return sema.fail(block, ty_src, "expected vector of integer type, found vector of '{}'", .{scalar_ty}),
+            }
+        },
+        else => return sema.fail(block, ty_src, "expected integer type or vector of integer type, found '{}'", .{operand_ty}),
     }
+
     const target = sema.mod.getTarget();
-    const bits = operand_ty.intInfo(target).bits;
-    if (bits == 0) return Air.Inst.Ref.zero;
-    if (operand_ty.intInfo(target).bits % 8 != 0) {
-        return sema.fail(block, ty_src, "@byteSwap requires the number of bits to be evenly divisible by 8, but {} has {} bits", .{
-            operand_ty,
-            operand_ty.intInfo(target).bits,
-        });
+    const bits = scalar_ty.intInfo(target).bits;
+    if (bits % 8 != 0) {
+        return sema.fail(
+            block,
+            ty_src,
+            "@byteSwap requires the number of bits to be evenly divisible by 8, but {} has {} bits",
+            .{ scalar_ty, bits },
+        );
     }
 
-    const runtime_src = if (try sema.resolveMaybeUndefVal(block, operand_src, operand)) |val| {
-        if (val.isUndef()) return sema.addConstUndef(operand_ty);
-        const result_val = try val.byteSwap(operand_ty, target, sema.arena);
-        return sema.addConstant(operand_ty, result_val);
-    } else operand_src;
+    switch (operand_ty.zigTypeTag()) {
+        .Int, .ComptimeInt => {
+            if (bits == 0) return Air.Inst.Ref.zero;
 
-    try sema.requireRuntimeBlock(block, runtime_src);
-    return block.addTyOp(.byte_swap, operand_ty, operand);
+            const runtime_src = if (try sema.resolveMaybeUndefVal(block, operand_src, operand)) |val| {
+                if (val.isUndef()) return sema.addConstUndef(operand_ty);
+                const result_val = try val.byteSwap(operand_ty, target, sema.arena);
+                return sema.addConstant(operand_ty, result_val);
+            } else operand_src;
+
+            try sema.requireRuntimeBlock(block, runtime_src);
+            return block.addTyOp(.byte_swap, operand_ty, operand);
+        },
+        .Vector => {
+            if (bits == 0) {
+                return sema.addConstant(
+                    operand_ty,
+                    try Value.Tag.repeated.create(sema.arena, Value.zero),
+                );
+            }
+
+            const runtime_src = if (try sema.resolveMaybeUndefVal(block, operand_src, operand)) |val| {
+                if (val.isUndef())
+                    return sema.addConstUndef(operand_ty);
+
+                const vec_len = operand_ty.vectorLen();
+                var elem_buf: Value.ElemValueBuffer = undefined;
+                const elems = try sema.arena.alloc(Value, vec_len);
+                for (elems) |*elem, i| {
+                    const elem_val = val.elemValueBuffer(i, &elem_buf);
+                    elem.* = try elem_val.byteSwap(operand_ty, target, sema.arena);
+                }
+                return sema.addConstant(
+                    operand_ty,
+                    try Value.Tag.aggregate.create(sema.arena, elems),
+                );
+            } else operand_src;
+
+            try sema.requireRuntimeBlock(block, runtime_src);
+            return block.addTyOp(.byte_swap, operand_ty, operand);
+        },
+        else => unreachable,
+    }
 }
 
 fn zirBitReverse(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
test/behavior/byteswap.zig
@@ -52,32 +52,77 @@ test "@byteSwap integers" {
     try ByteSwapIntTest.run();
 }
 
-test "@byteSwap vectors" {
-    if (builtin.zig_backend == .stage2_llvm) return error.SkipZigTest;
+fn vector8() !void {
+    var v = @Vector(2, u8){ 0x12, 0x13 };
+    var result = @byteSwap(u8, v);
+    try expect(result[0] == 0x12);
+    try expect(result[1] == 0x13);
+}
+
+test "@byteSwap vectors u8" {
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
 
-    const ByteSwapVectorTest = struct {
-        fn run() !void {
-            try t(u8, 2, [_]u8{ 0x12, 0x13 }, [_]u8{ 0x12, 0x13 });
-            try t(u16, 2, [_]u16{ 0x1234, 0x2345 }, [_]u16{ 0x3412, 0x4523 });
-            try t(u24, 2, [_]u24{ 0x123456, 0x234567 }, [_]u24{ 0x563412, 0x674523 });
-        }
+    comptime try vector8();
+    try vector8();
+}
 
-        fn t(
-            comptime I: type,
-            comptime n: comptime_int,
-            input: std.meta.Vector(n, I),
-            expected_vector: std.meta.Vector(n, I),
-        ) !void {
-            const actual_output: [n]I = @byteSwap(I, input);
-            const expected_output: [n]I = expected_vector;
-            try std.testing.expectEqual(expected_output, actual_output);
-        }
-    };
-    comptime try ByteSwapVectorTest.run();
-    try ByteSwapVectorTest.run();
+fn vector16() !void {
+    var v = @Vector(2, u16){ 0x1234, 0x2345 };
+    var result = @byteSwap(u16, v);
+    try expect(result[0] == 0x3412);
+    try expect(result[1] == 0x4523);
+}
+
+test "@byteSwap vectors u16" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
+
+    comptime try vector16();
+    try vector16();
+}
+
+fn vector24() !void {
+    var v = @Vector(2, u24){ 0x123456, 0x234567 };
+    var result = @byteSwap(u24, v);
+    try expect(result[0] == 0x563412);
+    try expect(result[1] == 0x674523);
+}
+
+test "@byteSwap vectors u24" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
+
+    comptime try vector24();
+    try vector24();
+}
+
+fn vector0() !void {
+    var v = @Vector(2, u0){ 0, 0 };
+    var result = @byteSwap(u0, v);
+    try expect(result[0] == 0);
+    try expect(result[1] == 0);
+}
+
+test "@byteSwap vectors u0" {
+    // TODO: vector initialization for @Vector(x, u0) currently fails.
+    if (builtin.zig_backend == .stage2_llvm) return error.SkipZigTest;
+
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
+
+    comptime try vector0();
+    try vector0();
 }