Commit 9031cc54f2

Andrew Kelley <andrew@ziglang.org>
2022-05-18 01:51:35
Sema: implement `@intCast` for vectors
1 parent 691fba3
Changed files (2)
src
test
behavior
src/Sema.zig
@@ -7020,22 +7020,25 @@ fn intCast(
     operand_src: LazySrcLoc,
     runtime_safety: bool,
 ) CompileError!Air.Inst.Ref {
-    // TODO: Add support for vectors
-    const dest_is_comptime_int = try sema.checkIntType(block, dest_ty_src, dest_ty);
-    _ = try sema.checkIntType(block, operand_src, sema.typeOf(operand));
+    const operand_ty = sema.typeOf(operand);
+    const dest_scalar_ty = try sema.checkIntOrVectorAllowComptime(block, dest_ty, dest_ty_src);
+    const operand_scalar_ty = try sema.checkIntOrVectorAllowComptime(block, operand_ty, operand_src);
 
     if (try sema.isComptimeKnown(block, operand_src, operand)) {
         return sema.coerce(block, dest_ty, operand, operand_src);
-    } else if (dest_is_comptime_int) {
+    } else if (dest_scalar_ty.zigTypeTag() == .ComptimeInt) {
         return sema.fail(block, operand_src, "unable to cast runtime value to 'comptime_int'", .{});
     }
 
+    try sema.checkVectorizableBinaryOperands(block, operand_src, dest_ty, operand_ty, dest_ty_src, operand_src);
+    const is_vector = dest_ty.zigTypeTag() == .Vector;
+
     if ((try sema.typeHasOnePossibleValue(block, dest_ty_src, dest_ty))) |opv| {
         // requirement: intCast(u0, input) iff input == 0
         if (runtime_safety and block.wantSafety()) {
             try sema.requireRuntimeBlock(block, operand_src);
             const target = sema.mod.getTarget();
-            const wanted_info = dest_ty.intInfo(target);
+            const wanted_info = dest_scalar_ty.intInfo(target);
             const wanted_bits = wanted_info.bits;
 
             if (wanted_bits == 0) {
@@ -7051,9 +7054,8 @@ fn intCast(
     try sema.requireRuntimeBlock(block, operand_src);
     if (runtime_safety and block.wantSafety()) {
         const target = sema.mod.getTarget();
-        const operand_ty = sema.typeOf(operand);
-        const actual_info = operand_ty.intInfo(target);
-        const wanted_info = dest_ty.intInfo(target);
+        const actual_info = operand_scalar_ty.intInfo(target);
+        const wanted_info = dest_scalar_ty.intInfo(target);
         const actual_bits = actual_info.bits;
         const wanted_bits = wanted_info.bits;
         const actual_value_bits = actual_bits - @boolToInt(actual_info.signedness == .signed);
@@ -7062,7 +7064,11 @@ fn intCast(
         // range shrinkage
         // requirement: int value fits into target type
         if (wanted_value_bits < actual_value_bits) {
-            const dest_max_val = try dest_ty.maxInt(sema.arena, target);
+            const dest_max_val_scalar = try dest_scalar_ty.maxInt(sema.arena, target);
+            const dest_max_val = if (is_vector)
+                try Value.Tag.repeated.create(sema.arena, dest_max_val_scalar)
+            else
+                dest_max_val_scalar;
             const dest_max = try sema.addConstant(operand_ty, dest_max_val);
             const diff = try block.addBinOp(.subwrap, dest_max, operand);
 
@@ -7080,19 +7086,59 @@ fn intCast(
                 } else dest_max_val;
                 const dest_range = try sema.addConstant(unsigned_operand_ty, dest_range_val);
 
-                const is_in_range = try block.addBinOp(.cmp_lte, diff_unsigned, dest_range);
-                try sema.addSafetyCheck(block, is_in_range, .cast_truncated_data);
+                const ok = if (is_vector) ok: {
+                    const is_in_range = try block.addCmpVector(diff_unsigned, dest_range, .lte, try sema.addType(operand_ty));
+                    const all_in_range = try block.addInst(.{
+                        .tag = .reduce,
+                        .data = .{ .reduce = .{
+                            .operand = is_in_range,
+                            .operation = .And,
+                        } },
+                    });
+                    break :ok all_in_range;
+                } else ok: {
+                    const is_in_range = try block.addBinOp(.cmp_lte, diff_unsigned, dest_range);
+                    break :ok is_in_range;
+                };
+                try sema.addSafetyCheck(block, ok, .cast_truncated_data);
             } else {
-                const is_in_range = try block.addBinOp(.cmp_lte, diff, dest_max);
-                try sema.addSafetyCheck(block, is_in_range, .cast_truncated_data);
+                const ok = if (is_vector) ok: {
+                    const is_in_range = try block.addCmpVector(diff, dest_max, .lte, try sema.addType(operand_ty));
+                    const all_in_range = try block.addInst(.{
+                        .tag = .reduce,
+                        .data = .{ .reduce = .{
+                            .operand = is_in_range,
+                            .operation = .And,
+                        } },
+                    });
+                    break :ok all_in_range;
+                } else ok: {
+                    const is_in_range = try block.addBinOp(.cmp_lte, diff, dest_max);
+                    break :ok is_in_range;
+                };
+                try sema.addSafetyCheck(block, ok, .cast_truncated_data);
             }
-        }
-        // no shrinkage, yes sign loss
-        // requirement: signed to unsigned >= 0
-        else if (actual_info.signedness == .signed and wanted_info.signedness == .unsigned) {
-            const zero_inst = try sema.addConstant(operand_ty, Value.zero);
-            const is_in_range = try block.addBinOp(.cmp_gte, operand, zero_inst);
-            try sema.addSafetyCheck(block, is_in_range, .cast_truncated_data);
+        } else if (actual_info.signedness == .signed and wanted_info.signedness == .unsigned) {
+            // no shrinkage, yes sign loss
+            // requirement: signed to unsigned >= 0
+            const ok = if (is_vector) ok: {
+                const zero_val = try Value.Tag.repeated.create(sema.arena, Value.zero);
+                const zero_inst = try sema.addConstant(operand_ty, zero_val);
+                const is_in_range = try block.addCmpVector(operand, zero_inst, .lte, try sema.addType(operand_ty));
+                const all_in_range = try block.addInst(.{
+                    .tag = .reduce,
+                    .data = .{ .reduce = .{
+                        .operand = is_in_range,
+                        .operation = .And,
+                    } },
+                });
+                break :ok all_in_range;
+            } else ok: {
+                const zero_inst = try sema.addConstant(operand_ty, Value.zero);
+                const is_in_range = try block.addBinOp(.cmp_gte, operand, zero_inst);
+                break :ok is_in_range;
+            };
+            try sema.addSafetyCheck(block, ok, .cast_truncated_data);
         }
     }
     return block.addTyOp(.intcast, dest_ty, operand);
@@ -14517,8 +14563,8 @@ fn zirTruncate(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
     const dest_scalar_ty = try sema.resolveType(block, dest_ty_src, extra.lhs);
     const operand = sema.resolveInst(extra.rhs);
     const dest_is_comptime_int = try sema.checkIntType(block, dest_ty_src, dest_scalar_ty);
-    const operand_scalar_ty = try sema.checkIntOrVectorAllowComptime(block, operand, operand_src);
     const operand_ty = sema.typeOf(operand);
+    const operand_scalar_ty = try sema.checkIntOrVectorAllowComptime(block, operand_ty, operand_src);
     const is_vector = operand_ty.zigTypeTag() == .Vector;
     const dest_ty = if (is_vector)
         try Type.vector(sema.arena, operand_ty.vectorLen(), dest_scalar_ty)
@@ -14686,7 +14732,7 @@ fn zirByteSwap(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
     const operand = sema.resolveInst(inst_data.operand);
     const operand_ty = sema.typeOf(operand);
-    const scalar_ty = try sema.checkIntOrVectorAllowComptime(block, operand, operand_src);
+    const scalar_ty = try sema.checkIntOrVectorAllowComptime(block, operand_ty, operand_src);
     const target = sema.mod.getTarget();
     const bits = scalar_ty.intInfo(target).bits;
     if (bits % 8 != 0) {
@@ -14743,7 +14789,7 @@ fn zirBitReverse(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
     const operand = sema.resolveInst(inst_data.operand);
     const operand_ty = sema.typeOf(operand);
-    _ = try sema.checkIntOrVectorAllowComptime(block, operand, operand_src);
+    _ = try sema.checkIntOrVectorAllowComptime(block, operand_ty, operand_src);
 
     if (try sema.typeHasOnePossibleValue(block, operand_src, operand_ty)) |val| {
         return sema.addConstant(operand_ty, val);
@@ -15095,10 +15141,9 @@ fn checkIntOrVector(
 fn checkIntOrVectorAllowComptime(
     sema: *Sema,
     block: *Block,
-    operand: Air.Inst.Ref,
+    operand_ty: Type,
     operand_src: LazySrcLoc,
 ) CompileError!Type {
-    const operand_ty = sema.typeOf(operand);
     switch (try operand_ty.zigTypeTagOrPoison()) {
         .Int, .ComptimeInt => return operand_ty,
         .Vector => {
test/behavior/cast.zig
@@ -587,7 +587,11 @@ test "cast *[1][*]const u8 to [*]const ?[*]const u8" {
 }
 
 test "vector casts" {
-    if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
 
     const S = struct {
         fn doTheTest() !void {