Commit 07691db3ae

Andrew Kelley <andrew@ziglang.org>
2021-10-18 00:36:12
stage2: fix handling of error unions as return type
* LLVM backend: fix phi instruction not respecting `isByRef` - Also fix `is_non_null` not respecting `isByRef` * Type: implement abiSize for error unions
1 parent 6534f2e
Changed files (4)
src/codegen/llvm.zig
@@ -1804,14 +1804,16 @@ pub const FuncGen = struct {
 
         const raw_llvm_ty = try self.dg.llvmType(inst_ty);
 
-        // If the zig tag type is a function, this represents an actual function body; not
-        // a pointer to it. LLVM IR allows the call instruction to use function bodies instead
-        // of function pointers, however the phi makes it a runtime value and therefore
-        // the LLVM type has to be wrapped in a pointer.
-        const llvm_ty = if (inst_ty.zigTypeTag() == .Fn)
-            raw_llvm_ty.pointerType(0)
-        else
-            raw_llvm_ty;
+        const llvm_ty = ty: {
+            // If the zig tag type is a function, this represents an actual function body; not
+            // a pointer to it. LLVM IR allows the call instruction to use function bodies instead
+            // of function pointers, however the phi makes it a runtime value and therefore
+            // the LLVM type has to be wrapped in a pointer.
+            if (inst_ty.zigTypeTag() == .Fn or isByRef(inst_ty)) {
+                break :ty raw_llvm_ty.pointerType(0);
+            }
+            break :ty raw_llvm_ty;
+        };
 
         const phi_node = self.builder.buildPhi(llvm_ty, "");
         phi_node.addIncoming(
@@ -2315,7 +2317,7 @@ pub const FuncGen = struct {
             return self.builder.buildICmp(op, loaded, zero, "");
         }
 
-        if (operand_is_ptr) {
+        if (operand_is_ptr or isByRef(err_union_ty)) {
             const err_field_ptr = self.builder.buildStructGEP(operand, 0, "");
             const loaded = self.builder.buildLoad(err_field_ptr, "");
             return self.builder.buildICmp(op, loaded, zero, "");
src/type.zig
@@ -1693,15 +1693,15 @@ pub const Type = extern union {
             },
 
             .error_union => {
-                const payload = self.castTag(.error_union).?.data;
-                if (!payload.error_set.hasCodeGenBits()) {
-                    return payload.payload.abiAlignment(target);
-                } else if (!payload.payload.hasCodeGenBits()) {
-                    return payload.error_set.abiAlignment(target);
+                const data = self.castTag(.error_union).?.data;
+                if (!data.error_set.hasCodeGenBits()) {
+                    return data.payload.abiAlignment(target);
+                } else if (!data.payload.hasCodeGenBits()) {
+                    return data.error_set.abiAlignment(target);
                 }
-                return std.math.max(
-                    payload.payload.abiAlignment(target),
-                    payload.error_set.abiAlignment(target),
+                return @maximum(
+                    data.payload.abiAlignment(target),
+                    data.error_set.abiAlignment(target),
                 );
             },
 
@@ -1942,15 +1942,25 @@ pub const Type = extern union {
             },
 
             .error_union => {
-                const payload = self.castTag(.error_union).?.data;
-                if (!payload.error_set.hasCodeGenBits() and !payload.payload.hasCodeGenBits()) {
+                const data = self.castTag(.error_union).?.data;
+                if (!data.error_set.hasCodeGenBits() and !data.payload.hasCodeGenBits()) {
                     return 0;
-                } else if (!payload.error_set.hasCodeGenBits()) {
-                    return payload.payload.abiSize(target);
-                } else if (!payload.payload.hasCodeGenBits()) {
-                    return payload.error_set.abiSize(target);
+                } else if (!data.error_set.hasCodeGenBits()) {
+                    return data.payload.abiSize(target);
+                } else if (!data.payload.hasCodeGenBits()) {
+                    return data.error_set.abiSize(target);
                 }
-                std.debug.panic("TODO abiSize error union {}", .{self});
+                const code_align = abiAlignment(data.error_set, target);
+                const payload_align = abiAlignment(data.payload, target);
+                const big_align = @maximum(code_align, payload_align);
+                const payload_size = abiSize(data.payload, target);
+
+                var size: u64 = 0;
+                size += abiSize(data.error_set, target);
+                size = std.mem.alignForwardGeneric(u64, size, payload_align);
+                size += payload_size;
+                size = std.mem.alignForwardGeneric(u64, size, big_align);
+                return size;
             },
         };
     }
test/behavior/error.zig
@@ -49,3 +49,69 @@ pub fn baz() anyerror!i32 {
 test "error wrapping" {
     try expect((baz() catch unreachable) == 15);
 }
+
+test "unwrap simple value from error" {
+    const i = unwrapSimpleValueFromErrorDo() catch unreachable;
+    try expect(i == 13);
+}
+fn unwrapSimpleValueFromErrorDo() anyerror!isize {
+    return 13;
+}
+
+test "error return in assignment" {
+    doErrReturnInAssignment() catch unreachable;
+}
+
+fn doErrReturnInAssignment() anyerror!void {
+    var x: i32 = undefined;
+    x = try makeANonErr();
+}
+
+fn makeANonErr() anyerror!i32 {
+    return 1;
+}
+
+test "syntax: optional operator in front of error union operator" {
+    comptime {
+        try expect(?(anyerror!i32) == ?(anyerror!i32));
+    }
+}
+
+test "widen cast integer payload of error union function call" {
+    const S = struct {
+        fn errorable() !u64 {
+            var x = @as(u64, try number());
+            return x;
+        }
+
+        fn number() anyerror!u32 {
+            return 1234;
+        }
+    };
+    try expect((try S.errorable()) == 1234);
+}
+
+test "debug info for optional error set" {
+    const SomeError = error{Hello};
+    var a_local_variable: ?SomeError = null;
+    _ = a_local_variable;
+}
+
+test "implicit cast to optional to error union to return result loc" {
+    const S = struct {
+        fn entry() !void {
+            var x: Foo = undefined;
+            if (func(&x)) |opt| {
+                try expect(opt != null);
+            } else |_| @panic("expected non error");
+        }
+        fn func(f: *Foo) anyerror!?*Foo {
+            return f;
+        }
+        const Foo = struct {
+            field: i32,
+        };
+    };
+    try S.entry();
+    //comptime S.entry(); TODO
+}
test/behavior/error_stage1.zig
@@ -4,52 +4,14 @@ const expectError = std.testing.expectError;
 const expectEqual = std.testing.expectEqual;
 const mem = std.mem;
 
-pub fn foo() anyerror!i32 {
-    const x = try bar();
-    return x + 1;
-}
-
-pub fn bar() anyerror!i32 {
-    return 13;
-}
-
-pub fn baz() anyerror!i32 {
-    const y = foo() catch 1234;
-    return y + 1;
-}
-
-test "error wrapping" {
-    try expect((baz() catch unreachable) == 15);
-}
-
-fn gimmeItBroke() []const u8 {
-    return @errorName(error.ItBroke);
+fn gimmeItBroke() anyerror {
+    return error.ItBroke;
 }
 
 test "@errorName" {
     try expect(mem.eql(u8, @errorName(error.AnError), "AnError"));
     try expect(mem.eql(u8, @errorName(error.ALongerErrorName), "ALongerErrorName"));
-}
-
-test "unwrap simple value from error" {
-    const i = unwrapSimpleValueFromErrorDo() catch unreachable;
-    try expect(i == 13);
-}
-fn unwrapSimpleValueFromErrorDo() anyerror!isize {
-    return 13;
-}
-
-test "error return in assignment" {
-    doErrReturnInAssignment() catch unreachable;
-}
-
-fn doErrReturnInAssignment() anyerror!void {
-    var x: i32 = undefined;
-    x = try makeANonErr();
-}
-
-fn makeANonErr() anyerror!i32 {
-    return 1;
+    try expect(mem.eql(u8, @errorName(gimmeItBroke()), "ItBroke"));
 }
 
 test "error union type " {
@@ -116,12 +78,6 @@ fn testComptimeTestErrorEmptySet(x: EmptyErrorSet!i32) !void {
     }
 }
 
-test "syntax: optional operator in front of error union operator" {
-    comptime {
-        try expect(?(anyerror!i32) == ?(anyerror!i32));
-    }
-}
-
 test "comptime err to int of error set with only 1 possible value" {
     testErrToIntWithOnePossibleValue(error.A, @errorToInt(error.A));
     comptime testErrToIntWithOnePossibleValue(error.A, @errorToInt(error.A));
@@ -268,20 +224,6 @@ test "nested error union function call in optional unwrap" {
     }
 }
 
-test "widen cast integer payload of error union function call" {
-    const S = struct {
-        fn errorable() !u64 {
-            var x = @as(u64, try number());
-            return x;
-        }
-
-        fn number() anyerror!u32 {
-            return 1234;
-        }
-    };
-    try expect((try S.errorable()) == 1234);
-}
-
 test "return function call to error set from error union function" {
     const S = struct {
         fn errorable() anyerror!i32 {
@@ -307,12 +249,6 @@ test "optional error set is the same size as error set" {
     comptime try expect(S.returnsOptErrSet() == null);
 }
 
-test "debug info for optional error set" {
-    const SomeError = error{Hello};
-    var a_local_variable: ?SomeError = null;
-    _ = a_local_variable;
-}
-
 test "nested catch" {
     const S = struct {
         fn entry() !void {
@@ -335,25 +271,6 @@ test "nested catch" {
     comptime try S.entry();
 }
 
-test "implicit cast to optional to error union to return result loc" {
-    const S = struct {
-        fn entry() !void {
-            var x: Foo = undefined;
-            if (func(&x)) |opt| {
-                try expect(opt != null);
-            } else |_| @panic("expected non error");
-        }
-        fn func(f: *Foo) anyerror!?*Foo {
-            return f;
-        }
-        const Foo = struct {
-            field: i32,
-        };
-    };
-    try S.entry();
-    //comptime S.entry(); TODO
-}
-
 test "function pointer with return type that is error union with payload which is pointer of parent struct" {
     const S = struct {
         const Foo = struct {