Commit 93d696e84e

Jacob Young <jacobly0@users.noreply.github.com>
2023-03-03 07:18:23
CBE: implement some big integer and vector unary operations
1 parent a8f4ac2
Changed files (5)
lib
src
codegen
test
lib/zig.h
@@ -1919,7 +1919,7 @@ static inline zig_i128 zig_bit_reverse_i128(zig_i128 val, uint8_t bits) {
 
 /* ========================== Big Integer Support =========================== */
 
-static inline uint16_t zig_big_bytes(uint16_t bits) {
+static inline uint16_t zig_int_bytes(uint16_t bits) {
     uint16_t bytes = (bits + CHAR_BIT - 1) / CHAR_BIT;
     uint16_t alignment = 16;
     while (alignment / 2 >= bytes) alignment /= 2;
@@ -1931,7 +1931,7 @@ static inline int32_t zig_cmp_big(const void *lhs, const void *rhs, bool is_sign
     const uint8_t *rhs_bytes = rhs;
     uint16_t byte_offset = 0;
     bool do_signed = is_signed;
-    uint16_t remaining_bytes = zig_big_bytes(bits);
+    uint16_t remaining_bytes = zig_int_bytes(bits);
 
 #if zig_little_endian
     byte_offset = remaining_bytes;
@@ -1965,7 +1965,7 @@ static inline int32_t zig_cmp_big(const void *lhs, const void *rhs, bool is_sign
         remaining_bytes -= 128 / CHAR_BIT;
 
 #if zig_big_endian
-        byte_offset -= 128 / CHAR_BIT;
+        byte_offset += 128 / CHAR_BIT;
 #endif
     }
 
@@ -1994,7 +1994,7 @@ static inline int32_t zig_cmp_big(const void *lhs, const void *rhs, bool is_sign
         remaining_bytes -= 64 / CHAR_BIT;
 
 #if zig_big_endian
-        byte_offset -= 64 / CHAR_BIT;
+        byte_offset += 64 / CHAR_BIT;
 #endif
     }
 
@@ -2023,7 +2023,7 @@ static inline int32_t zig_cmp_big(const void *lhs, const void *rhs, bool is_sign
         remaining_bytes -= 32 / CHAR_BIT;
 
 #if zig_big_endian
-        byte_offset -= 32 / CHAR_BIT;
+        byte_offset += 32 / CHAR_BIT;
 #endif
     }
 
@@ -2052,7 +2052,7 @@ static inline int32_t zig_cmp_big(const void *lhs, const void *rhs, bool is_sign
         remaining_bytes -= 16 / CHAR_BIT;
 
 #if zig_big_endian
-        byte_offset -= 16 / CHAR_BIT;
+        byte_offset += 16 / CHAR_BIT;
 #endif
     }
 
@@ -2081,13 +2081,368 @@ static inline int32_t zig_cmp_big(const void *lhs, const void *rhs, bool is_sign
         remaining_bytes -= 8 / CHAR_BIT;
 
 #if zig_big_endian
-        byte_offset -= 8 / CHAR_BIT;
+        byte_offset += 8 / CHAR_BIT;
 #endif
     }
 
     return 0;
 }
 
+static inline uint16_t zig_clz_big(const void *val, bool is_signed, uint16_t bits) {
+    const uint8_t *val_bytes = val;
+    uint16_t byte_offset = 0;
+    uint16_t remaining_bytes = zig_int_bytes(bits);
+    uint16_t skip_bits = remaining_bytes * 8 - bits;
+    uint16_t total_lz = 0;
+    uint16_t limb_lz;
+    (void)is_signed;
+
+#if zig_little_endian
+    byte_offset = remaining_bytes;
+#endif
+
+    while (remaining_bytes >= 128 / CHAR_BIT) {
+#if zig_little_endian
+        byte_offset -= 128 / CHAR_BIT;
+#endif
+
+        {
+            zig_u128 val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_lz = zig_clz_u128(val_limb, 128 - skip_bits);
+        }
+
+        total_lz += limb_lz;
+        if (limb_lz < 128 - skip_bits) return total_lz;
+        skip_bits = 0;
+        remaining_bytes -= 128 / CHAR_BIT;
+
+#if zig_big_endian
+        byte_offset += 128 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 64 / CHAR_BIT) {
+#if zig_little_endian
+        byte_offset -= 64 / CHAR_BIT;
+#endif
+
+        {
+            uint64_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_lz = zig_clz_u64(val_limb, 64 - skip_bits);
+        }
+
+        total_lz += limb_lz;
+        if (limb_lz < 64 - skip_bits) return total_lz;
+        skip_bits = 0;
+        remaining_bytes -= 64 / CHAR_BIT;
+
+#if zig_big_endian
+        byte_offset += 64 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 32 / CHAR_BIT) {
+#if zig_little_endian
+        byte_offset -= 32 / CHAR_BIT;
+#endif
+
+        {
+            uint32_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_lz = zig_clz_u32(val_limb, 32 - skip_bits);
+        }
+
+        total_lz += limb_lz;
+        if (limb_lz < 32 - skip_bits) return total_lz;
+        skip_bits = 0;
+        remaining_bytes -= 32 / CHAR_BIT;
+
+#if zig_big_endian
+        byte_offset += 32 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 16 / CHAR_BIT) {
+#if zig_little_endian
+        byte_offset -= 16 / CHAR_BIT;
+#endif
+
+        {
+            uint16_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_lz = zig_clz_u16(val_limb, 16 - skip_bits);
+        }
+
+        total_lz += limb_lz;
+        if (limb_lz < 16 - skip_bits) return total_lz;
+        skip_bits = 0;
+        remaining_bytes -= 16 / CHAR_BIT;
+
+#if zig_big_endian
+        byte_offset += 16 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 8 / CHAR_BIT) {
+#if zig_little_endian
+        byte_offset -= 8 / CHAR_BIT;
+#endif
+
+        {
+            uint8_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_lz = zig_clz_u8(val_limb, 8 - skip_bits);
+        }
+
+        total_lz += limb_lz;
+        if (limb_lz < 8 - skip_bits) return total_lz;
+        skip_bits = 0;
+        remaining_bytes -= 8 / CHAR_BIT;
+
+#if zig_big_endian
+        byte_offset += 8 / CHAR_BIT;
+#endif
+    }
+
+    return total_lz;
+}
+
+static inline uint16_t zig_ctz_big(const void *val, bool is_signed, uint16_t bits) {
+    const uint8_t *val_bytes = val;
+    uint16_t byte_offset = 0;
+    uint16_t remaining_bytes = zig_int_bytes(bits);
+    uint16_t total_tz = 0;
+    uint16_t limb_tz;
+    (void)is_signed;
+
+#if zig_big_endian
+    byte_offset = remaining_bytes;
+#endif
+
+    while (remaining_bytes >= 128 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 128 / CHAR_BIT;
+#endif
+
+        {
+            zig_u128 val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_tz = zig_ctz_u128(val_limb, 128);
+        }
+
+        total_tz += limb_tz;
+        if (limb_tz < 128) return total_tz;
+        remaining_bytes -= 128 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 128 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 64 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 64 / CHAR_BIT;
+#endif
+
+        {
+            uint64_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_tz = zig_ctz_u64(val_limb, 64);
+        }
+
+        total_tz += limb_tz;
+        if (limb_tz < 64) return total_tz;
+        remaining_bytes -= 64 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 64 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 32 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 32 / CHAR_BIT;
+#endif
+
+        {
+            uint32_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_tz = zig_ctz_u32(val_limb, 32);
+        }
+
+        total_tz += limb_tz;
+        if (limb_tz < 32) return total_tz;
+        remaining_bytes -= 32 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 32 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 16 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 16 / CHAR_BIT;
+#endif
+
+        {
+            uint16_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_tz = zig_ctz_u16(val_limb, 16);
+        }
+
+        total_tz += limb_tz;
+        if (limb_tz < 16) return total_tz;
+        remaining_bytes -= 16 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 16 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 8 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 8 / CHAR_BIT;
+#endif
+
+        {
+            uint8_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            limb_tz = zig_ctz_u8(val_limb, 8);
+        }
+
+        total_tz += limb_tz;
+        if (limb_tz < 8) return total_tz;
+        remaining_bytes -= 8 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 8 / CHAR_BIT;
+#endif
+    }
+
+    return total_tz;
+}
+
+static inline uint16_t zig_popcount_big(const void *val, bool is_signed, uint16_t bits) {
+    const uint8_t *val_bytes = val;
+    uint16_t byte_offset = 0;
+    uint16_t remaining_bytes = zig_int_bytes(bits);
+    uint16_t total_pc = 0;
+    (void)is_signed;
+
+#if zig_big_endian
+    byte_offset = remaining_bytes;
+#endif
+
+    while (remaining_bytes >= 128 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 128 / CHAR_BIT;
+#endif
+
+        {
+            zig_u128 val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            total_pc += zig_popcount_u128(val_limb, 128);
+        }
+
+        remaining_bytes -= 128 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 128 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 64 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 64 / CHAR_BIT;
+#endif
+
+        {
+            uint64_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            total_pc += zig_popcount_u64(val_limb, 64);
+        }
+
+        remaining_bytes -= 64 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 64 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 32 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 32 / CHAR_BIT;
+#endif
+
+        {
+            uint32_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            total_pc += zig_popcount_u32(val_limb, 32);
+        }
+
+        remaining_bytes -= 32 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 32 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 16 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 16 / CHAR_BIT;
+#endif
+
+        {
+            uint16_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            total_pc = zig_popcount_u16(val_limb, 16);
+        }
+
+        remaining_bytes -= 16 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 16 / CHAR_BIT;
+#endif
+    }
+
+    while (remaining_bytes >= 8 / CHAR_BIT) {
+#if zig_big_endian
+        byte_offset -= 8 / CHAR_BIT;
+#endif
+
+        {
+            uint8_t val_limb;
+
+            memcpy(&val_limb, &val_bytes[byte_offset], sizeof(val_limb));
+            total_pc = zig_popcount_u8(val_limb, 8);
+        }
+
+        remaining_bytes -= 8 / CHAR_BIT;
+
+#if zig_little_endian
+        byte_offset += 8 / CHAR_BIT;
+#endif
+    }
+
+    return total_pc;
+}
+
 /* ========================= Floating Point Support ========================= */
 
 #if _MSC_VER
@@ -2742,7 +3097,7 @@ zig_msvc_atomics_128op(u128, max)
         uint32_t index = 0; \
         const uint8_t *lhs_ptr = lhs; \
         const uint8_t *rhs_ptr = rhs; \
-        uint16_t elem_bytes = zig_big_bytes(elem_bits); \
+        uint16_t elem_bytes = zig_int_bytes(elem_bits); \
  \
         while (index < len) { \
             result[index] = zig_cmp_big(lhs_ptr, rhs_ptr, is_signed, elem_bits) operator 0; \
@@ -2758,6 +3113,57 @@ zig_cmp_vec(le, <=)
 zig_cmp_vec(gt, > )
 zig_cmp_vec(ge, >=)
 
+static inline void zig_clz_vec(void *result, const void *val, uint32_t len, bool is_signed, uint16_t elem_bits) {
+    uint32_t index = 0;
+    const uint8_t *val_ptr = val;
+    uint16_t elem_bytes = zig_int_bytes(elem_bits);
+
+    while (index < len) {
+        uint16_t lz = zig_clz_big(val_ptr, is_signed, elem_bits);
+        if (elem_bits <= 128) {
+            ((uint8_t *)result)[index] = (uint8_t)lz;
+        } else {
+            ((uint16_t *)result)[index] = lz;
+        }
+        val_ptr += elem_bytes;
+        index += 1;
+    }
+}
+
+static inline void zig_ctz_vec(void *result, const void *val, uint32_t len, bool is_signed, uint16_t elem_bits) {
+    uint32_t index = 0;
+    const uint8_t *val_ptr = val;
+    uint16_t elem_bytes = zig_int_bytes(elem_bits);
+
+    while (index < len) {
+        uint16_t tz = zig_ctz_big(val_ptr, is_signed, elem_bits);
+        if (elem_bits <= 128) {
+            ((uint8_t *)result)[index] = (uint8_t)tz;
+        } else {
+            ((uint16_t *)result)[index] = tz;
+        }
+        val_ptr += elem_bytes;
+        index += 1;
+    }
+}
+
+static inline void zig_popcount_vec(void *result, const void *val, uint32_t len, bool is_signed, uint16_t elem_bits) {
+    uint32_t index = 0;
+    const uint8_t *val_ptr = val;
+    uint16_t elem_bytes = zig_int_bytes(elem_bits);
+
+    while (index < len) {
+        uint16_t pc = zig_popcount_big(val_ptr, is_signed, elem_bits);
+        if (elem_bits <= 128) {
+            ((uint8_t *)result)[index] = (uint8_t)pc;
+        } else {
+            ((uint16_t *)result)[index] = pc;
+        }
+        val_ptr += elem_bytes;
+        index += 1;
+    }
+}
+
 /* ======================== Special Case Intrinsics ========================= */
 
 #if (_MSC_VER && _M_X64) || defined(__x86_64__)
src/codegen/c.zig
@@ -2844,7 +2844,7 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail,
             .cmp_vector => blk: {
                 const ty_pl = f.air.instructions.items(.data)[inst].ty_pl;
                 const extra = f.air.extraData(Air.VectorCmp, ty_pl.payload).data;
-                break :blk try cmpBuiltinCall(f, inst, extra, extra.compareOperator(), .operator, .bits);
+                break :blk try airCmpBuiltinCall(f, inst, extra, extra.compareOperator(), .operator, .bits,);
             },
             .cmp_lt_errors_len => try airCmpLtErrorsLen(f, inst),
 
@@ -3837,9 +3837,16 @@ fn airCmpOp(f: *Function, inst: Air.Inst.Index, operator: std.math.CompareOperat
     const target = f.object.dg.module.getTarget();
     const operand_bits = operand_ty.bitSize(target);
     if (operand_ty.isInt() and operand_bits > 64)
-        return cmpBuiltinCall(f, inst, bin_op, operator, .cmp, if (operand_bits > 128) .bits else .none);
+        return airCmpBuiltinCall(
+            f,
+            inst,
+            bin_op,
+            operator,
+            .cmp,
+            if (operand_bits > 128) .bits else .none,
+        );
     if (operand_ty.isRuntimeFloat())
-        return cmpBuiltinCall(f, inst, bin_op, operator, .operator, .none);
+        return airCmpBuiltinCall(f, inst, bin_op, operator, .operator, .none);
 
     const inst_ty = f.air.typeOfIndex(inst);
     const lhs = try f.resolveInst(bin_op.lhs);
@@ -3876,9 +3883,16 @@ fn airEquality(
     const target = f.object.dg.module.getTarget();
     const operand_bits = operand_ty.bitSize(target);
     if (operand_ty.isInt() and operand_bits > 64)
-        return cmpBuiltinCall(f, inst, bin_op, operator, .cmp, if (operand_bits > 128) .bits else .none);
+        return airCmpBuiltinCall(
+            f,
+            inst,
+            bin_op,
+            operator,
+            .cmp,
+            if (operand_bits > 128) .bits else .none,
+        );
     if (operand_ty.isRuntimeFloat())
-        return cmpBuiltinCall(f, inst, bin_op, operator, .operator, .none);
+        return airCmpBuiltinCall(f, inst, bin_op, operator, .operator, .none);
 
     const lhs = try f.resolveInst(bin_op.lhs);
     const rhs = try f.resolveInst(bin_op.rhs);
@@ -5969,14 +5983,25 @@ fn airUnBuiltinCall(
     const inst_ty = f.air.typeOfIndex(inst);
     const operand_ty = f.air.typeOf(ty_op.operand);
 
+    const inst_cty = try f.typeToCType(inst_ty, .complete);
+    const ref_ret = switch (inst_cty.tag()) {
+        else => false,
+        .array, .vector => true,
+    };
+
     const writer = f.object.writer();
     const local = try f.allocLocal(inst, inst_ty);
-    try f.writeCValue(writer, local, .Other);
-    try writer.writeAll(" = zig_");
-    try writer.writeAll(operation);
-    try writer.writeByte('_');
+    if (!ref_ret) {
+        try f.writeCValue(writer, local, .Other);
+        try writer.writeAll(" = ");
+    }
+    try writer.print("zig_{s}_", .{operation});
     try f.object.dg.renderTypeForBuiltinFnName(writer, operand_ty);
     try writer.writeByte('(');
+    if (ref_ret) {
+        try f.writeCValue(writer, local, .FunctionArgument);
+        try writer.writeAll(", ");
+    }
     try f.writeCValue(writer, operand, .FunctionArgument);
     try f.object.dg.renderBuiltinInfo(writer, operand_ty, info);
     try writer.writeAll(");\n");
@@ -6019,7 +6044,7 @@ fn airBinBuiltinCall(
     return local;
 }
 
-fn cmpBuiltinCall(
+fn airCmpBuiltinCall(
     f: *Function,
     inst: Air.Inst.Index,
     data: anytype,
@@ -6034,7 +6059,11 @@ fn cmpBuiltinCall(
     const rhs = try f.resolveInst(data.rhs);
     try reap(f, inst, &.{ data.lhs, data.rhs });
 
-    const ref_ret = inst_ty.tag() != .bool;
+    const inst_cty = try f.typeToCType(inst_ty, .complete);
+    const ref_ret = switch (inst_cty.tag()) {
+        else => false,
+        .array, .vector => true,
+    };
 
     const writer = f.object.writer();
     const local = try f.allocLocal(inst, inst_ty);
test/behavior/bugs/10147.zig
@@ -6,7 +6,6 @@ test "test calling @clz on both vector and scalar inputs" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     var x: u32 = 0x1;
test/behavior/math.zig
@@ -100,7 +100,6 @@ test "@clz vectors" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     try testClzVectors();
@@ -163,7 +162,6 @@ test "@ctz vectors" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
-    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
     if (builtin.zig_backend == .stage2_llvm and builtin.cpu.arch == .aarch64) {
@@ -1562,6 +1560,12 @@ test "signed zeros are represented properly" {
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
 
+    if (builtin.os.tag == .windows and builtin.cpu.arch == .aarch64 and
+        builtin.zig_backend == .stage2_c)
+    {
+        return error.SkipZigTest;
+    }
+
     const S = struct {
         fn doTheTest() !void {
             try testOne(f16);
test/behavior/popcount.zig
@@ -67,7 +67,6 @@ fn testPopCountIntegers() !void {
 }
 
 test "@popCount vectors" {
-    if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO