Commit a005ac9d3c

Andrew Kelley <andrew@ziglang.org>
2022-02-13 04:44:30
stage2: implement `@popCount` for SIMD vectors
1 parent 16ec848
src/codegen/llvm.zig
@@ -2205,7 +2205,7 @@ pub const FuncGen = struct {
                 .get_union_tag  => try self.airGetUnionTag(inst),
                 .clz            => try self.airClzCtz(inst, "ctlz"),
                 .ctz            => try self.airClzCtz(inst, "cttz"),
-                .popcount       => try self.airPopCount(inst, "ctpop"),
+                .popcount       => try self.airPopCount(inst),
                 .tag_name       => try self.airTagName(inst),
                 .error_name     => try self.airErrorName(inst),
                 .splat          => try self.airSplat(inst),
@@ -4364,7 +4364,7 @@ pub const FuncGen = struct {
         }
     }
 
-    fn airPopCount(self: *FuncGen, inst: Air.Inst.Index, prefix: [*:0]const u8) !?*const llvm.Value {
+    fn airPopCount(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
         if (self.liveness.isUnused(inst)) return null;
 
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;
@@ -4372,11 +4372,16 @@ pub const FuncGen = struct {
         const operand = try self.resolveInst(ty_op.operand);
         const target = self.dg.module.getTarget();
         const bits = operand_ty.intInfo(target).bits;
+        const vec_len: ?u32 = switch (operand_ty.zigTypeTag()) {
+            .Vector => operand_ty.vectorLen(),
+            else => null,
+        };
 
         var fn_name_buf: [100]u8 = undefined;
-        const llvm_fn_name = std.fmt.bufPrintZ(&fn_name_buf, "llvm.{s}.i{d}", .{
-            prefix, bits,
-        }) catch unreachable;
+        const llvm_fn_name = if (vec_len) |len|
+            std.fmt.bufPrintZ(&fn_name_buf, "llvm.ctpop.v{d}i{d}", .{ len, bits }) catch unreachable
+        else
+            std.fmt.bufPrintZ(&fn_name_buf, "llvm.ctpop.i{d}", .{bits}) catch unreachable;
         const fn_val = self.dg.object.llvm_module.getNamedFunction(llvm_fn_name) orelse blk: {
             const operand_llvm_ty = try self.dg.llvmType(operand_ty);
             const param_types = [_]*const llvm.Type{operand_llvm_ty};
src/Sema.zig
@@ -720,7 +720,6 @@ fn analyzeBodyInner(
             .align_cast                   => try sema.zirAlignCast(block, inst),
             .has_decl                     => try sema.zirHasDecl(block, inst),
             .has_field                    => try sema.zirHasField(block, inst),
-            .pop_count                    => try sema.zirPopCount(block, inst),
             .byte_swap                    => try sema.zirByteSwap(block, inst),
             .bit_reverse                  => try sema.zirBitReverse(block, inst),
             .bit_offset_of                => try sema.zirBitOffsetOf(block, inst),
@@ -743,8 +742,9 @@ fn analyzeBodyInner(
             .await_nosuspend              => try sema.zirAwait(block, inst, true),
             .extended                     => try sema.zirExtended(block, inst),
 
-            .clz => try sema.zirClzCtz(block, inst, .clz, Value.clz),
-            .ctz => try sema.zirClzCtz(block, inst, .ctz, Value.ctz),
+            .clz       => try sema.zirBitCount(block, inst, .clz,      Value.clz),
+            .ctz       => try sema.zirBitCount(block, inst, .ctz,      Value.ctz),
+            .pop_count => try sema.zirBitCount(block, inst, .popcount, Value.popCount),
 
             .sqrt  => try sema.zirUnaryMath(block, inst, .sqrt, Value.sqrt),
             .sin   => try sema.zirUnaryMath(block, inst, .sin, Value.sin),
@@ -11487,7 +11487,7 @@ fn zirAlignCast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
     return sema.coerceCompatiblePtrs(block, dest_ty, ptr, ptr_src);
 }
 
-fn zirClzCtz(
+fn zirBitCount(
     sema: *Sema,
     block: *Block,
     inst: Zir.Inst.Index,
@@ -11550,34 +11550,6 @@ fn zirClzCtz(
     }
 }
 
-fn zirPopCount(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
-    const inst_data = sema.code.instructions.items(.data)[inst].un_node;
-    const ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
-    const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
-    const operand = sema.resolveInst(inst_data.operand);
-    const operand_ty = sema.typeOf(operand);
-    // TODO implement support for vectors
-    if (operand_ty.zigTypeTag() != .Int) {
-        return sema.fail(block, ty_src, "expected integer type, found '{}'", .{
-            operand_ty,
-        });
-    }
-    const target = sema.mod.getTarget();
-    const bits = operand_ty.intInfo(target).bits;
-    if (bits == 0) return Air.Inst.Ref.zero;
-
-    const result_ty = try Type.smallestUnsignedInt(sema.arena, bits);
-
-    const runtime_src = if (try sema.resolveMaybeUndefVal(block, operand_src, operand)) |val| {
-        if (val.isUndef()) return sema.addConstUndef(result_ty);
-        const result_val = try val.popCount(operand_ty, target, sema.arena);
-        return sema.addConstant(result_ty, result_val);
-    } else operand_src;
-
-    try sema.requireRuntimeBlock(block, runtime_src);
-    return block.addTyOp(.popcount, result_ty, operand);
-}
-
 fn zirByteSwap(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const src = inst_data.src();
src/value.zig
@@ -1303,6 +1303,33 @@ pub const Value = extern union {
         }
     }
 
+    pub fn popCount(val: Value, ty: Type, target: Target) u64 {
+        assert(!val.isUndef());
+        switch (val.tag()) {
+            .zero, .bool_false => return 0,
+            .one, .bool_true => return 1,
+
+            .int_u64 => return @popCount(u64, val.castTag(.int_u64).?.data),
+
+            else => {
+                const info = ty.intInfo(target);
+
+                var buffer: Value.BigIntSpace = undefined;
+                const operand_bigint = val.toBigInt(&buffer);
+
+                var limbs_buffer: [4]std.math.big.Limb = undefined;
+                var result_bigint = BigIntMutable{
+                    .limbs = &limbs_buffer,
+                    .positive = undefined,
+                    .len = undefined,
+                };
+                result_bigint.popCount(operand_bigint, info.bits);
+
+                return result_bigint.toConst().to(u64) catch unreachable;
+            },
+        }
+    }
+
     /// Asserts the value is an integer and not undefined.
     /// Returns the number of bits the value requires to represent stored in twos complement form.
     pub fn intBitCountTwosComp(self: Value, target: Target) usize {
@@ -1340,24 +1367,6 @@ pub const Value = extern union {
         }
     }
 
-    pub fn popCount(val: Value, ty: Type, target: Target, arena: Allocator) !Value {
-        assert(!val.isUndef());
-
-        const info = ty.intInfo(target);
-
-        var buffer: Value.BigIntSpace = undefined;
-        const operand_bigint = val.toBigInt(&buffer);
-
-        const limbs = try arena.alloc(
-            std.math.big.Limb,
-            std.math.big.int.calcTwosCompLimbCount(info.bits),
-        );
-        var result_bigint = BigIntMutable{ .limbs = limbs, .positive = undefined, .len = undefined };
-        result_bigint.popCount(operand_bigint, info.bits);
-
-        return fromBigInt(arena, result_bigint.toConst());
-    }
-
     /// Asserts the value is an integer, and the destination type is ComptimeInt or Int.
     pub fn intFitsInType(self: Value, ty: Type, target: Target) bool {
         switch (self.tag()) {
test/behavior/popcount.zig
@@ -1,7 +1,6 @@
 const std = @import("std");
 const expect = std.testing.expect;
 const expectEqual = std.testing.expectEqual;
-const Vector = std.meta.Vector;
 
 test "@popCount integers" {
     comptime try testPopCountIntegers();
@@ -44,3 +43,23 @@ fn testPopCountIntegers() !void {
         try expect(@popCount(i128, @as(i128, 0b11111111000110001100010000100001000011000011100101010001)) == 24);
     }
 }
+
+test "@popCount vectors" {
+    comptime try testPopCountVectors();
+    try testPopCountVectors();
+}
+
+fn testPopCountVectors() !void {
+    {
+        var x: @Vector(8, u32) = [1]u32{0xffffffff} ** 8;
+        const expected = [1]u6{32} ** 8;
+        const result: [8]u6 = @popCount(u32, x);
+        try expect(std.mem.eql(u6, &expected, &result));
+    }
+    {
+        var x: @Vector(8, i16) = [1]i16{-1} ** 8;
+        const expected = [1]u5{16} ** 8;
+        const result: [8]u5 = @popCount(i16, x);
+        try expect(std.mem.eql(u5, &expected, &result));
+    }
+}
test/behavior/popcount_stage1.zig
@@ -1,24 +0,0 @@
-const std = @import("std");
-const expect = std.testing.expect;
-const expectEqual = std.testing.expectEqual;
-const Vector = std.meta.Vector;
-
-test "@popCount vectors" {
-    comptime try testPopCountVectors();
-    try testPopCountVectors();
-}
-
-fn testPopCountVectors() !void {
-    {
-        var x: Vector(8, u32) = [1]u32{0xffffffff} ** 8;
-        const expected = [1]u6{32} ** 8;
-        const result: [8]u6 = @popCount(u32, x);
-        try expect(std.mem.eql(u6, &expected, &result));
-    }
-    {
-        var x: Vector(8, i16) = [1]i16{-1} ** 8;
-        const expected = [1]u5{16} ** 8;
-        const result: [8]u5 = @popCount(i16, x);
-        try expect(std.mem.eql(u5, &expected, &result));
-    }
-}
test/behavior.zig
@@ -153,7 +153,6 @@ test {
                     _ = @import("behavior/ir_block_deps.zig");
                     _ = @import("behavior/misc.zig");
                     _ = @import("behavior/muladd.zig");
-                    _ = @import("behavior/popcount_stage1.zig");
                     _ = @import("behavior/reflection.zig");
                     _ = @import("behavior/select.zig");
                     _ = @import("behavior/shuffle.zig");