Commit a59d31bd28

LemonBoy <thatlemon@gmail.com>
2020-03-03 21:46:30
ir: Support tuple multiplication
1 parent e4eb817
Changed files (2)
src
test
stage1
behavior
src/ir.cpp
@@ -17351,14 +17351,15 @@ static IrInstGen *ir_analyze_tuple_cat(IrAnalyze *ira, IrInst* source_instr,
         ContainerKindStruct, source_instr->source_node, buf_ptr(name), bare_name, ContainerLayoutAuto);
     new_type->data.structure.special = StructSpecialInferredTuple;
     new_type->data.structure.resolve_status = ResolveStatusBeingInferred;
-
-    IrInstGen *new_struct_ptr = ir_resolve_result(ira, source_instr, no_result_loc(),
-            new_type, nullptr, false, true);
     uint32_t new_field_count = op1_field_count + op2_field_count;
 
     new_type->data.structure.src_field_count = new_field_count;
     new_type->data.structure.fields = realloc_type_struct_fields(new_type->data.structure.fields,
             0, new_field_count);
+
+    IrInstGen *new_struct_ptr = ir_resolve_result(ira, source_instr, no_result_loc(),
+            new_type, nullptr, false, true);
+
     for (uint32_t i = 0; i < new_field_count; i += 1) {
         TypeStructField *src_field;
         if (i < op1_field_count) {
@@ -17422,8 +17423,10 @@ static IrInstGen *ir_analyze_tuple_cat(IrAnalyze *ira, IrInst* source_instr,
             ir_analyze_store_ptr(ira, &elem_result_loc->base, elem_result_loc, deref, true);
         }
     }
-    IrInstGen *result = ir_get_deref(ira, source_instr, new_struct_ptr, nullptr);
-    return result;
+
+    const_ptrs.deinit();
+
+    return ir_get_deref(ira, source_instr, new_struct_ptr, nullptr);
 }
 
 static IrInstGen *ir_analyze_array_cat(IrAnalyze *ira, IrInstSrcBinOp *instruction) {
@@ -17480,8 +17483,9 @@ static IrInstGen *ir_analyze_array_cat(IrAnalyze *ira, IrInstSrcBinOp *instructi
         ZigValue *len_val = op1_val->data.x_struct.fields[slice_len_index];
         op1_array_end = op1_array_index + bigint_as_usize(&len_val->data.x_bigint);
         sentinel1 = ptr_type->data.pointer.sentinel;
-    } else if (op1_type->id == ZigTypeIdPointer && op1_type->data.pointer.ptr_len == PtrLenSingle &&
-            op1_type->data.pointer.child_type->id == ZigTypeIdArray)
+    } else if (op1_type->id == ZigTypeIdPointer &&
+               op1_type->data.pointer.ptr_len == PtrLenSingle &&
+               op1_type->data.pointer.child_type->id == ZigTypeIdArray)
     {
         ZigType *array_type = op1_type->data.pointer.child_type;
         child_type = array_type->data.array.child_type;
@@ -17654,6 +17658,103 @@ static IrInstGen *ir_analyze_array_cat(IrAnalyze *ira, IrInstSrcBinOp *instructi
     return result;
 }
 
+static IrInstGen *ir_analyze_tuple_mult(IrAnalyze *ira, IrInst* source_instr,
+                                        IrInstGen *op1, IrInstGen *op2)
+{
+    Error err;
+    ZigType *op1_type = op1->value->type;
+    uint64_t op1_field_count = op1_type->data.structure.src_field_count;
+
+    uint64_t mult_amt;
+    if (!ir_resolve_usize(ira, op2, &mult_amt))
+        return ira->codegen->invalid_inst_gen;
+
+    uint64_t new_field_count;
+    if (mul_u64_overflow(op1_field_count, mult_amt, &new_field_count)) {
+        ir_add_error(ira, source_instr, buf_sprintf("operation results in overflow"));
+        return ira->codegen->invalid_inst_gen;
+    }
+
+    Buf *bare_name = buf_alloc();
+    Buf *name = get_anon_type_name(ira->codegen, nullptr, container_string(ContainerKindStruct),
+        source_instr->scope, source_instr->source_node, bare_name);
+    ZigType *new_type = get_partial_container_type(ira->codegen, source_instr->scope,
+        ContainerKindStruct, source_instr->source_node, buf_ptr(name), bare_name, ContainerLayoutAuto);
+    new_type->data.structure.special = StructSpecialInferredTuple;
+    new_type->data.structure.resolve_status = ResolveStatusBeingInferred;
+    new_type->data.structure.src_field_count = new_field_count;
+    new_type->data.structure.fields = realloc_type_struct_fields(
+        new_type->data.structure.fields, 0, new_field_count);
+
+    IrInstGen *new_struct_ptr = ir_resolve_result(ira, source_instr, no_result_loc(),
+        new_type, nullptr, false, true);
+
+    for (uint64_t i = 0; i < new_field_count; i += 1) {
+        TypeStructField *src_field = op1_type->data.structure.fields[i % op1_field_count];
+        TypeStructField *new_field = new_type->data.structure.fields[i];
+
+        new_field->name = buf_sprintf("%lu", i);
+        new_field->type_entry = src_field->type_entry;
+        new_field->type_val = src_field->type_val;
+        new_field->src_index = i;
+        new_field->decl_node = src_field->decl_node;
+        new_field->init_val = src_field->init_val;
+        new_field->is_comptime = src_field->is_comptime;
+    }
+
+    if ((err = type_resolve(ira->codegen, new_type, ResolveStatusZeroBitsKnown)))
+        return ira->codegen->invalid_inst_gen;
+
+    ZigList<IrInstGen *> const_ptrs = {};
+    for (uint64_t i = 0; i < new_field_count; i += 1) {
+        TypeStructField *src_field = op1_type->data.structure.fields[i % op1_field_count];
+        TypeStructField *dst_field = new_type->data.structure.fields[i];
+
+        IrInstGen *field_value = ir_analyze_struct_value_field_value(
+            ira, source_instr, op1, src_field);
+        if (type_is_invalid(field_value->value->type))
+            return ira->codegen->invalid_inst_gen;
+
+        IrInstGen *dest_ptr = ir_analyze_struct_field_ptr(
+            ira, source_instr, dst_field, new_struct_ptr, new_type, true);
+        if (type_is_invalid(dest_ptr->value->type))
+            return ira->codegen->invalid_inst_gen;
+
+        if (instr_is_comptime(field_value)) {
+            const_ptrs.append(dest_ptr);
+        }
+
+        IrInstGen *store_ptr_inst = ir_analyze_store_ptr(
+            ira, source_instr, dest_ptr, field_value, true);
+        if (type_is_invalid(store_ptr_inst->value->type))
+            return ira->codegen->invalid_inst_gen;
+    }
+
+    if (const_ptrs.length != new_field_count) {
+        new_struct_ptr->value->special = ConstValSpecialRuntime;
+        for (size_t i = 0; i < const_ptrs.length; i += 1) {
+            IrInstGen *elem_result_loc = const_ptrs.at(i);
+            assert(elem_result_loc->value->special == ConstValSpecialStatic);
+            if (elem_result_loc->value->type->data.pointer.inferred_struct_field != nullptr) {
+                // This field will be generated comptime; no need to do this.
+                continue;
+            }
+            IrInstGen *deref = ir_get_deref(ira, &elem_result_loc->base, elem_result_loc, nullptr);
+            if (!type_requires_comptime(ira->codegen, elem_result_loc->value->type->data.pointer.child_type)) {
+                elem_result_loc->value->special = ConstValSpecialRuntime;
+            }
+            IrInstGen *store_ptr_inst = ir_analyze_store_ptr(
+                ira, &elem_result_loc->base, elem_result_loc, deref, true);
+            if (type_is_invalid(store_ptr_inst->value->type))
+                return ira->codegen->invalid_inst_gen;
+        }
+    }
+
+    const_ptrs.deinit();
+
+    return ir_get_deref(ira, source_instr, new_struct_ptr, nullptr);
+}
+
 static IrInstGen *ir_analyze_array_mult(IrAnalyze *ira, IrInstSrcBinOp *instruction) {
     IrInstGen *op1 = instruction->op1->child;
     if (type_is_invalid(op1->value->type))
@@ -17671,8 +17772,9 @@ static IrInstGen *ir_analyze_array_mult(IrAnalyze *ira, IrInstSrcBinOp *instruct
         array_val = ir_resolve_const(ira, op1, UndefOk);
         if (array_val == nullptr)
             return ira->codegen->invalid_inst_gen;
-    } else if (op1->value->type->id == ZigTypeIdPointer && op1->value->type->data.pointer.ptr_len == PtrLenSingle &&
-        op1->value->type->data.pointer.child_type->id == ZigTypeIdArray)
+    } else if (op1->value->type->id == ZigTypeIdPointer &&
+               op1->value->type->data.pointer.ptr_len == PtrLenSingle &&
+               op1->value->type->data.pointer.child_type->id == ZigTypeIdArray)
     {
         array_type = op1->value->type->data.pointer.child_type;
         IrInstGen *array_inst = ir_get_deref(ira, &op1->base, op1, nullptr);
@@ -17682,6 +17784,8 @@ static IrInstGen *ir_analyze_array_mult(IrAnalyze *ira, IrInstSrcBinOp *instruct
         if (array_val == nullptr)
             return ira->codegen->invalid_inst_gen;
         want_ptr_to_array = true;
+    } else if (is_tuple(op1->value->type)) {
+        return ir_analyze_tuple_mult(ira, &instruction->base.base, op1, op2);
     } else {
         ir_add_error(ira, &op1->base, buf_sprintf("expected array type, found '%s'", buf_ptr(&op1->value->type->name)));
         return ira->codegen->invalid_inst_gen;
test/stage1/behavior/tuple.zig
@@ -1,5 +1,7 @@
 const std = @import("std");
-const expect = std.testing.expect;
+const testing = std.testing;
+const expect = testing.expect;
+const expectEqual = testing.expectEqual;
 
 test "tuple concatenation" {
     const S = struct {
@@ -9,8 +11,31 @@ test "tuple concatenation" {
             var x = .{a};
             var y = .{b};
             var c = x ++ y;
-            expect(c[0] == 1);
-            expect(c[1] == 2);
+            expectEqual(@as(i32, 1), c[0]);
+            expectEqual(@as(i32, 2), c[1]);
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
+
+test "tuple multiplication" {
+    const S = struct {
+        fn doTheTest() void {
+            {
+                const t = .{} ** 4;
+                expectEqual(0, @typeInfo(@TypeOf(t)).Struct.fields.len);
+            }
+            {
+                const t = .{'a'} ** 4;
+                expectEqual(4, @typeInfo(@TypeOf(t)).Struct.fields.len);
+                inline for (t) |x| expectEqual('a', x);
+            }
+            {
+                const t = .{ 1, 2, 3 } ** 4;
+                expectEqual(12, @typeInfo(@TypeOf(t)).Struct.fields.len);
+                inline for (t) |x, i| expectEqual(1 + i % 3, x);
+            }
         }
     };
     S.doTheTest();