Commit 76f5396077

Shawn Landden <shawn@git.icu>
2019-07-14 16:22:37
@byteSwap on vectors
1 parent 86209e1
Changed files (4)
src/all_types.hpp
@@ -1771,6 +1771,7 @@ struct ZigLLVMFnKey {
         } overflow_arithmetic;
         struct {
             uint32_t bit_count;
+            uint32_t vector_len; // 0 means not a vector
         } bswap;
         struct {
             uint32_t bit_count;
src/codegen.cpp
@@ -4505,7 +4505,13 @@ static LLVMValueRef ir_render_optional_unwrap_ptr(CodeGen *g, IrExecutable *exec
     }
 }
 
-static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *int_type, BuiltinFnId fn_id) {
+static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *expr_type, BuiltinFnId fn_id) {
+    bool is_vector = expr_type->id == ZigTypeIdVector;
+    ZigType *int_type = is_vector ? expr_type->data.vector.elem_type : expr_type;
+    assert(int_type->id == ZigTypeIdInt);
+    uint32_t vector_len = 0;
+    if (is_vector)
+        vector_len = expr_type->data.vector.len;
     ZigLLVMFnKey key = {};
     const char *fn_name;
     uint32_t n_args;
@@ -4529,6 +4535,7 @@ static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *int_type, BuiltinFnI
         n_args = 1;
         key.id = ZigLLVMFnIdBswap;
         key.data.bswap.bit_count = (uint32_t)int_type->data.integral.bit_count;
+        key.data.bswap.vector_len = vector_len;
     } else if (fn_id == BuiltinFnIdBitReverse) {
         fn_name = "bitreverse";
         n_args = 1;
@@ -4543,12 +4550,15 @@ static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *int_type, BuiltinFnI
         return existing_entry->value;
 
     char llvm_name[64];
-    sprintf(llvm_name, "llvm.%s.i%" PRIu32, fn_name, int_type->data.integral.bit_count);
+    if (is_vector)
+        sprintf(llvm_name, "llvm.%s.v%" PRIu32 "i%" PRIu32, fn_name, vector_len, int_type->data.integral.bit_count);
+    else
+        sprintf(llvm_name, "llvm.%s.i%" PRIu32, fn_name, int_type->data.integral.bit_count);
     LLVMTypeRef param_types[] = {
-        get_llvm_type(g, int_type),
+        get_llvm_type(g, expr_type),
         LLVMInt1Type(),
     };
-    LLVMTypeRef fn_type = LLVMFunctionType(get_llvm_type(g, int_type), param_types, n_args, false);
+    LLVMTypeRef fn_type = LLVMFunctionType(get_llvm_type(g, expr_type), param_types, n_args, false);
     LLVMValueRef fn_val = LLVMAddFunction(g->module, llvm_name, fn_type);
     assert(LLVMGetIntrinsicID(fn_val));
 
@@ -5542,15 +5552,19 @@ static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrIn
 
 static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInstructionBswap *instruction) {
     LLVMValueRef op = ir_llvm_value(g, instruction->op);
-    ZigType *int_type = instruction->base.value.type;
+    ZigType *expr_type = instruction->base.value.type;
+    bool is_vector = expr_type->id == ZigTypeIdVector;
+    ZigType *int_type = is_vector ? expr_type->data.vector.elem_type : expr_type;
     assert(int_type->id == ZigTypeIdInt);
     if (int_type->data.integral.bit_count % 16 == 0) {
-        LLVMValueRef fn_val = get_int_builtin_fn(g, instruction->base.value.type, BuiltinFnIdBswap);
+        LLVMValueRef fn_val = get_int_builtin_fn(g, expr_type, BuiltinFnIdBswap);
         return LLVMBuildCall(g->builder, fn_val, &op, 1, "");
     }
     // Not an even number of bytes, so we zext 1 byte, then bswap, shift right 1 byte, truncate
     ZigType *extended_type = get_int_type(g, int_type->data.integral.is_signed,
             int_type->data.integral.bit_count + 8);
+    if (is_vector)
+        extended_type = get_vector_type(g, expr_type->data.vector.len, extended_type);
     // aabbcc
     LLVMValueRef extended = LLVMBuildZExt(g->builder, op, get_llvm_type(g, extended_type), "");
     // 00aabbcc
@@ -5560,7 +5574,7 @@ static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInst
     LLVMValueRef shifted = ZigLLVMBuildLShrExact(g->builder, swapped,
             LLVMConstInt(get_llvm_type(g, extended_type), 8, false), "");
     // 00ccbbaa
-    return LLVMBuildTrunc(g->builder, shifted, get_llvm_type(g, int_type), "");
+    return LLVMBuildTrunc(g->builder, shifted, get_llvm_type(g, expr_type), "");
 }
 
 static LLVMValueRef ir_render_bit_reverse(CodeGen *g, IrExecutable *executable, IrInstructionBitReverse *instruction) {
src/ir.cpp
@@ -25253,16 +25253,42 @@ static IrInstruction *ir_analyze_instruction_float_op(IrAnalyze *ira, IrInstruct
 }
 
 static IrInstruction *ir_analyze_instruction_bswap(IrAnalyze *ira, IrInstructionBswap *instruction) {
-    ZigType *int_type = ir_resolve_int_type(ira, instruction->type->child);
-    if (type_is_invalid(int_type))
+    IrInstruction *op = instruction->op->child;
+    ZigType *type_expr = ir_resolve_type(ira, instruction->type->child);
+    if (type_is_invalid(type_expr))
         return ira->codegen->invalid_instruction;
 
-    IrInstruction *op = ir_implicit_cast(ira, instruction->op->child, int_type);
+    if (type_expr->id != ZigTypeIdInt) {
+        ir_add_error(ira, instruction->type,
+            buf_sprintf("expected integer type, found '%s'", buf_ptr(&type_expr->name)));
+        if (type_expr->id == ZigTypeIdVector &&
+            type_expr->data.vector.elem_type->id == ZigTypeIdInt)
+            ir_add_error(ira, instruction->type,
+                buf_sprintf("represent vectors with their scalar types, i.e. '%s'",
+                    buf_ptr(&type_expr->data.vector.elem_type->name)));
+        return ira->codegen->invalid_instruction;
+    }
+    ZigType *int_type = type_expr;
+
+    ZigType *expr_type = op->value.type;
+    bool is_vector = expr_type->id == ZigTypeIdVector;
+    ZigType *ret_type = int_type;
+    if (is_vector)
+        ret_type = get_vector_type(ira->codegen, expr_type->data.vector.len, int_type);
+
+    op = ir_implicit_cast(ira, instruction->op->child, ret_type);
     if (type_is_invalid(op->value.type))
         return ira->codegen->invalid_instruction;
 
     if (int_type->data.integral.bit_count == 0) {
-        IrInstruction *result = ir_const(ira, &instruction->base, int_type);
+        IrInstruction *result = ir_const(ira, &instruction->base, ret_type);
+        if (is_vector) {
+            expand_undef_array(ira->codegen, &result->value);
+            result->value.data.x_array.data.s_none.elements =
+                allocate<ConstExprValue>(expr_type->data.vector.len);
+            for (unsigned i = 0; i < expr_type->data.vector.len; i++)
+                bigint_init_unsigned(&result->value.data.x_array.data.s_none.elements[i].data.x_bigint, 0);
+        }
         bigint_init_unsigned(&result->value.data.x_bigint, 0);
         return result;
     }
@@ -25282,20 +25308,36 @@ static IrInstruction *ir_analyze_instruction_bswap(IrAnalyze *ira, IrInstruction
         if (val == nullptr)
             return ira->codegen->invalid_instruction;
         if (val->special == ConstValSpecialUndef)
-            return ir_const_undef(ira, &instruction->base, int_type);
+            return ir_const_undef(ira, &instruction->base, ret_type);
 
-        IrInstruction *result = ir_const(ira, &instruction->base, int_type);
+        IrInstruction *result = ir_const(ira, &instruction->base, ret_type);
         size_t buf_size = int_type->data.integral.bit_count / 8;
         uint8_t *buf = allocate_nonzero<uint8_t>(buf_size);
-        bigint_write_twos_complement(&val->data.x_bigint, buf, int_type->data.integral.bit_count, true);
-        bigint_read_twos_complement(&result->value.data.x_bigint, buf, int_type->data.integral.bit_count, false,
-                int_type->data.integral.is_signed);
+        if (is_vector) {
+            expand_undef_array(ira->codegen, &result->value);
+            result->value.data.x_array.data.s_none.elements =
+                allocate<ConstExprValue>(expr_type->data.vector.len);
+            for (unsigned i = 0; i < expr_type->data.vector.len; i++) {
+                ConstExprValue *cur = &val->data.x_array.data.s_none.elements[i];
+                result->value.data.x_array.data.s_none.elements[i].special = cur->special;
+                if (cur->special == ConstValSpecialUndef)
+                    continue;
+                bigint_write_twos_complement(&cur->data.x_bigint, buf, int_type->data.integral.bit_count, true);
+                bigint_read_twos_complement(&result->value.data.x_array.data.s_none.elements[i].data.x_bigint,
+                        buf, int_type->data.integral.bit_count, false,
+                        int_type->data.integral.is_signed);
+            }
+        } else {
+            bigint_write_twos_complement(&val->data.x_bigint, buf, int_type->data.integral.bit_count, true);
+            bigint_read_twos_complement(&result->value.data.x_bigint, buf, int_type->data.integral.bit_count, false,
+                    int_type->data.integral.is_signed);
+        }
         return result;
     }
 
     IrInstruction *result = ir_build_bswap(&ira->new_irb, instruction->base.scope,
             instruction->base.source_node, nullptr, op);
-    result->value.type = int_type;
+    result->value.type = ret_type;
     return result;
 }
 
test/stage1/behavior/byteswap.zig
@@ -6,6 +6,11 @@ test "@byteSwap" {
     testByteSwap();
 }
 
+test "@byteSwap on vectors" {
+    comptime testVectorByteSwap();
+    testVectorByteSwap();
+}
+
 fn testByteSwap() void {
     expect(@byteSwap(u0, 0) == 0);
     expect(@byteSwap(u8, 0x12) == 0x12);
@@ -30,3 +35,9 @@ fn testByteSwap() void {
     expect(@byteSwap(i128, @bitCast(i128, u128(0x123456789abcdef11121314151617181))) ==
         @bitCast(i128, u128(0x8171615141312111f1debc9a78563412)));
 }
+
+fn testVectorByteSwap() void {
+    expect((@byteSwap(u8, @Vector(2, u8)([2]u8{0x12, 0x13})) == @Vector(2, u8)([2]u8{0x12, 0x13})).all);
+    expect((@byteSwap(u16, @Vector(2, u16)([2]u16{0x1234, 0x2345})) == @Vector(2, u16)([2]u16{0x3412, 0x4523})).all);
+    expect((@byteSwap(u24, @Vector(2, u24)([2]u24{0x123456, 0x234567})) == @Vector(2, u24)([2]u24{0x563412, 0x674523})).all);
+}