Commit e26ccd5166

Andrew Kelley <superjoe30@gmail.com>
2017-11-17 03:15:15
debug safety for unions
1 parent f12d366
src/all_types.hpp
@@ -1317,6 +1317,7 @@ enum PanicMsgId {
     PanicMsgIdUnwrapMaybeFail,
     PanicMsgIdInvalidErrorCode,
     PanicMsgIdIncorrectAlignment,
+    PanicMsgIdBadUnionField,
 
     PanicMsgIdCount,
 };
src/codegen.cpp
@@ -810,6 +810,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
             return buf_create_from_str("invalid error code");
         case PanicMsgIdIncorrectAlignment:
             return buf_create_from_str("incorrect alignment");
+        case PanicMsgIdBadUnionField:
+            return buf_create_from_str("access of inactive union field");
     }
     zig_unreachable();
 }
@@ -2415,6 +2417,23 @@ static LLVMValueRef ir_render_union_field_ptr(CodeGen *g, IrExecutable *executab
         return bitcasted_union_field_ptr;
     }
 
+    if (ir_want_debug_safety(g, &instruction->base)) {
+        LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_tag_index, "");
+        LLVMValueRef tag_value = gen_load_untyped(g, tag_field_ptr, 0, false, "");
+        LLVMValueRef expected_tag_value = LLVMConstInt(union_type->data.unionation.tag_type->type_ref,
+                field->value, false);
+
+        LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnionCheckOk");
+        LLVMBasicBlockRef bad_block = LLVMAppendBasicBlock(g->cur_fn_val, "UnionCheckFail");
+        LLVMValueRef ok_val = LLVMBuildICmp(g->builder, LLVMIntEQ, tag_value, expected_tag_value, "");
+        LLVMBuildCondBr(g->builder, ok_val, ok_block, bad_block);
+
+        LLVMPositionBuilderAtEnd(g->builder, bad_block);
+        gen_debug_safety_crash(g, PanicMsgIdBadUnionField);
+
+        LLVMPositionBuilderAtEnd(g->builder, ok_block);
+    }
+
     LLVMValueRef union_field_ptr = LLVMBuildStructGEP(g->builder, union_ptr, union_type->data.unionation.gen_union_index, "");
     LLVMValueRef bitcasted_union_field_ptr = LLVMBuildBitCast(g->builder, union_field_ptr, field_type_ref, "");
     return bitcasted_union_field_ptr;
@@ -3977,21 +3996,17 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val) {
 
                 LLVMValueRef union_value_ref;
                 {
-                    unsigned field_count;
-                    LLVMValueRef fields[2];
-                    fields[0] = correctly_typed_value;
                     if (pad_bytes == 0) {
-                        field_count = 1;
+                        union_value_ref = correctly_typed_value;
                     } else {
+                        LLVMValueRef fields[2];
                         fields[0] = correctly_typed_value;
                         fields[1] = LLVMGetUndef(LLVMArrayType(LLVMInt8Type(), (unsigned)pad_bytes));
-                        field_count = 2;
-                    }
-
-                    if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) {
-                        union_value_ref = LLVMConstStruct(fields, field_count, false);
-                    } else {
-                        union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, field_count);
+                        if (make_unnamed_struct || type_entry->data.unionation.gen_tag_index != SIZE_MAX) {
+                            union_value_ref = LLVMConstStruct(fields, 2, false);
+                        } else {
+                            union_value_ref = LLVMConstNamedStruct(union_type_ref, fields, 2);
+                        }
                     }
                 }
 
test/cases/union.zig
@@ -41,7 +41,7 @@ const Foo = union {
 test "basic unions" {
     var foo = Foo { .int = 1 };
     assert(foo.int == 1);
-    foo.float = 12.34;
+    foo = Foo {.float = 12.34};
     assert(foo.float == 12.34);
 }
 
test/debug_safety.zig
@@ -260,4 +260,24 @@ pub fn addCases(cases: &tests.CompareOutputContext) {
         \\    return int_slice[0];
         \\}
     );
+
+    cases.addDebugSafety("bad union field access",
+        \\pub fn panic(message: []const u8) -> noreturn {
+        \\    @import("std").os.exit(126);
+        \\}
+        \\
+        \\const Foo = union {
+        \\    float: f32,
+        \\    int: u32,
+        \\};
+        \\
+        \\pub fn main() -> %void {
+        \\    var f = Foo { .int = 42 };
+        \\    bar(&f);
+        \\}
+        \\
+        \\fn bar(f: &Foo) {
+        \\    f.float = 12.34;
+        \\}
+    );
 }