Commit b6ccde47ad

Andrew Kelley <andrew@ziglang.org>
2022-03-28 23:17:05
Sema: allow mixing array and vector operands
* Added peer type resolution for arrays and vectors: the vector type is selected. * Fixed passing the lhs type or rhs type instead of the peer resolved type when calling Value methods during analyzeArithmetic handling of comptime expressions. * `checkVectorizableBinaryOperands` now allows mixing vectors and arrays, as long as one of the operands is a vector. This matches stage1's handling of `^=` but apparently stage1 is inconsistent and does not handle e.g. `*=`. stage2 now will always allow mixing vector and array operands for all operations.
1 parent 691c7cb
Changed files (2)
src
test
behavior
src/Sema.zig
@@ -9617,7 +9617,7 @@ fn analyzeArithmetic(
                     if (lhs_val.isUndef()) {
                         if (lhs_scalar_ty.isSignedInt() and rhs_scalar_ty.isSignedInt()) {
                             if (maybe_rhs_val) |rhs_val| {
-                                if (rhs_val.compare(.neq, Value.negative_one, rhs_ty, target)) {
+                                if (rhs_val.compare(.neq, Value.negative_one, resolved_type, target)) {
                                     return sema.addConstUndef(resolved_type);
                                 }
                             }
@@ -9692,7 +9692,7 @@ fn analyzeArithmetic(
                     if (lhs_val.isUndef()) {
                         if (lhs_scalar_ty.isSignedInt() and rhs_scalar_ty.isSignedInt()) {
                             if (maybe_rhs_val) |rhs_val| {
-                                if (rhs_val.compare(.neq, Value.negative_one, rhs_ty, target)) {
+                                if (rhs_val.compare(.neq, Value.negative_one, resolved_type, target)) {
                                     return sema.addConstUndef(resolved_type);
                                 }
                             }
@@ -9755,7 +9755,7 @@ fn analyzeArithmetic(
                     if (lhs_val.isUndef()) {
                         if (lhs_scalar_ty.isSignedInt() and rhs_scalar_ty.isSignedInt()) {
                             if (maybe_rhs_val) |rhs_val| {
-                                if (rhs_val.compare(.neq, Value.negative_one, rhs_ty, target)) {
+                                if (rhs_val.compare(.neq, Value.negative_one, resolved_type, target)) {
                                     return sema.addConstUndef(resolved_type);
                                 }
                             }
@@ -9845,7 +9845,7 @@ fn analyzeArithmetic(
                         if (lhs_val.compareWithZero(.eq)) {
                             return sema.addConstant(resolved_type, Value.zero);
                         }
-                        if (lhs_val.compare(.eq, Value.one, lhs_ty, target)) {
+                        if (lhs_val.compare(.eq, Value.one, resolved_type, target)) {
                             return casted_rhs;
                         }
                     }
@@ -9861,7 +9861,7 @@ fn analyzeArithmetic(
                     if (rhs_val.compareWithZero(.eq)) {
                         return sema.addConstant(resolved_type, Value.zero);
                     }
-                    if (rhs_val.compare(.eq, Value.one, rhs_ty, target)) {
+                    if (rhs_val.compare(.eq, Value.one, resolved_type, target)) {
                         return casted_lhs;
                     }
                     if (maybe_lhs_val) |lhs_val| {
@@ -9896,7 +9896,7 @@ fn analyzeArithmetic(
                         if (lhs_val.compareWithZero(.eq)) {
                             return sema.addConstant(resolved_type, Value.zero);
                         }
-                        if (lhs_val.compare(.eq, Value.one, lhs_ty, target)) {
+                        if (lhs_val.compare(.eq, Value.one, resolved_type, target)) {
                             return casted_rhs;
                         }
                     }
@@ -9908,7 +9908,7 @@ fn analyzeArithmetic(
                     if (rhs_val.compareWithZero(.eq)) {
                         return sema.addConstant(resolved_type, Value.zero);
                     }
-                    if (rhs_val.compare(.eq, Value.one, rhs_ty, target)) {
+                    if (rhs_val.compare(.eq, Value.one, resolved_type, target)) {
                         return casted_lhs;
                     }
                     if (maybe_lhs_val) |lhs_val| {
@@ -9932,7 +9932,7 @@ fn analyzeArithmetic(
                         if (lhs_val.compareWithZero(.eq)) {
                             return sema.addConstant(resolved_type, Value.zero);
                         }
-                        if (lhs_val.compare(.eq, Value.one, lhs_ty, target)) {
+                        if (lhs_val.compare(.eq, Value.one, resolved_type, target)) {
                             return casted_rhs;
                         }
                     }
@@ -9944,7 +9944,7 @@ fn analyzeArithmetic(
                     if (rhs_val.compareWithZero(.eq)) {
                         return sema.addConstant(resolved_type, Value.zero);
                     }
-                    if (rhs_val.compare(.eq, Value.one, rhs_ty, target)) {
+                    if (rhs_val.compare(.eq, Value.one, resolved_type, target)) {
                         return casted_lhs;
                     }
                     if (maybe_lhs_val) |lhs_val| {
@@ -14521,9 +14521,20 @@ fn checkVectorizableBinaryOperands(
 ) CompileError!void {
     const lhs_zig_ty_tag = try lhs_ty.zigTypeTagOrPoison();
     const rhs_zig_ty_tag = try rhs_ty.zigTypeTagOrPoison();
-    if (lhs_zig_ty_tag == .Vector and rhs_zig_ty_tag == .Vector) {
-        const lhs_len = lhs_ty.vectorLen();
-        const rhs_len = rhs_ty.vectorLen();
+    if (lhs_zig_ty_tag != .Vector and rhs_zig_ty_tag != .Vector) return;
+
+    const lhs_is_vector = switch (lhs_zig_ty_tag) {
+        .Vector, .Array => true,
+        else => false,
+    };
+    const rhs_is_vector = switch (rhs_zig_ty_tag) {
+        .Vector, .Array => true,
+        else => false,
+    };
+
+    if (lhs_is_vector and rhs_is_vector) {
+        const lhs_len = lhs_ty.arrayLen();
+        const rhs_len = rhs_ty.arrayLen();
         if (lhs_len != rhs_len) {
             const msg = msg: {
                 const msg = try sema.errMsg(block, src, "vector length mismatch", .{});
@@ -14534,14 +14545,14 @@ fn checkVectorizableBinaryOperands(
             };
             return sema.failWithOwnedErrorMsg(block, msg);
         }
-    } else if (lhs_zig_ty_tag == .Vector or rhs_zig_ty_tag == .Vector) {
+    } else {
         const target = sema.mod.getTarget();
         const msg = msg: {
             const msg = try sema.errMsg(block, src, "mixed scalar and vector operands: {} and {}", .{
                 lhs_ty.fmt(target), rhs_ty.fmt(target),
             });
             errdefer msg.destroy(sema.gpa);
-            if (lhs_zig_ty_tag == .Vector) {
+            if (lhs_is_vector) {
                 try sema.errNote(block, lhs_src, msg, "vector here", .{});
                 try sema.errNote(block, rhs_src, msg, "scalar here", .{});
             } else {
@@ -21017,6 +21028,18 @@ fn resolvePeerTypes(
                 chosen_i = candidate_i + 1;
                 continue;
             },
+            .Vector => switch (chosen_ty_tag) {
+                .Array => {
+                    chosen = candidate;
+                    chosen_i = candidate_i + 1;
+                    continue;
+                },
+                else => {},
+            },
+            .Array => switch (chosen_ty_tag) {
+                .Vector => continue,
+                else => {},
+            },
             else => {},
         }
 
test/behavior/vector.zig
@@ -879,3 +879,27 @@ test "saturating shift-left" {
     try S.doTheTest();
     comptime try S.doTheTest();
 }
+
+test "multiplication-assignment operator with an array operand" {
+    if (builtin.zig_backend == .stage1) {
+        // stage1 emits a compile error
+        return error.SkipZigTest;
+    }
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        fn doTheTest() !void {
+            var x: @Vector(3, i32) = .{ 1, 2, 3 };
+            x *= [_]i32{ 4, 5, 6 };
+            try expect(x[0] == 4);
+            try expect(x[1] == 10);
+            try expect(x[2] == 18);
+        }
+    };
+    try S.doTheTest();
+    comptime try S.doTheTest();
+}