Commit 1158bc3ead

Andrew Kelley <superjoe30@gmail.com>
2016-01-23 02:05:22
support statically initialized structs
1 parent 7bd9c82
src/all_types.hpp
@@ -45,6 +45,10 @@ struct ConstEnumValue {
     ConstExprValue *payload;
 };
 
+struct ConstStructValue {
+    ConstExprValue **fields;
+};
+
 struct ConstExprValue {
     bool ok; // true if constant expression evalution worked
     bool depends_on_compile_var;
@@ -56,6 +60,7 @@ struct ConstExprValue {
         TypeTableEntry *x_type;
         ConstExprValue *x_maybe;
         ConstEnumValue x_enum;
+        ConstStructValue x_struct;
     } data;
 };
 
@@ -721,7 +726,8 @@ struct TypeStructField {
 struct TypeTableEntryStruct {
     AstNode *decl_node;
     bool is_packed;
-    uint32_t field_count;
+    uint32_t src_field_count;
+    uint32_t gen_field_count;
     TypeStructField *fields;
     uint64_t size_bytes;
     bool is_invalid; // true if any fields are invalid
src/analyze.cpp
@@ -34,6 +34,8 @@ static AstNode *first_executing_node(AstNode *node) {
             return first_executing_node(node->data.field_access_expr.struct_expr);
         case NodeTypeSwitchRange:
             return first_executing_node(node->data.switch_range.start);
+        case NodeTypeContainerInitExpr:
+            return first_executing_node(node->data.container_init_expr.type);
         case NodeTypeRoot:
         case NodeTypeRootExportDecl:
         case NodeTypeFnProto:
@@ -69,7 +71,6 @@ static AstNode *first_executing_node(AstNode *node) {
         case NodeTypeForExpr:
         case NodeTypeSwitchExpr:
         case NodeTypeSwitchProng:
-        case NodeTypeContainerInitExpr:
         case NodeTypeArrayType:
             return node;
     }
@@ -303,7 +304,8 @@ static void unknown_size_array_type_common_init(CodeGen *g, TypeTableEntry *chil
     entry->align_in_bits = g->pointer_size_bytes * 8;
     entry->data.structure.is_packed = false;
     entry->data.structure.is_unknown_size_array = true;
-    entry->data.structure.field_count = element_count;
+    entry->data.structure.src_field_count = element_count;
+    entry->data.structure.gen_field_count = element_count;
     entry->data.structure.fields = allocate<TypeStructField>(element_count);
     entry->data.structure.fields[0].name = buf_create_from_str("ptr");
     entry->data.structure.fields[0].type_entry = pointer_type;
@@ -764,7 +766,7 @@ static void resolve_struct_type(CodeGen *g, ImportTableEntry *import, TypeTableE
 
     int field_count = decl_node->data.struct_decl.fields.length;
 
-    struct_type->data.structure.field_count = field_count;
+    struct_type->data.structure.src_field_count = field_count;
     struct_type->data.structure.fields = allocate<TypeStructField>(field_count);
 
     // we possibly allocate too much here since gen_field_count can be lower than field_count.
@@ -823,6 +825,8 @@ static void resolve_struct_type(CodeGen *g, ImportTableEntry *import, TypeTableE
     }
     struct_type->data.structure.embedded_in_current = false;
 
+    struct_type->data.structure.gen_field_count = gen_field_index;
+
     if (!struct_type->data.structure.is_invalid) {
 
         LLVMStructSetBody(struct_type->type_ref, element_types, gen_field_index, false);
@@ -1083,6 +1087,9 @@ static void add_global_const_expr(CodeGen *g, Expr *expr) {
 }
 
 static bool num_lit_fits_in_other_type(CodeGen *g, AstNode *literal_node, TypeTableEntry *other_type) {
+    if (other_type->id == TypeTableEntryIdInvalid) {
+        return false;
+    }
     Expr *expr = get_resolved_expr(literal_node);
     ConstExprValue *const_val = &expr->const_val;
     assert(const_val->ok);
@@ -1438,16 +1445,6 @@ static TypeEnumField *get_enum_field(TypeTableEntry *enum_type, Buf *name) {
     return nullptr;
 }
 
-static TypeStructField *get_struct_field(TypeTableEntry *struct_type, Buf *name) {
-    for (uint32_t i = 0; i < struct_type->data.structure.field_count; i += 1) {
-        TypeStructField *type_struct_field = &struct_type->data.structure.fields[i];
-        if (buf_eql_buf(type_struct_field->name, name)) {
-            return type_struct_field;
-        }
-    }
-    return nullptr;
-}
-
 static TypeTableEntry *analyze_enum_value_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         AstNode *field_access_node, AstNode *value_node, TypeTableEntry *enum_type, Buf *field_name)
 {
@@ -1484,12 +1481,11 @@ static TypeTableEntry *analyze_enum_value_expr(CodeGen *g, ImportTableEntry *imp
     return enum_type;
 }
 
-static TypeStructField *find_struct_type_field(TypeTableEntry *type_entry, Buf *name, int *index) {
+static TypeStructField *find_struct_type_field(TypeTableEntry *type_entry, Buf *name) {
     assert(type_entry->id == TypeTableEntryIdStruct);
-    for (uint32_t i = 0; i < type_entry->data.structure.field_count; i += 1) {
+    for (uint32_t i = 0; i < type_entry->data.structure.src_field_count; i += 1) {
         TypeStructField *field = &type_entry->data.structure.fields[i];
         if (buf_eql_buf(field->name, name)) {
-            *index = i;
             return field;
         }
     }
@@ -1530,16 +1526,18 @@ static TypeTableEntry *analyze_container_init_expr(CodeGen *g, ImportTableEntry
 
 
         int expr_field_count = container_init_expr->entries.length;
-        int actual_field_count = container_type->data.structure.field_count;
+        int actual_field_count = container_type->data.structure.src_field_count;
 
         int *field_use_counts = allocate<int>(actual_field_count);
+        ConstExprValue *const_val = &get_resolved_expr(node)->const_val;
+        const_val->ok = true;
+        const_val->data.x_struct.fields = allocate<ConstExprValue*>(actual_field_count);
         for (int i = 0; i < expr_field_count; i += 1) {
             AstNode *val_field_node = container_init_expr->entries.at(i);
             assert(val_field_node->type == NodeTypeStructValueField);
 
-            int field_index;
             TypeStructField *type_field = find_struct_type_field(container_type,
-                    &val_field_node->data.struct_val_field.name, &field_index);
+                    &val_field_node->data.struct_val_field.name);
 
             if (!type_field) {
                 add_node_error(g, val_field_node,
@@ -1548,6 +1546,7 @@ static TypeTableEntry *analyze_container_init_expr(CodeGen *g, ImportTableEntry
                 continue;
             }
 
+            int field_index = type_field->src_index;
             field_use_counts[field_index] += 1;
             if (field_use_counts[field_index] > 1) {
                 add_node_error(g, val_field_node, buf_sprintf("duplicate field"));
@@ -1558,6 +1557,16 @@ static TypeTableEntry *analyze_container_init_expr(CodeGen *g, ImportTableEntry
 
             analyze_expression(g, import, context, type_field->type_entry,
                     val_field_node->data.struct_val_field.expr);
+
+            if (const_val->ok) {
+                ConstExprValue *field_val =
+                    &get_resolved_expr(val_field_node->data.struct_val_field.expr)->const_val;
+                if (field_val->ok) {
+                    const_val->data.x_struct.fields[field_index] = field_val;
+                } else {
+                    const_val->ok = false;
+                }
+            }
         }
 
         for (int i = 0; i < actual_field_count; i += 1) {
@@ -1634,7 +1643,7 @@ static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *i
         TypeTableEntry *bare_struct_type = (struct_type->id == TypeTableEntryIdStruct) ?
             struct_type : struct_type->data.pointer.child_type;
 
-        node->data.field_access_expr.type_struct_field = get_struct_field(bare_struct_type, field_name);
+        node->data.field_access_expr.type_struct_field = find_struct_type_field(bare_struct_type, field_name);
         if (node->data.field_access_expr.type_struct_field) {
             return node->data.field_access_expr.type_struct_field->type_entry;
         } else {
src/codegen.cpp
@@ -1527,13 +1527,13 @@ static LLVMValueRef gen_container_init_expr(CodeGen *g, AstNode *node) {
     if (type_entry->id == TypeTableEntryIdStruct) {
         assert(node->data.container_init_expr.kind == ContainerInitKindStruct);
 
-        int field_count = type_entry->data.structure.field_count;
-        assert(field_count == node->data.container_init_expr.entries.length);
+        int src_field_count = type_entry->data.structure.src_field_count;
+        assert(src_field_count == node->data.container_init_expr.entries.length);
 
         StructValExprCodeGen *struct_val_expr_node = &node->data.container_init_expr.resolved_struct_val_expr;
         LLVMValueRef tmp_struct_ptr = struct_val_expr_node->ptr;
 
-        for (int i = 0; i < field_count; i += 1) {
+        for (int i = 0; i < src_field_count; i += 1) {
             AstNode *field_node = node->data.container_init_expr.entries.at(i);
             assert(field_node->type == NodeTypeStructValueField);
             TypeStructField *type_struct_field = field_node->data.struct_val_field.type_struct_field;
@@ -2109,7 +2109,13 @@ static LLVMValueRef gen_const_val(CodeGen *g, TypeTableEntry *type_entry, ConstE
         };
         return LLVMConstStruct(fields, 2, false);
     } else if (type_entry->id == TypeTableEntryIdStruct) {
-        zig_panic("TODO");
+        LLVMValueRef *fields = allocate<LLVMValueRef>(type_entry->data.structure.gen_field_count);
+        for (int i = 0; i < type_entry->data.structure.src_field_count; i += 1) {
+            TypeStructField *type_struct_field = &type_entry->data.structure.fields[i];
+            fields[type_struct_field->gen_index] = gen_const_val(g, type_struct_field->type_entry,
+                    const_val->data.x_struct.fields[i]);
+        }
+        return LLVMConstNamedStruct(type_entry->type_ref, fields, type_entry->data.structure.gen_field_count);
     } else if (type_entry->id == TypeTableEntryIdArray) {
         zig_panic("TODO");
     } else if (type_entry->id == TypeTableEntryIdEnum) {
@@ -2142,10 +2148,10 @@ static void gen_const_globals(CodeGen *g) {
         TypeTableEntry *type_entry = expr->type_entry;
 
         if (handle_is_ptr(type_entry)) {
-            LLVMValueRef global_value = LLVMAddGlobal(g->module, type_entry->type_ref, "");
-            LLVMSetLinkage(global_value, LLVMPrivateLinkage);
             LLVMValueRef init_val = gen_const_val(g, type_entry, const_val);
+            LLVMValueRef global_value = LLVMAddGlobal(g->module, LLVMTypeOf(init_val), "");
             LLVMSetInitializer(global_value, init_val);
+            LLVMSetLinkage(global_value, LLVMPrivateLinkage);
             LLVMSetGlobalConstant(global_value, true);
             LLVMSetUnnamedAddr(global_value, true);
             expr->const_llvm_val = global_value;
@@ -2171,15 +2177,22 @@ static void do_code_gen(CodeGen *g) {
         }
 
         // TODO if the global is exported, set external linkage
-        LLVMValueRef global_value = LLVMAddGlobal(g->module, var->type->type_ref, "");
-        LLVMSetLinkage(global_value, LLVMPrivateLinkage);
-
-        if (var->is_const) {
-            LLVMValueRef init_val = gen_expr(g, var->decl_node->data.variable_declaration.expr);
-            LLVMSetInitializer(global_value, init_val);
+        LLVMValueRef init_val;
+
+        assert(var->decl_node);
+        assert(var->decl_node->type == NodeTypeVariableDeclaration);
+        AstNode *expr_node = var->decl_node->data.variable_declaration.expr;
+        if (expr_node) {
+            Expr *expr = get_resolved_expr(expr_node);
+            ConstExprValue *const_val = &expr->const_val;
+            assert(const_val->ok);
+            TypeTableEntry *type_entry = expr->type_entry;
+            init_val = gen_const_val(g, type_entry, const_val);
         } else {
-            LLVMSetInitializer(global_value, LLVMConstNull(var->type->type_ref));
+            init_val = LLVMConstNull(var->type->type_ref);
         }
+        LLVMValueRef global_value = LLVMAddGlobal(g->module, LLVMTypeOf(init_val), "");
+        LLVMSetInitializer(global_value, init_val);
         LLVMSetGlobalConstant(global_value, var->is_const);
         LLVMSetUnnamedAddr(global_value, true);
 
test/run_tests.cpp
@@ -1226,6 +1226,24 @@ pub fn main(args: [][]u8) i32 => {
 }
     )SOURCE", "OK\n");
 
+    add_simple_case("statically initialized struct", R"SOURCE(
+import "std.zig";
+struct Foo {
+    x: i32,
+    y: bool,
+}
+var foo = Foo { .x = 13, .y = true, };
+pub fn main(args: [][]u8) i32 => {
+    foo.x += 1;
+    if (foo.x != 14) {
+        print_str("BAD\n");
+    }
+
+    print_str("OK\n");
+    return 0;
+}
+    )SOURCE", "OK\n");
+
 }