Commit 5f5ab49168

Veikka Tuominen <git@vexu.eu>
2023-01-19 17:44:09
Value: implement `compareAllWithZero` for `bytes` and `str_lit`
Closes #10692
1 parent a492a60
Changed files (4)
src/Sema.zig
@@ -11842,7 +11842,7 @@ fn zirShl(
                 if (scalar_ty.zigTypeTag() == .ComptimeInt) {
                     break :val shifted.wrapped_result;
                 }
-                if (shifted.overflow_bit.compareAllWithZero(.eq)) {
+                if (shifted.overflow_bit.compareAllWithZero(.eq, sema.mod)) {
                     break :val shifted.wrapped_result;
                 }
                 return sema.fail(block, src, "operation caused overflow", .{});
@@ -12831,7 +12831,7 @@ fn zirDiv(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Ins
         const lhs_val = maybe_lhs_val orelse unreachable;
         const rhs_val = maybe_rhs_val orelse unreachable;
         const rem = lhs_val.floatRem(rhs_val, resolved_type, sema.arena, mod) catch unreachable;
-        if (!rem.compareAllWithZero(.eq)) {
+        if (!rem.compareAllWithZero(.eq, mod)) {
             return sema.fail(block, src, "ambiguous coercion of division operands '{s}' and '{s}'; non-zero remainder '{}'", .{
                 @tagName(lhs_ty.tag()), @tagName(rhs_ty.tag()), rem.fmtValue(resolved_type, sema.mod),
             });
@@ -13024,7 +13024,7 @@ fn zirDivExact(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
             if (maybe_rhs_val) |rhs_val| {
                 if (is_int) {
                     const modulus_val = try lhs_val.intMod(rhs_val, resolved_type, sema.arena, mod);
-                    if (!(modulus_val.compareAllWithZero(.eq))) {
+                    if (!(modulus_val.compareAllWithZero(.eq, mod))) {
                         return sema.fail(block, src, "exact division produced remainder", .{});
                     }
                     const res = try lhs_val.intDiv(rhs_val, resolved_type, sema.arena, mod);
@@ -13035,7 +13035,7 @@ fn zirDivExact(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
                     return sema.addConstant(resolved_type, res);
                 } else {
                     const modulus_val = try lhs_val.floatMod(rhs_val, resolved_type, sema.arena, mod);
-                    if (!(modulus_val.compareAllWithZero(.eq))) {
+                    if (!(modulus_val.compareAllWithZero(.eq, mod))) {
                         return sema.fail(block, src, "exact division produced remainder", .{});
                     }
                     return sema.addConstant(
src/type.zig
@@ -5533,7 +5533,7 @@ pub const Type = extern union {
         }
         const S = struct {
             fn fieldWithRange(int_ty: Type, int_val: Value, end: usize, m: *Module) ?usize {
-                if (int_val.compareAllWithZero(.lt)) return null;
+                if (int_val.compareAllWithZero(.lt, m)) return null;
                 var end_payload: Value.Payload.U64 = .{
                     .base = .{ .tag = .int_u64 },
                     .data = end,
@@ -6556,12 +6556,12 @@ pub const Type = extern union {
                 if (!d.mutable and d.pointee_type.eql(Type.u8, mod)) {
                     switch (d.size) {
                         .Slice => {
-                            if (sent.compareAllWithZero(.eq)) {
+                            if (sent.compareAllWithZero(.eq, mod)) {
                                 return Type.initTag(.const_slice_u8_sentinel_0);
                             }
                         },
                         .Many => {
-                            if (sent.compareAllWithZero(.eq)) {
+                            if (sent.compareAllWithZero(.eq, mod)) {
                                 return Type.initTag(.manyptr_const_u8_sentinel_0);
                             }
                         },
src/value.zig
@@ -2076,13 +2076,22 @@ pub const Value = extern union {
     /// For vectors, returns true if comparison is true for ALL elements.
     ///
     /// Note that `!compareAllWithZero(.eq, ...) != compareAllWithZero(.neq, ...)`
-    pub fn compareAllWithZero(lhs: Value, op: std.math.CompareOperator) bool {
-        return compareAllWithZeroAdvanced(lhs, op, null) catch unreachable;
+    pub fn compareAllWithZero(lhs: Value, op: std.math.CompareOperator, mod: *Module) bool {
+        return compareAllWithZeroAdvancedExtra(lhs, op, mod, null) catch unreachable;
     }
 
     pub fn compareAllWithZeroAdvanced(
         lhs: Value,
         op: std.math.CompareOperator,
+        sema: *Sema,
+    ) Module.CompileError!bool {
+        return compareAllWithZeroAdvancedExtra(lhs, op, sema.mod, sema);
+    }
+
+    pub fn compareAllWithZeroAdvancedExtra(
+        lhs: Value,
+        op: std.math.CompareOperator,
+        mod: *Module,
         opt_sema: ?*Sema,
     ) Module.CompileError!bool {
         if (lhs.isInf()) {
@@ -2095,10 +2104,25 @@ pub const Value = extern union {
         }
 
         switch (lhs.tag()) {
-            .repeated => return lhs.castTag(.repeated).?.data.compareAllWithZeroAdvanced(op, opt_sema),
+            .repeated => return lhs.castTag(.repeated).?.data.compareAllWithZeroAdvancedExtra(op, mod, opt_sema),
             .aggregate => {
                 for (lhs.castTag(.aggregate).?.data) |elem_val| {
-                    if (!(try elem_val.compareAllWithZeroAdvanced(op, opt_sema))) return false;
+                    if (!(try elem_val.compareAllWithZeroAdvancedExtra(op, mod, opt_sema))) return false;
+                }
+                return true;
+            },
+            .str_lit => {
+                const str_lit = lhs.castTag(.str_lit).?.data;
+                const bytes = mod.string_literal_bytes.items[str_lit.index..][0..str_lit.len];
+                for (bytes) |byte| {
+                    if (!std.math.compare(byte, op, 0)) return false;
+                }
+                return true;
+            },
+            .bytes => {
+                const bytes = lhs.castTag(.bytes).?.data;
+                for (bytes) |byte| {
+                    if (!std.math.compare(byte, op, 0)) return false;
                 }
                 return true;
             },
@@ -3103,7 +3127,7 @@ pub const Value = extern union {
             .int_i64,
             .int_big_positive,
             .int_big_negative,
-            => compareAllWithZero(self, .eq),
+            => self.orderAgainstZero().compare(.eq),
 
             .undef => unreachable,
             .unreachable_value => unreachable,
test/behavior/vector.zig
@@ -1286,3 +1286,14 @@ test "store to vector in slice" {
     s[i] = s[0];
     try expectEqual(v[1], v[0]);
 }
+
+test "addition of vectors represented as strings" {
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+
+    const V = @Vector(3, u8);
+    const foo: V = "foo".*;
+    const bar: V = @typeName(u32).*;
+    try expectEqual(V{ 219, 162, 161 }, foo + bar);
+}