Commit b425d88737

Guillaume Wenzek <gwenzek@users.noreply.github.com>
2022-10-04 07:31:36
re-enable nvptx tests
1 parent 577f0aa
Changed files (2)
test
test/stage2/nvptx.zig
@@ -23,11 +23,10 @@ pub fn addCases(ctx: *TestContext) !void {
         var case = addPtx(ctx, "nvptx: read special registers");
 
         case.compiles(
-            \\fn threadIdX() usize {
-            \\     var tid = asm volatile ("mov.u32 \t$0, %tid.x;"
-            \\         : [ret] "=r" (-> u32),
-            \\     );
-            \\     return @as(usize, tid);
+            \\fn threadIdX() u32 {
+            \\    return asm ("mov.u32 \t%[r], %tid.x;"
+            \\       : [r] "=r" (-> utid),
+            \\    );
             \\}
             \\
             \\pub export fn special_reg(a: []const i32, out: []i32) callconv(.PtxKernel) void {
@@ -49,6 +48,38 @@ pub fn addCases(ctx: *TestContext) !void {
             \\}
         );
     }
+
+    {
+        var case = addPtx(ctx, "nvptx: reduce in shared mem");
+        case.compiles(
+            \\fn threadIdX() u32 {
+            \\    return asm ("mov.u32 \t%[r], %tid.x;"
+            \\       : [r] "=r" (-> utid),
+            \\    );
+            \\}
+            \\
+            \\ var _sdata: [1024]f32 addrspace(.shared) = undefined;
+            \\ pub export fn reduceSum(d_x: []const f32, out: *f32) callconv(ptx.Kernel) void {
+            \\     var sdata = @addrSpaceCast(.generic, &_sdata);
+            \\     const tid: u32 = threadIdX();
+            \\     var sum = d_x[tid];
+            \\     sdata[tid] = sum;
+            \\     asm volatile ("bar.sync \t0;");
+            \\     var s: u32 = 512;
+            \\     while (s > 0) : (s = s >> 1) {
+            \\         if (tid < s) {
+            \\             sum += sdata[tid + s];
+            \\             sdata[tid] = sum;
+            \\         }
+            \\         asm volatile ("bar.sync \t0;");
+            \\     }
+            \\
+            \\     if (tid == 0) {
+            \\         out.* = sum;
+            \\     }
+            \\ }
+        );
+    }
 }
 
 const nvptx_target = std.zig.CrossTarget{
test/cases.zig
@@ -4,6 +4,5 @@ const TestContext = @import("../src/test.zig").TestContext;
 pub fn addCases(ctx: *TestContext) !void {
     try @import("compile_errors.zig").addCases(ctx);
     try @import("stage2/cbe.zig").addCases(ctx);
-    // https://github.com/ziglang/zig/issues/10968
-    //try @import("stage2/nvptx.zig").addCases(ctx);
+    try @import("stage2/nvptx.zig").addCases(ctx);
 }