Commit 1066004b79

Andrew Kelley <andrew@ziglang.org>
2019-02-21 20:44:14
better handling of arrays in packed structs
* Separate LoadPtr IR instructions into pass1 and pass2 variants. * Define `type_size_bits` for extern structs to be the same as their `@sizeOf(T) * 8` and allow them in packed structs. * More helpful error messages when trying to use types in packed structs that are not allowed. * Support arrays in packed structs even when they are not byte-aligned. * Add compile error for using arrays in packed structs when the padding bits would be problematic. This is necessary since we do not have packed arrays. closes #677
1 parent 2bb795d
src/all_types.hpp
@@ -2119,6 +2119,7 @@ enum IrInstructionId {
     IrInstructionIdUnOp,
     IrInstructionIdBinOp,
     IrInstructionIdLoadPtr,
+    IrInstructionIdLoadPtrGen,
     IrInstructionIdStorePtr,
     IrInstructionIdFieldPtr,
     IrInstructionIdStructFieldPtr,
@@ -2414,6 +2415,13 @@ struct IrInstructionLoadPtr {
     IrInstruction *ptr;
 };
 
+struct IrInstructionLoadPtrGen {
+    IrInstruction base;
+
+    IrInstruction *ptr;
+    LLVMValueRef tmp_ptr;
+};
+
 struct IrInstructionStorePtr {
     IrInstruction base;
 
src/analyze.cpp
@@ -365,19 +365,19 @@ uint64_t type_size_bits(CodeGen *g, ZigType *type_entry) {
     if (!type_has_bits(type_entry))
         return 0;
 
-    if (type_entry->id == ZigTypeIdStruct && type_entry->data.structure.layout == ContainerLayoutPacked) {
-        uint64_t result = 0;
-        for (size_t i = 0; i < type_entry->data.structure.src_field_count; i += 1) {
-            result += type_size_bits(g, type_entry->data.structure.fields[i].type_entry);
+    if (type_entry->id == ZigTypeIdStruct) {
+        if (type_entry->data.structure.layout == ContainerLayoutPacked) {
+            uint64_t result = 0;
+            for (size_t i = 0; i < type_entry->data.structure.src_field_count; i += 1) {
+                result += type_size_bits(g, type_entry->data.structure.fields[i].type_entry);
+            }
+            return result;
+        } else if (type_entry->data.structure.layout == ContainerLayoutExtern) {
+            return type_size(g, type_entry) * 8;
         }
-        return result;
     } else if (type_entry->id == ZigTypeIdArray) {
         ZigType *child_type = type_entry->data.array.child_type;
-        if (child_type->id == ZigTypeIdStruct &&
-            child_type->data.structure.layout == ContainerLayoutPacked)
-        {
-            return type_entry->data.array.len * type_size_bits(g, child_type);
-        }
+        return type_entry->data.array.len * type_size_bits(g, child_type);
     }
 
     return LLVMSizeOfTypeInBits(g->target_data_ref, type_entry->type_ref);
@@ -1444,7 +1444,10 @@ static bool analyze_const_string(CodeGen *g, Scope *scope, AstNode *node, Buf **
     return true;
 }
 
-static bool type_allowed_in_packed_struct(ZigType *type_entry) {
+static Error emit_error_unless_type_allowed_in_packed_struct(CodeGen *g, ZigType *type_entry,
+        AstNode *source_node)
+{
+    Error err;
     switch (type_entry->id) {
         case ZigTypeIdInvalid:
             zig_unreachable();
@@ -1461,27 +1464,74 @@ static bool type_allowed_in_packed_struct(ZigType *type_entry) {
         case ZigTypeIdArgTuple:
         case ZigTypeIdOpaque:
         case ZigTypeIdPromise:
-            return false;
+            add_node_error(g, source_node,
+                    buf_sprintf("type '%s' not allowed in packed struct; no guaranteed in-memory representation",
+                        buf_ptr(&type_entry->name)));
+            return ErrorSemanticAnalyzeFail;
         case ZigTypeIdVoid:
         case ZigTypeIdBool:
         case ZigTypeIdInt:
         case ZigTypeIdFloat:
         case ZigTypeIdPointer:
-        case ZigTypeIdArray:
         case ZigTypeIdFn:
         case ZigTypeIdVector:
-            return true;
+            return ErrorNone;
+        case ZigTypeIdArray: {
+            ZigType *elem_type = type_entry->data.array.child_type;
+            if ((err = emit_error_unless_type_allowed_in_packed_struct(g, elem_type, source_node)))
+                return err;
+            if (type_size(g, type_entry) * 8 == type_size_bits(g, type_entry))
+                return ErrorNone;
+            add_node_error(g, source_node,
+                buf_sprintf("array of '%s' not allowed in packed struct due to padding bits",
+                    buf_ptr(&elem_type->name)));
+            return ErrorSemanticAnalyzeFail;
+        }
         case ZigTypeIdStruct:
-            return type_entry->data.structure.layout == ContainerLayoutPacked;
+            switch (type_entry->data.structure.layout) {
+                case ContainerLayoutPacked:
+                case ContainerLayoutExtern:
+                    return ErrorNone;
+                case ContainerLayoutAuto:
+                    add_node_error(g, source_node,
+                        buf_sprintf("non-packed, non-extern struct '%s' not allowed in packed struct; no guaranteed in-memory representation",
+                            buf_ptr(&type_entry->name)));
+                    return ErrorSemanticAnalyzeFail;
+            }
+            zig_unreachable();
         case ZigTypeIdUnion:
-            return type_entry->data.unionation.layout == ContainerLayoutPacked;
+            switch (type_entry->data.unionation.layout) {
+                case ContainerLayoutPacked:
+                case ContainerLayoutExtern:
+                    return ErrorNone;
+                case ContainerLayoutAuto:
+                    add_node_error(g, source_node,
+                        buf_sprintf("non-packed, non-extern union '%s' not allowed in packed struct; no guaranteed in-memory representation",
+                            buf_ptr(&type_entry->name)));
+                    return ErrorSemanticAnalyzeFail;
+            }
+            zig_unreachable();
         case ZigTypeIdOptional:
-            {
-                ZigType *child_type = type_entry->data.maybe.child_type;
-                return type_is_codegen_pointer(child_type);
+            if (get_codegen_ptr_type(type_entry) != nullptr) {
+                return ErrorNone;
+            } else {
+                add_node_error(g, source_node,
+                    buf_sprintf("type '%s' not allowed in packed struct; no guaranteed in-memory representation",
+                        buf_ptr(&type_entry->name)));
+                return ErrorSemanticAnalyzeFail;
             }
-        case ZigTypeIdEnum:
-            return type_entry->data.enumeration.decl_node->data.container_decl.init_arg_expr != nullptr;
+        case ZigTypeIdEnum: {
+            AstNode *decl_node = type_entry->data.enumeration.decl_node;
+            if (decl_node->data.container_decl.init_arg_expr != nullptr) {
+                return ErrorNone;
+            }
+            ErrorMsg *msg = add_node_error(g, source_node,
+                buf_sprintf("type '%s' not allowed in packed struct; no guaranteed in-memory representation",
+                    buf_ptr(&type_entry->name)));
+            add_error_note(g, msg, decl_node,
+                    buf_sprintf("enum declaration does not specify an integer tag type"));
+            return ErrorSemanticAnalyzeFail;
+        }
     }
     zig_unreachable();
 }
@@ -2051,11 +2101,8 @@ static Error resolve_struct_type(CodeGen *g, ZigType *struct_type) {
         type_struct_field->gen_index = gen_field_index;
 
         if (packed) {
-            if (!type_allowed_in_packed_struct(field_type)) {
-                AstNode *field_source_node = decl_node->data.container_decl.fields.at(i);
-                add_node_error(g, field_source_node,
-                        buf_sprintf("packed structs cannot contain fields of type '%s'",
-                            buf_ptr(&field_type->name)));
+            AstNode *field_source_node = decl_node->data.container_decl.fields.at(i);
+            if ((err = emit_error_unless_type_allowed_in_packed_struct(g, field_type, field_source_node))) {
                 struct_type->data.structure.resolve_status = ResolveStatusInvalid;
                 break;
             }
src/codegen.cpp
@@ -3281,7 +3281,7 @@ static LLVMValueRef ir_render_decl_var(CodeGen *g, IrExecutable *executable,
     return nullptr;
 }
 
-static LLVMValueRef ir_render_load_ptr(CodeGen *g, IrExecutable *executable, IrInstructionLoadPtr *instruction) {
+static LLVMValueRef ir_render_load_ptr(CodeGen *g, IrExecutable *executable, IrInstructionLoadPtrGen *instruction) {
     ZigType *child_type = instruction->base.value.type;
     if (!type_has_bits(child_type))
         return nullptr;
@@ -3296,7 +3296,6 @@ static LLVMValueRef ir_render_load_ptr(CodeGen *g, IrExecutable *executable, IrI
 
     bool big_endian = g->is_big_endian;
 
-    assert(!handle_is_ptr(child_type));
     LLVMValueRef containing_int = gen_load(g, ptr, ptr_type, "");
     uint32_t host_bit_count = LLVMGetIntTypeWidth(LLVMTypeOf(containing_int));
     assert(host_bit_count == host_int_bytes * 8);
@@ -3308,7 +3307,16 @@ static LLVMValueRef ir_render_load_ptr(CodeGen *g, IrExecutable *executable, IrI
     LLVMValueRef shift_amt_val = LLVMConstInt(LLVMTypeOf(containing_int), shift_amt, false);
     LLVMValueRef shifted_value = LLVMBuildLShr(g->builder, containing_int, shift_amt_val, "");
 
-    return LLVMBuildTrunc(g->builder, shifted_value, child_type->type_ref, "");
+    if (!handle_is_ptr(child_type))
+        return LLVMBuildTrunc(g->builder, shifted_value, child_type->type_ref, "");
+
+    assert(instruction->tmp_ptr != nullptr);
+    LLVMTypeRef same_size_int = LLVMIntType(size_in_bits);
+    LLVMValueRef truncated_int = LLVMBuildTrunc(g->builder, shifted_value, same_size_int, "");
+    LLVMValueRef bitcasted_ptr = LLVMBuildBitCast(g->builder, instruction->tmp_ptr,
+            LLVMPointerType(same_size_int, 0), "");
+    LLVMBuildStore(g->builder, truncated_int, bitcasted_ptr);
+    return instruction->tmp_ptr;
 }
 
 static bool value_is_all_undef_array(ConstExprValue *const_val, size_t len) {
@@ -5460,6 +5468,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
         case IrInstructionIdDeclVarSrc:
         case IrInstructionIdPtrCastSrc:
         case IrInstructionIdCmpxchgSrc:
+        case IrInstructionIdLoadPtr:
             zig_unreachable();
 
         case IrInstructionIdDeclVarGen:
@@ -5478,8 +5487,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
             return ir_render_br(g, executable, (IrInstructionBr *)instruction);
         case IrInstructionIdUnOp:
             return ir_render_un_op(g, executable, (IrInstructionUnOp *)instruction);
-        case IrInstructionIdLoadPtr:
-            return ir_render_load_ptr(g, executable, (IrInstructionLoadPtr *)instruction);
+        case IrInstructionIdLoadPtrGen:
+            return ir_render_load_ptr(g, executable, (IrInstructionLoadPtrGen *)instruction);
         case IrInstructionIdStorePtr:
             return ir_render_store_ptr(g, executable, (IrInstructionStorePtr *)instruction);
         case IrInstructionIdVarPtr:
@@ -5836,8 +5845,32 @@ static LLVMValueRef pack_const_int(CodeGen *g, LLVMTypeRef big_int_type_ref, Con
                 LLVMValueRef ptr_size_int_val = LLVMConstPtrToInt(ptr_val, g->builtin_types.entry_usize->type_ref);
                 return LLVMConstZExt(ptr_size_int_val, big_int_type_ref);
             }
-        case ZigTypeIdArray:
-            zig_panic("TODO bit pack an array");
+        case ZigTypeIdArray: {
+            LLVMValueRef val = LLVMConstInt(big_int_type_ref, 0, false);
+            if (const_val->data.x_array.special == ConstArraySpecialUndef) {
+                return val;
+            }
+            expand_undef_array(g, const_val);
+            bool is_big_endian = g->is_big_endian; // TODO get endianness from struct type
+            uint32_t packed_bits_size = type_size_bits(g, type_entry->data.array.child_type);
+            size_t used_bits = 0;
+            for (size_t i = 0; i < type_entry->data.array.len; i += 1) {
+                ConstExprValue *elem_val = &const_val->data.x_array.data.s_none.elements[i];
+                LLVMValueRef child_val = pack_const_int(g, big_int_type_ref, elem_val);
+
+                if (is_big_endian) {
+                    LLVMValueRef shift_amt = LLVMConstInt(big_int_type_ref, packed_bits_size, false);
+                    val = LLVMConstShl(val, shift_amt);
+                    val = LLVMConstOr(val, child_val);
+                } else {
+                    LLVMValueRef shift_amt = LLVMConstInt(big_int_type_ref, used_bits, false);
+                    LLVMValueRef child_val_shifted = LLVMConstShl(child_val, shift_amt);
+                    val = LLVMConstOr(val, child_val_shifted);
+                    used_bits += packed_bits_size;
+                }
+            }
+            return val;
+        }
         case ZigTypeIdVector:
             zig_panic("TODO bit pack a vector");
         case ZigTypeIdUnion:
@@ -6728,6 +6761,9 @@ static void do_code_gen(CodeGen *g) {
             } else if (instruction->id == IrInstructionIdResizeSlice) {
                 IrInstructionResizeSlice *resize_slice_instruction = (IrInstructionResizeSlice *)instruction;
                 slot = &resize_slice_instruction->tmp_ptr;
+            } else if (instruction->id == IrInstructionIdLoadPtrGen) {
+                IrInstructionLoadPtrGen *load_ptr_inst = (IrInstructionLoadPtrGen *)instruction;
+                slot = &load_ptr_inst->tmp_ptr;
             } else if (instruction->id == IrInstructionIdVectorToArray) {
                 IrInstructionVectorToArray *vector_to_array_instruction = (IrInstructionVectorToArray *)instruction;
                 alignment_bytes = get_abi_alignment(g, vector_to_array_instruction->vector->value.type);
src/ir.cpp
@@ -416,6 +416,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionLoadPtr *) {
     return IrInstructionIdLoadPtr;
 }
 
+static constexpr IrInstructionId ir_instruction_id(IrInstructionLoadPtrGen *) {
+    return IrInstructionIdLoadPtrGen;
+}
+
 static constexpr IrInstructionId ir_instruction_id(IrInstructionStorePtr *) {
     return IrInstructionIdStorePtr;
 }
@@ -2292,6 +2296,19 @@ static IrInstruction *ir_build_ptr_cast_gen(IrAnalyze *ira, IrInstruction *sourc
     return &instruction->base;
 }
 
+static IrInstruction *ir_build_load_ptr_gen(IrAnalyze *ira, IrInstruction *source_instruction,
+        IrInstruction *ptr, ZigType *ty)
+{
+    IrInstructionLoadPtrGen *instruction = ir_build_instruction<IrInstructionLoadPtrGen>(
+            &ira->new_irb, source_instruction->scope, source_instruction->source_node);
+    instruction->base.value.type = ty;
+    instruction->ptr = ptr;
+
+    ir_ref_instruction(ptr, ira->new_irb.current_basic_block);
+
+    return &instruction->base;
+}
+
 static IrInstruction *ir_build_bit_cast(IrBuilder *irb, Scope *scope, AstNode *source_node,
         IrInstruction *dest_type, IrInstruction *value)
 {
@@ -11534,10 +11551,11 @@ static IrInstruction *ir_get_deref(IrAnalyze *ira, IrInstruction *source_instruc
             IrInstructionRef *ref_inst = reinterpret_cast<IrInstructionRef *>(ptr);
             return ref_inst->value;
         }
-        IrInstruction *load_ptr_instruction = ir_build_load_ptr(&ira->new_irb, source_instruction->scope,
-                source_instruction->source_node, ptr);
-        load_ptr_instruction->value.type = child_type;
-        return load_ptr_instruction;
+        IrInstruction *result = ir_build_load_ptr_gen(ira, source_instruction, ptr, child_type);
+        if (type_entry->data.pointer.host_int_bytes != 0 && handle_is_ptr(child_type)) {
+            ir_add_alloca(ira, result, child_type);
+        }
+        return result;
     } else {
         ir_add_error_node(ira, source_instruction->source_node,
             buf_sprintf("attempt to dereference non-pointer type '%s'",
@@ -13398,8 +13416,8 @@ static IrInstruction *ir_analyze_instruction_export(IrAnalyze *ira, IrInstructio
     }
 
     // TODO audit the various ways to use @export
-    if (want_var_export && target->id == IrInstructionIdLoadPtr) {
-        IrInstructionLoadPtr *load_ptr = reinterpret_cast<IrInstructionLoadPtr *>(target);
+    if (want_var_export && target->id == IrInstructionIdLoadPtrGen) {
+        IrInstructionLoadPtrGen *load_ptr = reinterpret_cast<IrInstructionLoadPtrGen *>(target);
         if (load_ptr->ptr->id == IrInstructionIdVarPtr) {
             IrInstructionVarPtr *var_ptr = reinterpret_cast<IrInstructionVarPtr *>(load_ptr->ptr);
             ZigVar *var = var_ptr->var;
@@ -22316,6 +22334,7 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio
         case IrInstructionIdVectorToArray:
         case IrInstructionIdAssertZero:
         case IrInstructionIdResizeSlice:
+        case IrInstructionIdLoadPtrGen:
             zig_unreachable();
 
         case IrInstructionIdReturn:
@@ -22722,6 +22741,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
         case IrInstructionIdUnOp:
         case IrInstructionIdBinOp:
         case IrInstructionIdLoadPtr:
+        case IrInstructionIdLoadPtrGen:
         case IrInstructionIdConst:
         case IrInstructionIdCast:
         case IrInstructionIdContainerInitList:
src/ir_print.cpp
@@ -336,6 +336,11 @@ static void ir_print_load_ptr(IrPrint *irp, IrInstructionLoadPtr *instruction) {
     fprintf(irp->f, ".*");
 }
 
+static void ir_print_load_ptr_gen(IrPrint *irp, IrInstructionLoadPtrGen *instruction) {
+    ir_print_other_instruction(irp, instruction->ptr);
+    fprintf(irp->f, ".*");
+}
+
 static void ir_print_store_ptr(IrPrint *irp, IrInstructionStorePtr *instruction) {
     fprintf(irp->f, "*");
     ir_print_var_instruction(irp, instruction->ptr);
@@ -1468,6 +1473,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
         case IrInstructionIdLoadPtr:
             ir_print_load_ptr(irp, (IrInstructionLoadPtr *)instruction);
             break;
+        case IrInstructionIdLoadPtrGen:
+            ir_print_load_ptr_gen(irp, (IrInstructionLoadPtrGen *)instruction);
+            break;
         case IrInstructionIdStorePtr:
             ir_print_store_ptr(irp, (IrInstructionStorePtr *)instruction);
             break;
test/stage1/behavior/struct.zig
@@ -1,5 +1,6 @@
 const std = @import("std");
 const expect = std.testing.expect;
+const expectEqualSlices = std.testing.expectEqualSlices;
 const builtin = @import("builtin");
 const maxInt = std.math.maxInt;
 
@@ -103,19 +104,20 @@ fn structInitializer() void {
 }
 
 test "fn call of struct field" {
-    expect(callStructField(Foo{ .ptr = aFunc }) == 13);
-}
-
-const Foo = struct {
-    ptr: fn () i32,
-};
+    const Foo = struct {
+        ptr: fn () i32,
+    };
+    const S = struct {
+        fn aFunc() i32 {
+            return 13;
+        }
 
-fn aFunc() i32 {
-    return 13;
-}
+        fn callStructField(foo: Foo) i32 {
+            return foo.ptr();
+        }
+    };
 
-fn callStructField(foo: Foo) i32 {
-    return foo.ptr();
+    expect(S.callStructField(Foo{ .ptr = S.aFunc }) == 13);
 }
 
 test "store member function in variable" {
@@ -468,3 +470,24 @@ test "pointer to packed struct member in a stack variable" {
     b_ptr.* = 2;
     expect(s.b == 2);
 }
+
+test "non-byte-aligned array inside packed struct" {
+    const Foo = packed struct {
+        a: bool,
+        b: [0x16]u8,
+    };
+    const S = struct {
+        fn bar(slice: []const u8) void {
+            expectEqualSlices(u8, slice, "abcdefghijklmnopqurstu");
+        }
+        fn doTheTest() void {
+            var foo = Foo{
+                .a = true,
+                .b = "abcdefghijklmnopqurstu",
+            };
+            bar(foo.b);
+        }
+    };
+    S.doTheTest();
+    comptime S.doTheTest();
+}
test/compile_errors.zig
@@ -1,6 +1,60 @@
 const tests = @import("tests.zig");
 
 pub fn addCases(cases: *tests.CompileErrorContext) void {
+    cases.addTest(
+        "packed struct with fields of not allowed types",
+        \\const A = packed struct {
+        \\    x: anyerror,
+        \\};
+        \\const B = packed struct {
+        \\    x: [2]u24,
+        \\};
+        \\const C = packed struct {
+        \\    x: [1]anyerror,
+        \\};
+        \\const D = packed struct {
+        \\    x: [1]S,
+        \\};
+        \\const E = packed struct {
+        \\    x: [1]U,
+        \\};
+        \\const F = packed struct {
+        \\    x: ?anyerror,
+        \\};
+        \\const G = packed struct {
+        \\    x: Enum,
+        \\};
+        \\export fn entry() void {
+        \\    var a: A = undefined;
+        \\    var b: B = undefined;
+        \\    var r: C = undefined;
+        \\    var d: D = undefined;
+        \\    var e: E = undefined;
+        \\    var f: F = undefined;
+        \\    var g: G = undefined;
+        \\}
+        \\const S = struct {
+        \\    x: i32,
+        \\};
+        \\const U = struct {
+        \\    A: i32,
+        \\    B: u32,
+        \\};
+        \\const Enum = enum {
+        \\    A,
+        \\    B,
+        \\};
+    ,
+        ".tmp_source.zig:2:5: error: type 'anyerror' not allowed in packed struct; no guaranteed in-memory representation",
+        ".tmp_source.zig:5:5: error: array of 'u24' not allowed in packed struct due to padding bits",
+        ".tmp_source.zig:8:5: error: type 'anyerror' not allowed in packed struct; no guaranteed in-memory representation",
+        ".tmp_source.zig:11:5: error: non-packed, non-extern struct 'S' not allowed in packed struct; no guaranteed in-memory representation",
+        ".tmp_source.zig:14:5: error: non-packed, non-extern struct 'U' not allowed in packed struct; no guaranteed in-memory representation",
+        ".tmp_source.zig:17:5: error: type '?anyerror' not allowed in packed struct; no guaranteed in-memory representation",
+        ".tmp_source.zig:20:5: error: type 'Enum' not allowed in packed struct; no guaranteed in-memory representation",
+        ".tmp_source.zig:38:14: note: enum declaration does not specify an integer tag type",
+    );
+
     cases.addCase(x: {
         var tc = cases.create(
             "deduplicate undeclared identifier",