Commit 005a54a853

Andrew Kelley <andrew@ziglang.org>
2019-09-19 16:48:04
fixups for `@splat`
* Fix codegen for splat - instead of giving vectors of length N to shufflevector for both of the operands, it gives vectors of length 1. The mask vector is the only one that needs N elements. * Separate Splat into SplatSrc and SplatGen; the `len` is not needed once it gets to codegen since it is redundant with the result type. * Refactor compile error for wrong vector element type so that the compile error message is not duplicated in zig source code * Improve implementation to correctly handle comptime values such as undefined and lazy values. * Improve compile error for bad vector element type to point to the correct place. * Delete dead code. * Modify behavior test to use an array cast instead of vector element indexing since I'm merging this splat commit out-of-order from Shawn's patch set.
1 parent 01577a3
src/all_types.hpp
@@ -2432,7 +2432,8 @@ enum IrInstructionId {
     IrInstructionIdIntType,
     IrInstructionIdVectorType,
     IrInstructionIdShuffleVector,
-    IrInstructionIdSplat,
+    IrInstructionIdSplatSrc,
+    IrInstructionIdSplatGen,
     IrInstructionIdBoolNot,
     IrInstructionIdMemset,
     IrInstructionIdMemcpy,
@@ -3683,13 +3684,19 @@ struct IrInstructionShuffleVector {
     IrInstruction *mask; // This is in zig-format, not llvm format
 };
 
-struct IrInstructionSplat {
+struct IrInstructionSplatSrc {
     IrInstruction base;
 
     IrInstruction *len;
     IrInstruction *scalar;
 };
 
+struct IrInstructionSplatGen {
+    IrInstruction base;
+
+    IrInstruction *scalar;
+};
+
 struct IrInstructionAssertZero {
     IrInstruction base;
 
src/codegen.cpp
@@ -4619,18 +4619,16 @@ static LLVMValueRef ir_render_shuffle_vector(CodeGen *g, IrExecutable *executabl
         llvm_mask_value, "");
 }
 
-static LLVMValueRef ir_render_splat(CodeGen *g, IrExecutable *executable, IrInstructionSplat *instruction) {
-    uint64_t len = bigint_as_u64(&instruction->len->value.data.x_bigint);
-    LLVMValueRef wrapped_scalar_undef = LLVMGetUndef(instruction->base.value.type->llvm_type);
-    LLVMValueRef wrapped_scalar = LLVMBuildInsertElement(g->builder, wrapped_scalar_undef,
-        ir_llvm_value(g, instruction->scalar),
-        LLVMConstInt(LLVMInt32Type(), 0, false),
-        "");
-    return LLVMBuildShuffleVector(g->builder,
-        wrapped_scalar,
-        wrapped_scalar_undef,
-        LLVMConstNull(LLVMVectorType(g->builtin_types.entry_u32->llvm_type, (uint32_t)len)),
-        "");
+static LLVMValueRef ir_render_splat(CodeGen *g, IrExecutable *executable, IrInstructionSplatGen *instruction) {
+    ZigType *result_type = instruction->base.value.type;
+    src_assert(result_type->id == ZigTypeIdVector, instruction->base.source_node);
+    uint32_t len = result_type->data.vector.len;
+    LLVMTypeRef op_llvm_type = LLVMVectorType(get_llvm_type(g, instruction->scalar->value.type), 1);
+    LLVMTypeRef mask_llvm_type = LLVMVectorType(LLVMInt32Type(), len);
+    LLVMValueRef undef_vector = LLVMGetUndef(op_llvm_type);
+    LLVMValueRef op_vector = LLVMBuildInsertElement(g->builder, undef_vector,
+            ir_llvm_value(g, instruction->scalar), LLVMConstInt(LLVMInt32Type(), 0, false), "");
+    return LLVMBuildShuffleVector(g->builder, op_vector, undef_vector, LLVMConstNull(mask_llvm_type), "");
 }
 
 static LLVMValueRef ir_render_pop_count(CodeGen *g, IrExecutable *executable, IrInstructionPopCount *instruction) {
@@ -6000,6 +5998,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
         case IrInstructionIdFrameSizeSrc:
         case IrInstructionIdAllocaGen:
         case IrInstructionIdAwaitSrc:
+        case IrInstructionIdSplatSrc:
             zig_unreachable();
 
         case IrInstructionIdDeclVarGen:
@@ -6160,8 +6159,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_spill_end(g, executable, (IrInstructionSpillEnd *)instruction);
         case IrInstructionIdShuffleVector:
             return ir_render_shuffle_vector(g, executable, (IrInstructionShuffleVector *) instruction);
-        case IrInstructionIdSplat:
-            return ir_render_splat(g, executable, (IrInstructionSplat *) instruction);
+        case IrInstructionIdSplatGen:
+            return ir_render_splat(g, executable, (IrInstructionSplatGen *) instruction);
     }
     zig_unreachable();
 }
src/ir.cpp
@@ -721,8 +721,12 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionShuffleVector *)
     return IrInstructionIdShuffleVector;
 }
 
-static constexpr IrInstructionId ir_instruction_id(IrInstructionSplat *) {
-    return IrInstructionIdSplat;
+static constexpr IrInstructionId ir_instruction_id(IrInstructionSplatSrc *) {
+    return IrInstructionIdSplatSrc;
+}
+
+static constexpr IrInstructionId ir_instruction_id(IrInstructionSplatGen *) {
+    return IrInstructionIdSplatGen;
 }
 
 static constexpr IrInstructionId ir_instruction_id(IrInstructionBoolNot *) {
@@ -2304,10 +2308,10 @@ static IrInstruction *ir_build_shuffle_vector(IrBuilder *irb, Scope *scope, AstN
     return &instruction->base;
 }
 
-static IrInstruction *ir_build_splat(IrBuilder *irb, Scope *scope, AstNode *source_node,
+static IrInstruction *ir_build_splat_src(IrBuilder *irb, Scope *scope, AstNode *source_node,
     IrInstruction *len, IrInstruction *scalar)
 {
-    IrInstructionSplat *instruction = ir_build_instruction<IrInstructionSplat>(irb, scope, source_node);
+    IrInstructionSplatSrc *instruction = ir_build_instruction<IrInstructionSplatSrc>(irb, scope, source_node);
     instruction->len = len;
     instruction->scalar = scalar;
 
@@ -2373,6 +2377,19 @@ static IrInstruction *ir_build_slice_src(IrBuilder *irb, Scope *scope, AstNode *
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_splat_gen(IrAnalyze *ira, IrInstruction *source_instruction, ZigType *result_type,
+    IrInstruction *scalar)
+{
+    IrInstructionSplatGen *instruction = ir_build_instruction<IrInstructionSplatGen>(
+            &ira->new_irb, source_instruction->scope, source_instruction->source_node);
+    instruction->base.value.type = result_type;
+    instruction->scalar = scalar;
+
+    ir_ref_instruction(scalar, ira->new_irb.current_basic_block);
+
+    return &instruction->base;
+}
+
 static IrInstruction *ir_build_slice_gen(IrAnalyze *ira, IrInstruction *source_instruction, ZigType *slice_type,
     IrInstruction *ptr, IrInstruction *start, IrInstruction *end, bool safety_check_on, IrInstruction *result_loc)
 {
@@ -5014,7 +5031,7 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
                 if (arg1_value == irb->codegen->invalid_instruction)
                     return arg1_value;
 
-                IrInstruction *splat = ir_build_splat(irb, scope, node,
+                IrInstruction *splat = ir_build_splat_src(irb, scope, node,
                     arg0_value, arg1_value);
                 return ir_lval_wrap(irb, scope, splat, lval, result_loc);
             }
@@ -11082,16 +11099,23 @@ static ZigType *ir_resolve_type(IrAnalyze *ira, IrInstruction *type_value) {
     return ir_resolve_const_type(ira->codegen, ira->new_irb.exec, type_value->source_node, val);
 }
 
+static Error ir_validate_vector_elem_type(IrAnalyze *ira, IrInstruction *source_instr, ZigType *elem_type) {
+    if (!is_valid_vector_elem_type(elem_type)) {
+        ir_add_error(ira, source_instr,
+            buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid",
+                buf_ptr(&elem_type->name)));
+        return ErrorSemanticAnalyzeFail;
+    }
+    return ErrorNone;
+}
+
 static ZigType *ir_resolve_vector_elem_type(IrAnalyze *ira, IrInstruction *elem_type_value) {
+    Error err;
     ZigType *elem_type = ir_resolve_type(ira, elem_type_value);
     if (type_is_invalid(elem_type))
         return ira->codegen->builtin_types.entry_invalid;
-    if (!is_valid_vector_elem_type(elem_type)) {
-        ir_add_error(ira, elem_type_value,
-            buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid",
-                buf_ptr(&elem_type->name)));
+    if ((err = ir_validate_vector_elem_type(ira, elem_type_value, elem_type)))
         return ira->codegen->builtin_types.entry_invalid;
-    }
     return elem_type;
 }
 
@@ -22357,7 +22381,9 @@ static IrInstruction *ir_analyze_instruction_shuffle_vector(IrAnalyze *ira, IrIn
     return ir_analyze_shuffle_vector(ira, &instruction->base, scalar_type, a, b, mask);
 }
 
-static IrInstruction *ir_analyze_instruction_splat(IrAnalyze *ira, IrInstructionSplat *instruction) {
+static IrInstruction *ir_analyze_instruction_splat(IrAnalyze *ira, IrInstructionSplatSrc *instruction) {
+    Error err;
+
     IrInstruction *len = instruction->len->child;
     if (type_is_invalid(len->value.type))
         return ira->codegen->invalid_instruction;
@@ -22366,41 +22392,32 @@ static IrInstruction *ir_analyze_instruction_splat(IrAnalyze *ira, IrInstruction
     if (type_is_invalid(scalar->value.type))
         return ira->codegen->invalid_instruction;
 
-    uint64_t len_int;
-    if (!ir_resolve_unsigned(ira, len, ira->codegen->builtin_types.entry_u32, &len_int)) {
-        ir_add_error(ira, len,
-            buf_sprintf("splat length must be comptime"));
+    uint64_t len_u64;
+    if (!ir_resolve_unsigned(ira, len, ira->codegen->builtin_types.entry_u32, &len_u64))
         return ira->codegen->invalid_instruction;
-    }
+    uint32_t len_int = len_u64;
 
-    if (!is_valid_vector_elem_type(scalar->value.type)) {
-        ir_add_error(ira, len,
-            buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid",
-                buf_ptr(&scalar->value.type->name)));
+    if ((err = ir_validate_vector_elem_type(ira, scalar, scalar->value.type)))
         return ira->codegen->invalid_instruction;
-    }
 
     ZigType *return_type = get_vector_type(ira->codegen, len_int, scalar->value.type);
 
     if (instr_is_comptime(scalar)) {
-        IrInstruction *result = ir_const_undef(ira, scalar, return_type);
-        result->value.data.x_array.data.s_none.elements =
-            allocate<ConstExprValue>(len_int);
-        for (uint32_t i = 0; i < len_int; i++) {
-            result->value.data.x_array.data.s_none.elements[i] =
-                scalar->value;
+        ConstExprValue *scalar_val = ir_resolve_const(ira, scalar, UndefOk);
+        if (scalar_val == nullptr)
+            return ira->codegen->invalid_instruction;
+        if (scalar_val->special == ConstValSpecialUndef)
+            return ir_const_undef(ira, &instruction->base, return_type);
+
+        IrInstruction *result = ir_const(ira, &instruction->base, return_type);
+        result->value.data.x_array.data.s_none.elements = create_const_vals(len_int);
+        for (uint32_t i = 0; i < len_int; i += 1) {
+            copy_const_val(&result->value.data.x_array.data.s_none.elements[i], scalar_val, false);
         }
-        result->value.type = return_type;
-        result->value.special = ConstValSpecialStatic;
         return result;
     }
 
-    IrInstruction *result = ir_build_splat(&ira->new_irb,
-        instruction->base.scope, instruction->base.source_node,
-        instruction->len->child, instruction->scalar->child);
-    result->value.type = return_type;
-    result->value.special = ConstValSpecialRuntime;
-    return result;
+    return ir_build_splat_gen(ira, &instruction->base, return_type, scalar);
 }
 
 static IrInstruction *ir_analyze_instruction_bool_not(IrAnalyze *ira, IrInstructionBoolNot *instruction) {
@@ -25857,6 +25874,7 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
         case IrInstructionIdTestErrGen:
         case IrInstructionIdFrameSizeGen:
         case IrInstructionIdAwaitGen:
+        case IrInstructionIdSplatGen:
             zig_unreachable();
 
         case IrInstructionIdReturn:
@@ -25987,8 +26005,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
             return ir_analyze_instruction_vector_type(ira, (IrInstructionVectorType *)instruction);
         case IrInstructionIdShuffleVector:
             return ir_analyze_instruction_shuffle_vector(ira, (IrInstructionShuffleVector *)instruction);
-         case IrInstructionIdSplat:
-            return ir_analyze_instruction_splat(ira, (IrInstructionSplat *)instruction);
+         case IrInstructionIdSplatSrc:
+            return ir_analyze_instruction_splat(ira, (IrInstructionSplatSrc *)instruction);
         case IrInstructionIdBoolNot:
             return ir_analyze_instruction_bool_not(ira, (IrInstructionBoolNot *)instruction);
         case IrInstructionIdMemset:
@@ -26325,7 +26343,8 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdIntType:
         case IrInstructionIdVectorType:
         case IrInstructionIdShuffleVector:
-        case IrInstructionIdSplat:
+        case IrInstructionIdSplatSrc:
+        case IrInstructionIdSplatGen:
         case IrInstructionIdBoolNot:
         case IrInstructionIdSliceSrc:
         case IrInstructionIdMemberCount:
src/ir_print.cpp
@@ -44,8 +44,10 @@ static const char* ir_instruction_type_str(IrInstruction* instruction) {
             return "Invalid";
         case IrInstructionIdShuffleVector:
             return "Shuffle";
-        case IrInstructionIdSplat:
-            return "Splat";
+        case IrInstructionIdSplatSrc:
+            return "SplatSrc";
+        case IrInstructionIdSplatGen:
+            return "SplatGen";
         case IrInstructionIdDeclVarSrc:
             return "DeclVarSrc";
         case IrInstructionIdDeclVarGen:
@@ -1224,7 +1226,7 @@ static void ir_print_shuffle_vector(IrPrint *irp, IrInstructionShuffleVector *in
     fprintf(irp->f, ")");
 }
 
-static void ir_print_splat(IrPrint *irp, IrInstructionSplat *instruction) {
+static void ir_print_splat_src(IrPrint *irp, IrInstructionSplatSrc *instruction) {
     fprintf(irp->f, "@splat(");
     ir_print_other_instruction(irp, instruction->len);
     fprintf(irp->f, ", ");
@@ -1232,6 +1234,12 @@ static void ir_print_splat(IrPrint *irp, IrInstructionSplat *instruction) {
     fprintf(irp->f, ")");
 }
 
+static void ir_print_splat_gen(IrPrint *irp, IrInstructionSplatGen *instruction) {
+    fprintf(irp->f, "@splat(");
+    ir_print_other_instruction(irp, instruction->scalar);
+    fprintf(irp->f, ")");
+}
+
 static void ir_print_bool_not(IrPrint *irp, IrInstructionBoolNot *instruction) {
     fprintf(irp->f, "! ");
     ir_print_other_instruction(irp, instruction->value);
@@ -2170,8 +2178,11 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction, bool
         case IrInstructionIdShuffleVector:
             ir_print_shuffle_vector(irp, (IrInstructionShuffleVector *)instruction);
             break;
-        case IrInstructionIdSplat:
-            ir_print_splat(irp, (IrInstructionSplat *)instruction);
+        case IrInstructionIdSplatSrc:
+            ir_print_splat_src(irp, (IrInstructionSplatSrc *)instruction);
+            break;
+        case IrInstructionIdSplatGen:
+            ir_print_splat_gen(irp, (IrInstructionSplatGen *)instruction);
             break;
         case IrInstructionIdBoolNot:
             ir_print_bool_not(irp, (IrInstructionBoolNot *)instruction);
test/stage1/behavior/vector.zig
@@ -145,10 +145,11 @@ test "vector @splat" {
             var v: u32 = 5;
             var x = @splat(4, v);
             expect(@typeOf(x) == @Vector(4, u32));
-            expect(x[0] == 5);
-            expect(x[1] == 5);
-            expect(x[2] == 5);
-            expect(x[3] == 5);
+            var array_x: [4]u32 = x;
+            expect(array_x[0] == 5);
+            expect(array_x[1] == 5);
+            expect(array_x[2] == 5);
+            expect(array_x[3] == 5);
         }
     };
     S.doTheTest();
test/compile_errors.zig
@@ -6514,7 +6514,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
         \\    var v = @splat(4, c);
         \\}
     ,
-        "tmp.zig:3:20: error: vector element type must be integer, float, bool, or pointer; 'comptime_int' is invalid",
+        "tmp.zig:3:23: error: vector element type must be integer, float, bool, or pointer; 'comptime_int' is invalid",
     );
 
     cases.add("compileLog of tagged enum doesn't crash the compiler",