Commit cd46daf7d0

John Schmidt <john.schmidt.h@gmail.com>
2022-03-24 23:27:23
sema: coerce inputs to vectors in zirSelect
1 parent f47db0a
Changed files (2)
src
test
behavior
src/Sema.zig
@@ -14805,6 +14805,7 @@ fn analyzeShuffle(
 fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
     const extra = sema.code.extraData(Zir.Inst.Select, inst_data.payload_index).data;
+    const target = sema.mod.getTarget();
 
     const elem_ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
     const pred_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
@@ -14813,35 +14814,21 @@ fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
 
     const elem_ty = try sema.resolveType(block, elem_ty_src, extra.elem_type);
     try sema.checkVectorElemType(block, elem_ty_src, elem_ty);
-    const pred = sema.resolveInst(extra.pred);
-    const a = sema.resolveInst(extra.a);
-    const b = sema.resolveInst(extra.b);
-    const target = sema.mod.getTarget();
-
-    const pred_ty = sema.typeOf(pred);
-    switch (try pred_ty.zigTypeTagOrPoison()) {
-        .Vector => {
-            const scalar_ty = pred_ty.childType();
-            if (!scalar_ty.eql(Type.bool, target)) {
-                const bool_vec_ty = try Type.vector(sema.arena, pred_ty.vectorLen(), Type.bool);
-                return sema.fail(block, pred_src, "Expected '{}', found '{}'", .{ bool_vec_ty.fmt(target), pred_ty.fmt(target) });
-            }
-        },
-        else => return sema.fail(block, pred_src, "Expected vector type, found '{}'", .{pred_ty.fmt(target)}),
-    }
+    const pred_uncoerced = sema.resolveInst(extra.pred);
+    const pred_ty = sema.typeOf(pred_uncoerced);
 
-    const vec_len = pred_ty.vectorLen();
-    const vec_ty = try Type.vector(sema.arena, vec_len, elem_ty);
+    const vec_len_u64 = switch (try pred_ty.zigTypeTagOrPoison()) {
+        .Vector, .Array => pred_ty.arrayLen(),
+        else => return sema.fail(block, pred_src, "expected vector or array, found '{}'", .{pred_ty.fmt(target)}),
+    };
+    const vec_len = try sema.usizeCast(block, pred_src, vec_len_u64);
 
-    const a_ty = sema.typeOf(a);
-    if (!a_ty.eql(vec_ty, target)) {
-        return sema.fail(block, a_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), a_ty.fmt(target) });
-    }
+    const bool_vec_ty = try Type.vector(sema.arena, vec_len, Type.bool);
+    const pred = try sema.coerce(block, bool_vec_ty, pred_uncoerced, pred_src);
 
-    const b_ty = sema.typeOf(b);
-    if (!b_ty.eql(vec_ty, target)) {
-        return sema.fail(block, b_src, "Expected '{}', found '{}'", .{ vec_ty.fmt(target), b_ty.fmt(target) });
-    }
+    const vec_ty = try Type.vector(sema.arena, vec_len, elem_ty);
+    const a = try sema.coerce(block, vec_ty, sema.resolveInst(extra.a), a_src);
+    const b = try sema.coerce(block, vec_ty, sema.resolveInst(extra.b), b_src);
 
     const maybe_pred = try sema.resolveMaybeUndefVal(block, pred_src, pred);
     const maybe_a = try sema.resolveMaybeUndefVal(block, a_src, a);
test/behavior/select.zig
@@ -3,18 +3,18 @@ const builtin = @import("builtin");
 const mem = std.mem;
 const expect = std.testing.expect;
 
-test "@select" {
+test "@select vectors" {
     if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
 
-    try doTheTest();
-    comptime try doTheTest();
+    comptime try selectVectors();
+    try selectVectors();
 }
 
-fn doTheTest() !void {
+fn selectVectors() !void {
     var a = @Vector(4, bool){ true, false, true, false };
     var b = @Vector(4, i32){ -1, 4, 999, -31 };
     var c = @Vector(4, i32){ -5, 1, 0, 1234 };
@@ -30,3 +30,32 @@ fn doTheTest() !void {
     var xyz = @select(f32, x, y, z);
     try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
 }
+
+test "@select arrays" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+    comptime try selectArrays();
+    try selectArrays();
+}
+
+fn selectArrays() !void {
+    var a = [4]bool{ false, true, false, true };
+    var b = [4]usize{ 0, 1, 2, 3 };
+    var c = [4]usize{ 4, 5, 6, 7 };
+    var abc = @select(usize, a, b, c);
+    try expect(abc[0] == 4);
+    try expect(abc[1] == 1);
+    try expect(abc[2] == 6);
+    try expect(abc[3] == 3);
+
+    var x = [4]bool{ false, false, false, true };
+    var y = [4]f32{ 0.001, 33.4, 836, -3381.233 };
+    var z = [4]f32{ 0.0, 312.1, -145.9, 9993.55 };
+    var xyz = @select(f32, x, y, z);
+    try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
+}