Commit 5e3c0b7af7

Daniel Kongsgaard <dakongsgaard@gmail.com>
2025-06-13 00:16:23
Allow more operators on bool vectors (#24131)
* Sema: allow binary operations and boolean not on vectors of bool * langref: Clarify use of operators on vectors (`and` and `or` not allowed) closes #24093
1 parent 4a02e08
Changed files (5)
doc/langref.html.in
@@ -1926,8 +1926,10 @@ or
       Vector types are created with the builtin function {#link|@Vector#}.
       </p>
       <p>
-      Vectors support the same builtin operators as their underlying base types.
-      These operations are performed element-wise, and return a vector of the same length
+      Vectors generally support the same builtin operators as their underlying base types.
+      The only exception to this is the keywords `and` and `or` on vectors of bools, since
+      these operators affect control flow, which is not allowed for vectors.
+      All other operations are performed element-wise, and return a vector of the same length
       as the input vectors. This includes:
       </p>
       <ul>
@@ -1937,6 +1939,7 @@ or
           <li>Bitwise operators ({#syntax#}>>{#endsyntax#}, {#syntax#}<<{#endsyntax#}, {#syntax#}&{#endsyntax#},
                                  {#syntax#}|{#endsyntax#}, {#syntax#}~{#endsyntax#}, etc.)</li>
           <li>Comparison operators ({#syntax#}<{#endsyntax#}, {#syntax#}>{#endsyntax#}, {#syntax#}=={#endsyntax#}, etc.)</li>
+          <li>Boolean not ({#syntax#}!{#endsyntax#})</li>
       </ul>
       <p>
       It is prohibited to use a math operator on a mixture of scalars (individual numbers)
lib/std/zig/AstGen.zig
@@ -806,7 +806,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
         .bool_and => return boolBinOp(gz, scope, ri, node, .bool_br_and),
         .bool_or  => return boolBinOp(gz, scope, ri, node, .bool_br_or),
 
-        .bool_not => return simpleUnOp(gz, scope, ri, node, coerced_bool_ri, tree.nodeData(node).node, .bool_not),
+        .bool_not => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, tree.nodeData(node).node, .bool_not),
         .bit_not  => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, tree.nodeData(node).node, .bit_not),
 
         .negation      => return   negation(gz, scope, ri, node),
src/Sema.zig
@@ -1171,11 +1171,11 @@ fn analyzeBodyInner(
             .as_node                      => try sema.zirAsNode(block, inst),
             .as_shift_operand             => try sema.zirAsShiftOperand(block, inst),
             .bit_and                      => try sema.zirBitwise(block, inst, .bit_and),
-            .bit_not                      => try sema.zirBitNot(block, inst),
+            .bit_not                      => try sema.zirBitNot(block, inst, false),
             .bit_or                       => try sema.zirBitwise(block, inst, .bit_or),
             .bitcast                      => try sema.zirBitcast(block, inst),
             .suspend_block                => try sema.zirSuspendBlock(block, inst),
-            .bool_not                     => try sema.zirBoolNot(block, inst),
+            .bool_not                     => try sema.zirBitNot(block, inst, true),
             .bool_br_and                  => try sema.zirBoolBr(block, inst, false),
             .bool_br_or                   => try sema.zirBoolBr(block, inst, true),
             .c_import                     => try sema.zirCImport(block, inst),
@@ -14412,9 +14412,9 @@ fn zirBitwise(
     const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src);
     const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src);
 
-    const is_int = scalar_tag == .int or scalar_tag == .comptime_int;
+    const is_int_or_bool = scalar_tag == .int or scalar_tag == .comptime_int or scalar_tag == .bool;
 
-    if (!is_int) {
+    if (!is_int_or_bool) {
         return sema.fail(block, src, "invalid operands to binary bitwise expression: '{s}' and '{s}'", .{ @tagName(lhs_ty.zigTypeTag(zcu)), @tagName(rhs_ty.zigTypeTag(zcu)) });
     }
 
@@ -14442,7 +14442,12 @@ fn zirBitwise(
     return block.addBinOp(air_tag, casted_lhs, casted_rhs);
 }
 
-fn zirBitNot(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
+fn zirBitNot(
+    sema: *Sema,
+    block: *Block,
+    inst: Zir.Inst.Index,
+    is_bool_not: bool,
+) CompileError!Air.Inst.Ref {
     const tracy = trace(@src());
     defer tracy.end();
 
@@ -14455,10 +14460,14 @@ fn zirBitNot(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.
     const operand = try sema.resolveInst(inst_data.operand);
     const operand_type = sema.typeOf(operand);
     const scalar_type = operand_type.scalarType(zcu);
+    const scalar_tag = scalar_type.zigTypeTag(zcu);
 
-    if (scalar_type.zigTypeTag(zcu) != .int) {
-        return sema.fail(block, src, "unable to perform binary not operation on type '{}'", .{
-            operand_type.fmt(pt),
+    const is_finite_int_or_bool = scalar_tag == .int or scalar_tag == .bool;
+    const is_allowed_type = if (is_bool_not) scalar_tag == .bool else is_finite_int_or_bool;
+
+    if (!is_allowed_type) {
+        return sema.fail(block, src, "unable to perform {s} not operation on type '{}'", .{
+            if (is_bool_not) "boolean" else "binary", operand_type.fmt(pt),
         });
     }
 
@@ -18336,25 +18345,6 @@ fn zirTypeofPeer(
     return Air.internedToRef(result_type.toIntern());
 }
 
-fn zirBoolNot(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
-    const tracy = trace(@src());
-    defer tracy.end();
-
-    const pt = sema.pt;
-    const zcu = pt.zcu;
-    const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
-    const src = block.nodeOffset(inst_data.src_node);
-    const operand_src = block.src(.{ .node_offset_un_op = inst_data.src_node });
-    const uncasted_operand = try sema.resolveInst(inst_data.operand);
-
-    const operand = try sema.coerce(block, .bool, uncasted_operand, operand_src);
-    if (try sema.resolveValue(operand)) |val| {
-        return if (val.isUndef(zcu)) .undef_bool else if (val.toBool()) .bool_false else .bool_true;
-    }
-    try sema.requireRuntimeBlock(block, src, null);
-    return block.addTyOp(.not, .bool, operand);
-}
-
 fn zirBoolBr(
     sema: *Sema,
     parent_block: *Block,
src/Value.zig
@@ -1627,7 +1627,7 @@ pub fn numberMin(lhs: Value, rhs: Value, zcu: *Zcu) Value {
     };
 }
 
-/// operands must be (vectors of) integers; handles undefined scalars.
+/// operands must be (vectors of) integers or bools; handles undefined scalars.
 pub fn bitwiseNot(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     if (ty.zigTypeTag(zcu) == .vector) {
@@ -1645,7 +1645,7 @@ pub fn bitwiseNot(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Va
     return bitwiseNotScalar(val, ty, arena, pt);
 }
 
-/// operands must be integers; handles undefined.
+/// operands must be integers or bools; handles undefined.
 pub fn bitwiseNotScalar(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     if (val.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() }));
@@ -1671,7 +1671,7 @@ pub fn bitwiseNotScalar(val: Value, ty: Type, arena: Allocator, pt: Zcu.PerThrea
     return pt.intValue_big(ty, result_bigint.toConst());
 }
 
-/// operands must be (vectors of) integers; handles undefined scalars.
+/// operands must be (vectors of) integers or bools; handles undefined scalars.
 pub fn bitwiseAnd(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     if (ty.zigTypeTag(zcu) == .vector) {
@@ -1690,7 +1690,7 @@ pub fn bitwiseAnd(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zc
     return bitwiseAndScalar(lhs, rhs, ty, allocator, pt);
 }
 
-/// operands must be integers; handles undefined.
+/// operands must be integers or bools; handles undefined.
 pub fn bitwiseAndScalar(orig_lhs: Value, orig_rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     // If one operand is defined, we turn the other into `0xAA` so the bitwise AND can
@@ -1744,7 +1744,7 @@ fn intValueAa(ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
     return pt.intValue_big(ty, result_bigint.toConst());
 }
 
-/// operands must be (vectors of) integers; handles undefined scalars.
+/// operands must be (vectors of) integers or bools; handles undefined scalars.
 pub fn bitwiseNand(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     if (ty.zigTypeTag(zcu) == .vector) {
@@ -1763,7 +1763,7 @@ pub fn bitwiseNand(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.P
     return bitwiseNandScalar(lhs, rhs, ty, arena, pt);
 }
 
-/// operands must be integers; handles undefined.
+/// operands must be integers or bools; handles undefined.
 pub fn bitwiseNandScalar(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     if (lhs.isUndef(zcu) or rhs.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() }));
@@ -1774,7 +1774,7 @@ pub fn bitwiseNandScalar(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt:
     return bitwiseXor(anded, all_ones, ty, arena, pt);
 }
 
-/// operands must be (vectors of) integers; handles undefined scalars.
+/// operands must be (vectors of) integers or bools; handles undefined scalars.
 pub fn bitwiseOr(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     if (ty.zigTypeTag(zcu) == .vector) {
@@ -1793,7 +1793,7 @@ pub fn bitwiseOr(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu
     return bitwiseOrScalar(lhs, rhs, ty, allocator, pt);
 }
 
-/// operands must be integers; handles undefined.
+/// operands must be integers or bools; handles undefined.
 pub fn bitwiseOrScalar(orig_lhs: Value, orig_rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
     // If one operand is defined, we turn the other into `0xAA` so the bitwise AND can
     // still zero out some bits.
@@ -1827,7 +1827,7 @@ pub fn bitwiseOrScalar(orig_lhs: Value, orig_rhs: Value, ty: Type, arena: Alloca
     return pt.intValue_big(ty, result_bigint.toConst());
 }
 
-/// operands must be (vectors of) integers; handles undefined scalars.
+/// operands must be (vectors of) integers or bools; handles undefined scalars.
 pub fn bitwiseXor(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     if (ty.zigTypeTag(zcu) == .vector) {
@@ -1846,7 +1846,7 @@ pub fn bitwiseXor(lhs: Value, rhs: Value, ty: Type, allocator: Allocator, pt: Zc
     return bitwiseXorScalar(lhs, rhs, ty, allocator, pt);
 }
 
-/// operands must be integers; handles undefined.
+/// operands must be integers or bools; handles undefined.
 pub fn bitwiseXorScalar(lhs: Value, rhs: Value, ty: Type, arena: Allocator, pt: Zcu.PerThread) !Value {
     const zcu = pt.zcu;
     if (lhs.isUndef(zcu) or rhs.isUndef(zcu)) return Value.fromInterned(try pt.intern(.{ .undef = ty.toIntern() }));
test/behavior/vector.zig
@@ -152,12 +152,22 @@ test "vector bit operators" {
 
     const S = struct {
         fn doTheTest() !void {
-            var v: @Vector(4, u8) = [4]u8{ 0b10101010, 0b10101010, 0b10101010, 0b10101010 };
-            var x: @Vector(4, u8) = [4]u8{ 0b11110000, 0b00001111, 0b10101010, 0b01010101 };
-            _ = .{ &v, &x };
-            try expect(mem.eql(u8, &@as([4]u8, v ^ x), &[4]u8{ 0b01011010, 0b10100101, 0b00000000, 0b11111111 }));
-            try expect(mem.eql(u8, &@as([4]u8, v | x), &[4]u8{ 0b11111010, 0b10101111, 0b10101010, 0b11111111 }));
-            try expect(mem.eql(u8, &@as([4]u8, v & x), &[4]u8{ 0b10100000, 0b00001010, 0b10101010, 0b00000000 }));
+            {
+                var v: @Vector(4, bool) = [4]bool{ false, false, true, true };
+                var x: @Vector(4, bool) = [4]bool{ true, false, true, false };
+                _ = .{ &v, &x };
+                try expect(mem.eql(bool, &@as([4]bool, v ^ x), &[4]bool{ true, false, false, true }));
+                try expect(mem.eql(bool, &@as([4]bool, v | x), &[4]bool{ true, false, true, true }));
+                try expect(mem.eql(bool, &@as([4]bool, v & x), &[4]bool{ false, false, true, false }));
+            }
+            {
+                var v: @Vector(4, u8) = [4]u8{ 0b10101010, 0b10101010, 0b10101010, 0b10101010 };
+                var x: @Vector(4, u8) = [4]u8{ 0b11110000, 0b00001111, 0b10101010, 0b01010101 };
+                _ = .{ &v, &x };
+                try expect(mem.eql(u8, &@as([4]u8, v ^ x), &[4]u8{ 0b01011010, 0b10100101, 0b00000000, 0b11111111 }));
+                try expect(mem.eql(u8, &@as([4]u8, v | x), &[4]u8{ 0b11111010, 0b10101111, 0b10101010, 0b11111111 }));
+                try expect(mem.eql(u8, &@as([4]u8, v & x), &[4]u8{ 0b10100000, 0b00001010, 0b10101010, 0b00000000 }));
+            }
         }
     };
     try S.doTheTest();
@@ -659,15 +669,41 @@ test "vector bitwise not operator" {
             }
         }
         fn doTheTest() !void {
-            try doTheTestNot(u8, [_]u8{ 0, 2, 4, 255 });
-            try doTheTestNot(u16, [_]u16{ 0, 2, 4, 255 });
-            try doTheTestNot(u32, [_]u32{ 0, 2, 4, 255 });
-            try doTheTestNot(u64, [_]u64{ 0, 2, 4, 255 });
+            try doTheTestNot(bool, [_]bool{ true, false, true, false });
 
             try doTheTestNot(u8, [_]u8{ 0, 2, 4, 255 });
             try doTheTestNot(u16, [_]u16{ 0, 2, 4, 255 });
             try doTheTestNot(u32, [_]u32{ 0, 2, 4, 255 });
             try doTheTestNot(u64, [_]u64{ 0, 2, 4, 255 });
+
+            try doTheTestNot(i8, [_]i8{ 0, 2, 4, 127 });
+            try doTheTestNot(i16, [_]i16{ 0, 2, 4, 127 });
+            try doTheTestNot(i32, [_]i32{ 0, 2, 4, 127 });
+            try doTheTestNot(i64, [_]i64{ 0, 2, 4, 127 });
+        }
+    };
+
+    try S.doTheTest();
+    try comptime S.doTheTest();
+}
+
+test "vector boolean not operator" {
+    if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
+
+    const S = struct {
+        fn doTheTestNot(comptime T: type, x: @Vector(4, T)) !void {
+            const y = !x;
+            for (@as([4]T, y), 0..) |v, i| {
+                try expect(!x[i] == v);
+            }
+        }
+        fn doTheTest() !void {
+            try doTheTestNot(bool, [_]bool{ true, false, true, false });
         }
     };