Commit 01577a3af4

Shawn Landden <shawn@git.icu>
2019-07-21 17:41:43
`@splat`
1 parent 0048bcb
src/all_types.hpp
@@ -1612,6 +1612,7 @@ enum BuiltinFnId {
     BuiltinFnIdIntType,
     BuiltinFnIdVectorType,
     BuiltinFnIdShuffle,
+    BuiltinFnIdSplat,
     BuiltinFnIdSetCold,
     BuiltinFnIdSetRuntimeSafety,
     BuiltinFnIdSetFloatMode,
@@ -2431,6 +2432,7 @@ enum IrInstructionId {
     IrInstructionIdIntType,
     IrInstructionIdVectorType,
     IrInstructionIdShuffleVector,
+    IrInstructionIdSplat,
     IrInstructionIdBoolNot,
     IrInstructionIdMemset,
     IrInstructionIdMemcpy,
@@ -3681,6 +3683,13 @@ struct IrInstructionShuffleVector {
     IrInstruction *mask; // This is in zig-format, not llvm format
 };
 
+struct IrInstructionSplat {
+    IrInstruction base;
+
+    IrInstruction *len;
+    IrInstruction *scalar;
+};
+
 struct IrInstructionAssertZero {
     IrInstruction base;
 
src/codegen.cpp
@@ -4619,6 +4619,20 @@ static LLVMValueRef ir_render_shuffle_vector(CodeGen *g, IrExecutable *executabl
         llvm_mask_value, "");
 }
 
+static LLVMValueRef ir_render_splat(CodeGen *g, IrExecutable *executable, IrInstructionSplat *instruction) {
+    uint64_t len = bigint_as_u64(&instruction->len->value.data.x_bigint);
+    LLVMValueRef wrapped_scalar_undef = LLVMGetUndef(instruction->base.value.type->llvm_type);
+    LLVMValueRef wrapped_scalar = LLVMBuildInsertElement(g->builder, wrapped_scalar_undef,
+        ir_llvm_value(g, instruction->scalar),
+        LLVMConstInt(LLVMInt32Type(), 0, false),
+        "");
+    return LLVMBuildShuffleVector(g->builder,
+        wrapped_scalar,
+        wrapped_scalar_undef,
+        LLVMConstNull(LLVMVectorType(g->builtin_types.entry_u32->llvm_type, (uint32_t)len)),
+        "");
+}
+
 static LLVMValueRef ir_render_pop_count(CodeGen *g, IrExecutable *executable, IrInstructionPopCount *instruction) {
     ZigType *int_type = instruction->op->value.type;
     LLVMValueRef fn_val = get_int_builtin_fn(g, int_type, BuiltinFnIdPopCount);
@@ -6146,6 +6160,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_spill_end(g, executable, (IrInstructionSpillEnd *)instruction);
         case IrInstructionIdShuffleVector:
             return ir_render_shuffle_vector(g, executable, (IrInstructionShuffleVector *) instruction);
+        case IrInstructionIdSplat:
+            return ir_render_splat(g, executable, (IrInstructionSplat *) instruction);
     }
     zig_unreachable();
 }
@@ -7837,6 +7853,7 @@ static void define_builtin_fns(CodeGen *g) {
     create_builtin_fn(g, BuiltinFnIdIntType, "IntType", 2); // TODO rename to Int
     create_builtin_fn(g, BuiltinFnIdVectorType, "Vector", 2);
     create_builtin_fn(g, BuiltinFnIdShuffle, "shuffle", 4);
+    create_builtin_fn(g, BuiltinFnIdSplat, "splat", 2);
     create_builtin_fn(g, BuiltinFnIdSetCold, "setCold", 1);
     create_builtin_fn(g, BuiltinFnIdSetRuntimeSafety, "setRuntimeSafety", 1);
     create_builtin_fn(g, BuiltinFnIdSetFloatMode, "setFloatMode", 1);
src/ir.cpp
@@ -721,6 +721,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionShuffleVector *)
     return IrInstructionIdShuffleVector;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionSplat *) {
+    return IrInstructionIdSplat;
+}
+
 static constexpr IrInstructionId ir_instruction_id(IrInstructionBoolNot *) {
     return IrInstructionIdBoolNot;
 }
@@ -2300,6 +2304,19 @@ static IrInstruction *ir_build_shuffle_vector(IrBuilder *irb, Scope *scope, AstN
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_splat(IrBuilder *irb, Scope *scope, AstNode *source_node,
+    IrInstruction *len, IrInstruction *scalar)
+{
+    IrInstructionSplat *instruction = ir_build_instruction<IrInstructionSplat>(irb, scope, source_node);
+    instruction->len = len;
+    instruction->scalar = scalar;
+
+    ir_ref_instruction(len, irb->current_basic_block);
+    ir_ref_instruction(scalar, irb->current_basic_block);
+
+    return &instruction->base;
+}
+
 static IrInstruction *ir_build_bool_not(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *value) {
     IrInstructionBoolNot *instruction = ir_build_instruction<IrInstructionBoolNot>(irb, scope, source_node);
     instruction->value = value;
@@ -4985,6 +5002,22 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
                     arg0_value, arg1_value, arg2_value, arg3_value);
                 return ir_lval_wrap(irb, scope, shuffle_vector, lval, result_loc);
             }
+        case BuiltinFnIdSplat:
+            {
+                AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
+                IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
+                if (arg0_value == irb->codegen->invalid_instruction)
+                    return arg0_value;
+
+                AstNode *arg1_node = node->data.fn_call_expr.params.at(1);
+                IrInstruction *arg1_value = ir_gen_node(irb, arg1_node, scope);
+                if (arg1_value == irb->codegen->invalid_instruction)
+                    return arg1_value;
+
+                IrInstruction *splat = ir_build_splat(irb, scope, node,
+                    arg0_value, arg1_value);
+                return ir_lval_wrap(irb, scope, splat, lval, result_loc);
+            }
         case BuiltinFnIdMemcpy:
             {
                 AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
@@ -22324,6 +22357,52 @@ static IrInstruction *ir_analyze_instruction_shuffle_vector(IrAnalyze *ira, IrIn
     return ir_analyze_shuffle_vector(ira, &instruction->base, scalar_type, a, b, mask);
 }
 
+static IrInstruction *ir_analyze_instruction_splat(IrAnalyze *ira, IrInstructionSplat *instruction) {
+    IrInstruction *len = instruction->len->child;
+    if (type_is_invalid(len->value.type))
+        return ira->codegen->invalid_instruction;
+
+    IrInstruction *scalar = instruction->scalar->child;
+    if (type_is_invalid(scalar->value.type))
+        return ira->codegen->invalid_instruction;
+
+    uint64_t len_int;
+    if (!ir_resolve_unsigned(ira, len, ira->codegen->builtin_types.entry_u32, &len_int)) {
+        ir_add_error(ira, len,
+            buf_sprintf("splat length must be comptime"));
+        return ira->codegen->invalid_instruction;
+    }
+
+    if (!is_valid_vector_elem_type(scalar->value.type)) {
+        ir_add_error(ira, len,
+            buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid",
+                buf_ptr(&scalar->value.type->name)));
+        return ira->codegen->invalid_instruction;
+    }
+
+    ZigType *return_type = get_vector_type(ira->codegen, len_int, scalar->value.type);
+
+    if (instr_is_comptime(scalar)) {
+        IrInstruction *result = ir_const_undef(ira, scalar, return_type);
+        result->value.data.x_array.data.s_none.elements =
+            allocate<ConstExprValue>(len_int);
+        for (uint32_t i = 0; i < len_int; i++) {
+            result->value.data.x_array.data.s_none.elements[i] =
+                scalar->value;
+        }
+        result->value.type = return_type;
+        result->value.special = ConstValSpecialStatic;
+        return result;
+    }
+
+    IrInstruction *result = ir_build_splat(&ira->new_irb,
+        instruction->base.scope, instruction->base.source_node,
+        instruction->len->child, instruction->scalar->child);
+    result->value.type = return_type;
+    result->value.special = ConstValSpecialRuntime;
+    return result;
+}
+
 static IrInstruction *ir_analyze_instruction_bool_not(IrAnalyze *ira, IrInstructionBoolNot *instruction) {
     IrInstruction *value = instruction->value->child;
     if (type_is_invalid(value->value.type))
@@ -25908,6 +25987,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
             return ir_analyze_instruction_vector_type(ira, (IrInstructionVectorType *)instruction);
         case IrInstructionIdShuffleVector:
             return ir_analyze_instruction_shuffle_vector(ira, (IrInstructionShuffleVector *)instruction);
+         case IrInstructionIdSplat:
+            return ir_analyze_instruction_splat(ira, (IrInstructionSplat *)instruction);
         case IrInstructionIdBoolNot:
             return ir_analyze_instruction_bool_not(ira, (IrInstructionBoolNot *)instruction);
         case IrInstructionIdMemset:
@@ -26244,6 +26325,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdIntType:
         case IrInstructionIdVectorType:
         case IrInstructionIdShuffleVector:
+        case IrInstructionIdSplat:
         case IrInstructionIdBoolNot:
         case IrInstructionIdSliceSrc:
         case IrInstructionIdMemberCount:
src/ir_print.cpp
@@ -44,6 +44,8 @@ static const char* ir_instruction_type_str(IrInstruction* instruction) {
             return "Invalid";
         case IrInstructionIdShuffleVector:
             return "Shuffle";
+        case IrInstructionIdSplat:
+            return "Splat";
         case IrInstructionIdDeclVarSrc:
             return "DeclVarSrc";
         case IrInstructionIdDeclVarGen:
@@ -1222,6 +1224,14 @@ static void ir_print_shuffle_vector(IrPrint *irp, IrInstructionShuffleVector *in
     fprintf(irp->f, ")");
 }
 
+static void ir_print_splat(IrPrint *irp, IrInstructionSplat *instruction) {
+    fprintf(irp->f, "@splat(");
+    ir_print_other_instruction(irp, instruction->len);
+    fprintf(irp->f, ", ");
+    ir_print_other_instruction(irp, instruction->scalar);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_bool_not(IrPrint *irp, IrInstructionBoolNot *instruction) {
     fprintf(irp->f, "! ");
     ir_print_other_instruction(irp, instruction->value);
@@ -2160,6 +2170,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction, bool
         case IrInstructionIdShuffleVector:
             ir_print_shuffle_vector(irp, (IrInstructionShuffleVector *)instruction);
             break;
+        case IrInstructionIdSplat:
+            ir_print_splat(irp, (IrInstructionSplat *)instruction);
+            break;
         case IrInstructionIdBoolNot:
             ir_print_bool_not(irp, (IrInstructionBoolNot *)instruction);
             break;
test/stage1/behavior/vector.zig
@@ -35,12 +35,12 @@ test "vector bin compares with mem.eql" {
         fn doTheTest() void {
             var v: @Vector(4, i32) = [4]i32{ 2147483647, -2, 30, 40 };
             var x: @Vector(4, i32) = [4]i32{ 1, 2147483647, 30, 4 };
-            expect(mem.eql(bool, ([4]bool)(v == x), [4]bool{ false, false,  true, false}));
-            expect(mem.eql(bool, ([4]bool)(v != x), [4]bool{  true,  true, false,  true}));
-            expect(mem.eql(bool, ([4]bool)(v  < x), [4]bool{ false,  true, false, false}));
-            expect(mem.eql(bool, ([4]bool)(v  > x), [4]bool{  true, false, false,  true}));
-            expect(mem.eql(bool, ([4]bool)(v <= x), [4]bool{ false,  true,  true, false}));
-            expect(mem.eql(bool, ([4]bool)(v >= x), [4]bool{  true, false,  true,  true}));
+            expect(mem.eql(bool, ([4]bool)(v == x), [4]bool{ false, false, true, false }));
+            expect(mem.eql(bool, ([4]bool)(v != x), [4]bool{ true, true, false, true }));
+            expect(mem.eql(bool, ([4]bool)(v < x), [4]bool{ false, true, false, false }));
+            expect(mem.eql(bool, ([4]bool)(v > x), [4]bool{ true, false, false, true }));
+            expect(mem.eql(bool, ([4]bool)(v <= x), [4]bool{ false, true, true, false }));
+            expect(mem.eql(bool, ([4]bool)(v >= x), [4]bool{ true, false, true, true }));
         }
     };
     S.doTheTest();
@@ -114,22 +114,22 @@ test "vector casts of sizes not divisable by 8" {
     const S = struct {
         fn doTheTest() void {
             {
-                var v: @Vector(4, u3) = [4]u3{ 5, 2,  3, 0};
+                var v: @Vector(4, u3) = [4]u3{ 5, 2, 3, 0 };
                 var x: [4]u3 = v;
                 expect(mem.eql(u3, x, ([4]u3)(v)));
             }
             {
-                var v: @Vector(4, u2) = [4]u2{ 1, 2,  3, 0};
+                var v: @Vector(4, u2) = [4]u2{ 1, 2, 3, 0 };
                 var x: [4]u2 = v;
                 expect(mem.eql(u2, x, ([4]u2)(v)));
             }
             {
-                var v: @Vector(4, u1) = [4]u1{ 1, 0,  1, 0};
+                var v: @Vector(4, u1) = [4]u1{ 1, 0, 1, 0 };
                 var x: [4]u1 = v;
                 expect(mem.eql(u1, x, ([4]u1)(v)));
             }
             {
-                var v: @Vector(4, bool) = [4]bool{ false, false,  true, false};
+                var v: @Vector(4, bool) = [4]bool{ false, false, true, false };
                 var x: [4]bool = v;
                 expect(mem.eql(bool, x, ([4]bool)(v)));
             }
@@ -138,3 +138,19 @@ test "vector casts of sizes not divisable by 8" {
     S.doTheTest();
     comptime S.doTheTest();
 }
+
+test "vector @splat" {
+    const S = struct {
+        fn doTheTest() void {
+            var v: u32 = 5;
+            var x = @splat(4, v);
+            expect(@typeOf(x) == @Vector(4, u32));
+            expect(x[0] == 5);
+            expect(x[1] == 5);
+            expect(x[2] == 5);
+            expect(x[3] == 5);
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
test/compile_errors.zig
@@ -6507,6 +6507,16 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
         "tmp.zig:2:26: error: vector element type must be integer, float, bool, or pointer; '@Vector(4, u8)' is invalid",
     );
 
+    cases.addTest(
+        "bad @splat type",
+        \\export fn entry() void {
+        \\    const c = 4;
+        \\    var v = @splat(4, c);
+        \\}
+    ,
+        "tmp.zig:3:20: error: vector element type must be integer, float, bool, or pointer; 'comptime_int' is invalid",
+    );
+
     cases.add("compileLog of tagged enum doesn't crash the compiler",
         \\const Bar = union(enum(u32)) {
         \\    X: i32 = 1