Commit ecca829bcb

Auguste Rame <auguste.rame@gmail.com>
2021-07-26 02:35:55
Add vector support for @popCount
1 parent 653c851
Changed files (6)
doc/langref.html.in
@@ -8105,12 +8105,14 @@ test "@wasmMemoryGrow" {
       {#header_close#}
 
       {#header_open|@popCount#}
-      <pre>{#syntax#}@popCount(comptime T: type, integer: T){#endsyntax#}</pre>
+      <pre>{#syntax#}@popCount(comptime T: type, operand: T){#endsyntax#}</pre>
+      <p>{#syntax#}T{#endsyntax#} must be an integer type.</p>
+      <p>{#syntax#}operand{#endsyntax#} may be an {#link|integer|Integers#} or {#link|vector|Vectors#}.</p>
       <p>Counts the number of bits set in an integer.</p>
       <p>
-      If {#syntax#}integer{#endsyntax#} is known at {#link|comptime#},
+      If {#syntax#}operand{#endsyntax#} is a {#link|comptime#}-known integer,
       the return type is {#syntax#}comptime_int{#endsyntax#}.
-      Otherwise, the return type is an unsigned integer with the minimum number
+      Otherwise, the return type is an unsigned integer or vector of unsigned integers with the minimum number
       of bits that can represent the bit count of the integer type.
       </p>
       {#see_also|@ctz|@clz#}
src/stage1/all_types.hpp
@@ -1913,6 +1913,7 @@ struct ZigLLVMFnKey {
         } clz;
         struct {
             uint32_t bit_count;
+            uint32_t vector_len; // 0 means not a vector
         } pop_count;
         struct {
             BuiltinFnId op;
src/stage1/analyze.cpp
@@ -7887,7 +7887,8 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey const *x) {
         case ZigLLVMFnIdClz:
             return (uint32_t)(x->data.clz.bit_count) * (uint32_t)2428952817;
         case ZigLLVMFnIdPopCount:
-            return (uint32_t)(x->data.clz.bit_count) * (uint32_t)101195049;
+            return (uint32_t)(x->data.pop_count.bit_count) * (uint32_t)101195049 +
+                   (uint32_t)(x->data.pop_count.vector_len) * (((uint32_t)x->id << 5) + 1025);
         case ZigLLVMFnIdFloatOp:
             return (uint32_t)(x->data.floating.bit_count) * ((uint32_t)x->id + 1025) +
                    (uint32_t)(x->data.floating.vector_len) * (((uint32_t)x->id << 5) + 1025) +
src/stage1/codegen.cpp
@@ -5053,6 +5053,7 @@ static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *expr_type, BuiltinFn
         n_args = 1;
         key.id = ZigLLVMFnIdPopCount;
         key.data.pop_count.bit_count = (uint32_t)int_type->data.integral.bit_count;
+        key.data.pop_count.vector_len = vector_len;
     } else if (fn_id == BuiltinFnIdBswap) {
         fn_name = "bswap";
         n_args = 1;
src/stage1/ir.cpp
@@ -15997,33 +15997,87 @@ static Stage1AirInst *ir_analyze_instruction_clz(IrAnalyze *ira, Stage1ZirInstCl
 }
 
 static Stage1AirInst *ir_analyze_instruction_pop_count(IrAnalyze *ira, Stage1ZirInstPopCount *instruction) {
+    Error err;
+    
     ZigType *int_type = ir_resolve_int_type(ira, instruction->type->child);
     if (type_is_invalid(int_type))
         return ira->codegen->invalid_inst_gen;
 
-    Stage1AirInst *op = ir_implicit_cast(ira, instruction->op->child, int_type);
+    Stage1AirInst *uncasted_op = instruction->op->child;
+    if (type_is_invalid(uncasted_op->value->type))
+        return ira->codegen->invalid_inst_gen;
+
+    uint32_t vector_len = UINT32_MAX; // means not a vector
+    if (uncasted_op->value->type->id == ZigTypeIdArray) {
+        bool can_be_vec_elem;
+        if ((err = is_valid_vector_elem_type(ira->codegen, uncasted_op->value->type->data.array.child_type,
+                        &can_be_vec_elem)))
+        {
+            return ira->codegen->invalid_inst_gen;
+        }
+        if (can_be_vec_elem) {
+            vector_len = uncasted_op->value->type->data.array.len;
+        }
+    } else if (uncasted_op->value->type->id == ZigTypeIdVector) {
+        vector_len = uncasted_op->value->type->data.vector.len;
+    }
+
+    bool is_vector = (vector_len != UINT32_MAX);
+    ZigType *op_type = is_vector ? get_vector_type(ira->codegen, vector_len, int_type) : int_type;
+
+    Stage1AirInst *op = ir_implicit_cast(ira, uncasted_op, op_type);
     if (type_is_invalid(op->value->type))
         return ira->codegen->invalid_inst_gen;
 
     if (int_type->data.integral.bit_count == 0)
         return ir_const_unsigned(ira, instruction->base.scope, instruction->base.source_node, 0);
 
+    ZigType *smallest_type = get_smallest_unsigned_int_type(ira->codegen, int_type->data.integral.bit_count);
+
     if (instr_is_comptime(op)) {
         ZigValue *val = ir_resolve_const(ira, op, UndefOk);
         if (val == nullptr)
             return ira->codegen->invalid_inst_gen;
         if (val->special == ConstValSpecialUndef)
             return ir_const_undef(ira, instruction->base.scope, instruction->base.source_node, ira->codegen->builtin_types.entry_num_lit_int);
+        
+        if (is_vector) {
+            ZigType *smallest_vec_type = get_vector_type(ira->codegen, vector_len, smallest_type);
+            Stage1AirInst *result = ir_const(ira, instruction->base.scope, instruction->base.source_node, smallest_vec_type);
+            expand_undef_array(ira->codegen, val);
+            result->value->data.x_array.data.s_none.elements = ira->codegen->pass1_arena->allocate<ZigValue>(smallest_vec_type->data.vector.len);
+            for (unsigned i = 0; i < smallest_vec_type->data.vector.len; i += 1) {
+                ZigValue *op_elem_val = &val->data.x_array.data.s_none.elements[i];
+                if ((err = ir_resolve_const_val(ira->codegen, ira->new_irb.exec, instruction->base.source_node,
+                    op_elem_val, UndefOk)))
+                {
+                    return ira->codegen->invalid_inst_gen;
+                }
+                ZigValue *result_elem_val = &result->value->data.x_array.data.s_none.elements[i];
+                result_elem_val->type = smallest_type;
+                result_elem_val->special = op_elem_val->special;
+                if (op_elem_val->special == ConstValSpecialUndef)
+                    continue;
 
-        if (bigint_cmp_zero(&val->data.x_bigint) != CmpLT) {
-            size_t result = bigint_popcount_unsigned(&val->data.x_bigint);
+                if (bigint_cmp_zero(&op_elem_val->data.x_bigint) != CmpLT) {
+                    size_t value = bigint_popcount_unsigned(&op_elem_val->data.x_bigint);
+                    bigint_init_unsigned(&result->value->data.x_array.data.s_none.elements[i].data.x_bigint, value);
+                }
+                size_t value = bigint_popcount_signed(&op_elem_val->data.x_bigint, int_type->data.integral.bit_count);
+                bigint_init_unsigned(&result->value->data.x_array.data.s_none.elements[i].data.x_bigint, value);
+            }
+            return result;
+        } else {
+            if (bigint_cmp_zero(&val->data.x_bigint) != CmpLT) {
+                size_t result = bigint_popcount_unsigned(&val->data.x_bigint);
+                return ir_const_unsigned(ira, instruction->base.scope, instruction->base.source_node, result);
+            }
+            size_t result = bigint_popcount_signed(&val->data.x_bigint, int_type->data.integral.bit_count);
             return ir_const_unsigned(ira, instruction->base.scope, instruction->base.source_node, result);
         }
-        size_t result = bigint_popcount_signed(&val->data.x_bigint, int_type->data.integral.bit_count);
-        return ir_const_unsigned(ira, instruction->base.scope, instruction->base.source_node, result);
     }
 
-    ZigType *return_type = get_smallest_unsigned_int_type(ira->codegen, int_type->data.integral.bit_count);
+    ZigType *return_type = is_vector ? get_vector_type(ira->codegen, vector_len, smallest_type) : smallest_type;
     return ir_build_pop_count_gen(ira, instruction->base.scope, instruction->base.source_node, return_type, op);
 }
 
test/behavior/popcount.zig
@@ -1,11 +1,14 @@
-const expect = @import("std").testing.expect;
+const std = @import("std");
+const expect = std.testing.expect;
+const expectEqual = std.testing.expectEqual;
+const Vector = std.meta.Vector;
 
-test "@popCount" {
-    comptime try testPopCount();
-    try testPopCount();
+test "@popCount integers" {
+    comptime try testPopCountIntegers();
+    try testPopCountIntegers();
 }
 
-fn testPopCount() !void {
+fn testPopCountIntegers() !void {
     {
         var x: u32 = 0xffffffff;
         try expect(@popCount(u32, x) == 32);
@@ -41,3 +44,22 @@ fn testPopCount() !void {
         try expect(@popCount(i128, 0b11111111000110001100010000100001000011000011100101010001) == 24);
     }
 }
+
+test "@popCount vectors" {
+    // https://github.com/ziglang/zig/issues/3317
+    if (std.Target.current.cpu.arch == .mipsel or std.Target.current.cpu.arch == .mips) return error.SkipZigTest;
+
+    comptime try testPopCountVectors();
+    try testPopCountVectors();
+}
+
+fn testPopCountVectors() !void {
+    {
+        var x: Vector(8, u32) = [1]u32{0xffffffff} ** 8;
+        try expectEqual([1]u6{32} ** 8, @as([8]u6, @popCount(u32, x)));
+    }
+    {
+        var x: Vector(8, i16) = [1]i16{-1} ** 8;
+        try expectEqual([1]u5{16} ** 8, @as([8]u5, @popCount(i16, x)));
+    }
+}