Commit cbaa10fc3b

Andrew Kelley <andrew@ziglang.org>
2019-11-02 04:16:49
implement storing vector elements via runtime index
1 parent 70be308
src/all_types.hpp
@@ -2426,6 +2426,7 @@ enum IrInstructionId {
     IrInstructionIdLoadPtr,
     IrInstructionIdLoadPtrGen,
     IrInstructionIdStorePtr,
+    IrInstructionIdVectorStoreElem,
     IrInstructionIdFieldPtr,
     IrInstructionIdStructFieldPtr,
     IrInstructionIdUnionFieldPtr,
@@ -2770,6 +2771,14 @@ struct IrInstructionStorePtr {
     IrInstruction *value;
 };
 
+struct IrInstructionVectorStoreElem {
+    IrInstruction base;
+
+    IrInstruction *vector_ptr;
+    IrInstruction *index;
+    IrInstruction *value;
+};
+
 struct IrInstructionFieldPtr {
     IrInstruction base;
 
src/codegen.cpp
@@ -3644,6 +3644,19 @@ static LLVMValueRef ir_render_store_ptr(CodeGen *g, IrExecutable *executable, Ir
     return nullptr;
 }
 
+static LLVMValueRef ir_render_vector_store_elem(CodeGen *g, IrExecutable *executable,
+        IrInstructionVectorStoreElem *instruction)
+{
+    LLVMValueRef vector_ptr = ir_llvm_value(g, instruction->vector_ptr);
+    LLVMValueRef index = ir_llvm_value(g, instruction->index);
+    LLVMValueRef value = ir_llvm_value(g, instruction->value);
+
+    LLVMValueRef loaded_vector = gen_load(g, vector_ptr, instruction->vector_ptr->value.type, "");
+    LLVMValueRef modified_vector = LLVMBuildInsertElement(g->builder, loaded_vector, value, index, "");
+    gen_store(g, modified_vector, vector_ptr, instruction->vector_ptr->value.type);
+    return nullptr;
+}
+
 static LLVMValueRef ir_render_var_ptr(CodeGen *g, IrExecutable *executable, IrInstructionVarPtr *instruction) {
     if (instruction->base.value.special != ConstValSpecialRuntime)
         return ir_llvm_value(g, &instruction->base);
@@ -6130,6 +6143,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_load_ptr(g, executable, (IrInstructionLoadPtrGen *)instruction);
         case IrInstructionIdStorePtr:
             return ir_render_store_ptr(g, executable, (IrInstructionStorePtr *)instruction);
+        case IrInstructionIdVectorStoreElem:
+            return ir_render_vector_store_elem(g, executable, (IrInstructionVectorStoreElem *)instruction);
         case IrInstructionIdVarPtr:
             return ir_render_var_ptr(g, executable, (IrInstructionVarPtr *)instruction);
         case IrInstructionIdReturnPtr:
src/ir.cpp
@@ -491,6 +491,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionStorePtr *) {
     return IrInstructionIdStorePtr;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionVectorStoreElem *) {
+    return IrInstructionIdVectorStoreElem;
+}
+
 static constexpr IrInstructionId ir_instruction_id(IrInstructionFieldPtr *) {
     return IrInstructionIdFieldPtr;
 }
@@ -1631,6 +1635,23 @@ static IrInstructionStorePtr *ir_build_store_ptr(IrBuilder *irb, Scope *scope, A
     return instruction;
 }
 
+static IrInstruction *ir_build_vector_store_elem(IrAnalyze *ira, IrInstruction *source_instruction,
+        IrInstruction *vector_ptr, IrInstruction *index, IrInstruction *value)
+{
+    IrInstructionVectorStoreElem *inst = ir_build_instruction<IrInstructionVectorStoreElem>(
+            &ira->new_irb, source_instruction->scope, source_instruction->source_node);
+    inst->base.value.type = ira->codegen->builtin_types.entry_void;
+    inst->vector_ptr = vector_ptr;
+    inst->index = index;
+    inst->value = value;
+
+    ir_ref_instruction(vector_ptr, ira->new_irb.current_basic_block);
+    ir_ref_instruction(index, ira->new_irb.current_basic_block);
+    ir_ref_instruction(value, ira->new_irb.current_basic_block);
+
+    return &inst->base;
+}
+
 static IrInstruction *ir_build_var_decl_src(IrBuilder *irb, Scope *scope, AstNode *source_node,
         ZigVar *var, IrInstruction *align_value, IrInstruction *ptr)
 {
@@ -16126,6 +16147,24 @@ static IrInstruction *ir_analyze_store_ptr(IrAnalyze *ira, IrInstruction *source
         mark_comptime_value_escape(ira, source_instr, &value->value);
     }
 
+    // If this is a store to a pointer with a runtime-known vector index,
+    // we have to figure out the IrInstruction which represents the index and
+    // emit a IrInstructionVectorStoreElem, or emit a compile error
+    // explaining why it is impossible for this store to work. Which is that
+    // the pointer address is of the vector; without the element index being known
+    // we cannot properly perform the insertion.
+    if (ptr->value.type->data.pointer.vector_index == VECTOR_INDEX_RUNTIME) {
+        if (ptr->id == IrInstructionIdElemPtr) {
+            IrInstructionElemPtr *elem_ptr = (IrInstructionElemPtr *)ptr;
+            return ir_build_vector_store_elem(ira, source_instr, elem_ptr->array_ptr,
+                    elem_ptr->elem_index, value);
+        }
+        ir_add_error(ira, ptr,
+            buf_sprintf("unable to determine vector element index of type '%s'",
+                buf_ptr(&ptr->value.type->name)));
+        return ira->codegen->invalid_instruction;
+    }
+
     IrInstructionStorePtr *store_ptr = ir_build_store_ptr(&ira->new_irb, source_instr->scope,
             source_instr->source_node, ptr, value);
     return &store_ptr->base;
@@ -26063,6 +26102,7 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
         case IrInstructionIdAwaitGen:
         case IrInstructionIdSplatGen:
         case IrInstructionIdVectorExtractElem:
+        case IrInstructionIdVectorStoreElem:
             zig_unreachable();
 
         case IrInstructionIdReturn:
@@ -26446,6 +26486,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdDeclVarSrc:
         case IrInstructionIdDeclVarGen:
         case IrInstructionIdStorePtr:
+        case IrInstructionIdVectorStoreElem:
         case IrInstructionIdCallSrc:
         case IrInstructionIdCallGen:
         case IrInstructionIdReturn:
src/ir_print.cpp
@@ -78,6 +78,8 @@ const char* ir_instruction_type_str(IrInstructionId id) {
             return "LoadPtrGen";
         case IrInstructionIdStorePtr:
             return "StorePtr";
+        case IrInstructionIdVectorStoreElem:
+            return "VectorStoreElem";
         case IrInstructionIdFieldPtr:
             return "FieldPtr";
         case IrInstructionIdStructFieldPtr:
@@ -790,6 +792,15 @@ static void ir_print_store_ptr(IrPrint *irp, IrInstructionStorePtr *instruction)
     ir_print_other_instruction(irp, instruction->value);
 }
 
+static void ir_print_vector_store_elem(IrPrint *irp, IrInstructionVectorStoreElem *instruction) {
+    fprintf(irp->f, "vector_ptr=");
+    ir_print_var_instruction(irp, instruction->vector_ptr);
+    fprintf(irp->f, ",index=");
+    ir_print_var_instruction(irp, instruction->index);
+    fprintf(irp->f, ",value=");
+    ir_print_other_instruction(irp, instruction->value);
+}
+
 static void ir_print_typeof(IrPrint *irp, IrInstructionTypeOf *instruction) {
     fprintf(irp->f, "@typeOf(");
     ir_print_other_instruction(irp, instruction->value);
@@ -2047,6 +2058,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction, bool
         case IrInstructionIdStorePtr:
             ir_print_store_ptr(irp, (IrInstructionStorePtr *)instruction);
             break;
+        case IrInstructionIdVectorStoreElem:
+            ir_print_vector_store_elem(irp, (IrInstructionVectorStoreElem *)instruction);
+            break;
         case IrInstructionIdTypeOf:
             ir_print_typeof(irp, (IrInstructionTypeOf *)instruction);
             break;
test/stage1/behavior/vector.zig
@@ -216,3 +216,21 @@ test "load vector elements via runtime index" {
     S.doTheTest();
     comptime S.doTheTest();
 }
+
+test "store vector elements via runtime index" {
+    const S = struct {
+        fn doTheTest() void {
+            var v: @Vector(4, i32) = [_]i32{ 1, 5, 3, undefined };
+            var i: u32 = 2;
+            v[i] = 1;
+            expect(v[1] == 5);
+            expect(v[2] == 1);
+            i += 1;
+            v[i] = -364;
+            expect(-364 == v[3]);
+        }
+    };
+
+    S.doTheTest();
+    comptime S.doTheTest();
+}
test/compile_errors.zig
@@ -26,6 +26,23 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
 
     cases.add(
         "dereference vector pointer with unknown runtime index",
+        "store vector pointer with unknown runtime index",
+        \\export fn entry() void {
+        \\    var v: @Vector(4, i32) = [_]i32{ 1, 5, 3, undefined };
+        \\
+        \\    var i: u32 = 0;
+        \\    storev(&v[i], 42);
+        \\}
+        \\
+        \\fn storev(ptr: var, val: i32) void {
+        \\    ptr.* = val;
+        \\}
+    ,
+        "tmp.zig:9:8: error: unable to determine vector element index of type '*align(16:0:4:?) i32",
+    );
+
+    cases.add(
+        "load vector pointer with unknown runtime index",
         \\export fn entry() void {
         \\    var v: @Vector(4, i32) = [_]i32{ 1, 5, 3, undefined };
         \\