Commit 799fedf612

Andrew Kelley <andrew@ziglang.org>
2021-08-08 05:34:28
stage2: pass some error union tests
* Value: rename `error_union` to `eu_payload` and clarify the intended usage in the doc comments. The way error unions is represented with Value is fixed to not have ambiguous values. * Fix codegen for error union constants in all the backends. * Implement the AIR instructions having to do with error unions in the LLVM backend.
1 parent f81b253
src/codegen/c.zig
@@ -350,32 +350,25 @@ pub const DeclGen = struct {
             .ErrorUnion => {
                 const error_type = t.errorUnionSet();
                 const payload_type = t.errorUnionPayload();
-                const sub_val = val.castTag(.error_union).?.data;
 
                 if (!payload_type.hasCodeGenBits()) {
                     // We use the error type directly as the type.
-                    return dg.renderValue(writer, error_type, sub_val);
+                    const err_val = if (val.errorUnionIsPayload()) Value.initTag(.zero) else val;
+                    return dg.renderValue(writer, error_type, err_val);
                 }
 
                 try writer.writeByte('(');
                 try dg.renderType(writer, t);
                 try writer.writeAll("){");
-                if (val.getError()) |_| {
-                    try writer.writeAll(" .error = ");
-                    try dg.renderValue(
-                        writer,
-                        error_type,
-                        sub_val,
-                    );
-                    try writer.writeAll(" }");
-                } else {
+                if (val.castTag(.eu_payload)) |pl| {
+                    const payload_val = pl.data;
                     try writer.writeAll(" .payload = ");
-                    try dg.renderValue(
-                        writer,
-                        payload_type,
-                        sub_val,
-                    );
+                    try dg.renderValue(writer, payload_type, payload_val);
                     try writer.writeAll(", .error = 0 }");
+                } else {
+                    try writer.writeAll(" .error = ");
+                    try dg.renderValue(writer, error_type, val);
+                    try writer.writeAll(" }");
                 }
             },
             .Enum => {
src/codegen/llvm.zig
@@ -593,7 +593,7 @@ pub const DeclGen = struct {
                         try self.llvmType(ptr_type),
                         try self.llvmType(Type.initTag(.usize)),
                     };
-                    return self.context.structType(&fields, 2, .False);
+                    return self.context.structType(&fields, fields.len, .False);
                 } else {
                     const elem_type = try self.llvmType(t.elemType());
                     return elem_type.pointerType(0);
@@ -621,10 +621,14 @@ pub const DeclGen = struct {
             .ErrorUnion => {
                 const error_type = t.errorUnionSet();
                 const payload_type = t.errorUnionPayload();
+                const llvm_error_type = try self.llvmType(error_type);
                 if (!payload_type.hasCodeGenBits()) {
-                    return self.llvmType(error_type);
+                    return llvm_error_type;
                 }
-                return self.todo("implement llvmType for error unions", .{});
+                const llvm_payload_type = try self.llvmType(payload_type);
+
+                const fields: [2]*const llvm.Type = .{ llvm_error_type, llvm_payload_type };
+                return self.context.structType(&fields, fields.len, .False);
             },
             .ErrorSet => {
                 return self.context.intType(16);
@@ -846,14 +850,25 @@ pub const DeclGen = struct {
             .ErrorUnion => {
                 const error_type = tv.ty.errorUnionSet();
                 const payload_type = tv.ty.errorUnionPayload();
-                const sub_val = tv.val.castTag(.error_union).?.data;
+                const is_pl = tv.val.errorUnionIsPayload();
 
                 if (!payload_type.hasCodeGenBits()) {
                     // We use the error type directly as the type.
-                    return self.genTypedValue(.{ .ty = error_type, .val = sub_val });
+                    const err_val = if (!is_pl) tv.val else Value.initTag(.zero);
+                    return self.genTypedValue(.{ .ty = error_type, .val = err_val });
                 }
 
-                return self.todo("implement error union const of type '{}'", .{tv.ty});
+                const fields: [2]*const llvm.Value = .{
+                    try self.genTypedValue(.{
+                        .ty = error_type,
+                        .val = if (is_pl) Value.initTag(.zero) else tv.val,
+                    }),
+                    try self.genTypedValue(.{
+                        .ty = payload_type,
+                        .val = if (tv.val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef),
+                    }),
+                };
+                return self.context.constStruct(&fields, fields.len, .False);
             },
             .Struct => {
                 const fields_len = tv.ty.structFieldCount();
@@ -984,10 +999,10 @@ pub const FuncGen = struct {
                 .is_non_null_ptr => try self.airIsNonNull(inst, true),
                 .is_null         => try self.airIsNull(inst, false),
                 .is_null_ptr     => try self.airIsNull(inst, true),
-                .is_non_err      => try self.airIsErr(inst, true, false),
-                .is_non_err_ptr  => try self.airIsErr(inst, true, true),
-                .is_err          => try self.airIsErr(inst, false, false),
-                .is_err_ptr      => try self.airIsErr(inst, false, true),
+                .is_non_err      => try self.airIsErr(inst, .EQ, false),
+                .is_non_err_ptr  => try self.airIsErr(inst, .EQ, true),
+                .is_err          => try self.airIsErr(inst, .NE, false),
+                .is_err_ptr      => try self.airIsErr(inst, .NE, true),
 
                 .alloc      => try self.airAlloc(inst),
                 .arg        => try self.airArg(inst),
@@ -1098,7 +1113,7 @@ pub const FuncGen = struct {
         const inst_ty = self.air.typeOfIndex(inst);
 
         switch (self.air.typeOf(bin_op.lhs).zigTypeTag()) {
-            .Int, .Bool, .Pointer => {
+            .Int, .Bool, .Pointer, .ErrorSet => {
                 const is_signed = inst_ty.isSignedInt();
                 const operation = switch (op) {
                     .eq => .EQ,
@@ -1256,12 +1271,7 @@ pub const FuncGen = struct {
         const rhs = try self.resolveInst(bin_op.rhs);
 
         const base_ptr = ptr: {
-            const index_type = self.context.intType(32);
-            const indices: [2]*const llvm.Value = .{
-                index_type.constNull(),
-                index_type.constInt(0, .False),
-            };
-            const ptr_field_ptr = self.builder.buildInBoundsGEP(lhs, &indices, 2, "");
+            const ptr_field_ptr = self.builder.buildStructGEP(lhs, 0, "");
             break :ptr self.builder.buildLoad(ptr_field_ptr, "");
         };
 
@@ -1472,7 +1482,7 @@ pub const FuncGen = struct {
                 index_type.constInt(1, .False),
             };
 
-            return self.builder.buildLoad(self.builder.buildInBoundsGEP(operand, &indices, 2, ""), "");
+            return self.builder.buildLoad(self.builder.buildInBoundsGEP(operand, &indices, indices.len, ""), "");
         } else {
             return self.builder.buildExtractValue(operand, 1, "");
         }
@@ -1488,7 +1498,7 @@ pub const FuncGen = struct {
     fn airIsErr(
         self: *FuncGen,
         inst: Air.Inst.Index,
-        invert_logic: bool,
+        op: llvm.IntPredicate,
         operand_is_ptr: bool,
     ) !?*const llvm.Value {
         if (self.liveness.isUnused(inst))
@@ -1498,16 +1508,22 @@ pub const FuncGen = struct {
         const operand = try self.resolveInst(un_op);
         const err_union_ty = self.air.typeOf(un_op);
         const payload_ty = err_union_ty.errorUnionPayload();
+        const err_set_ty = try self.dg.llvmType(Type.initTag(.anyerror));
+        const zero = err_set_ty.constNull();
 
         if (!payload_ty.hasCodeGenBits()) {
             const loaded = if (operand_is_ptr) self.builder.buildLoad(operand, "") else operand;
-            const op: llvm.IntPredicate = if (invert_logic) .EQ else .NE;
-            const err_set_ty = try self.dg.llvmType(Type.initTag(.anyerror));
-            const zero = err_set_ty.constNull();
             return self.builder.buildICmp(op, loaded, zero, "");
         }
 
-        return self.todo("implement 'airIsErr' for error unions with nonzero payload", .{});
+        if (operand_is_ptr) {
+            const err_field_ptr = self.builder.buildStructGEP(operand, 0, "");
+            const loaded = self.builder.buildLoad(err_field_ptr, "");
+            return self.builder.buildICmp(op, loaded, zero, "");
+        }
+
+        const loaded = self.builder.buildExtractValue(operand, 0, "");
+        return self.builder.buildICmp(op, loaded, zero, "");
     }
 
     fn airOptionalPayload(
@@ -1552,9 +1568,11 @@ pub const FuncGen = struct {
             return null;
         }
 
-        _ = operand;
-        _ = operand_is_ptr;
-        return self.todo("implement llvm codegen for 'airErrUnionPayload' for type {}", .{self.air.typeOf(ty_op.operand)});
+        if (operand_is_ptr) {
+            return self.builder.buildStructGEP(operand, 1, "");
+        }
+
+        return self.builder.buildExtractValue(operand, 1, "");
     }
 
     fn airErrUnionErr(
@@ -1574,7 +1592,13 @@ pub const FuncGen = struct {
             if (!operand_is_ptr) return operand;
             return self.builder.buildLoad(operand, "");
         }
-        return self.todo("implement llvm codegen for 'airErrUnionErr'", .{});
+
+        if (operand_is_ptr) {
+            const err_field_ptr = self.builder.buildStructGEP(operand, 0, "");
+            return self.builder.buildLoad(err_field_ptr, "");
+        }
+
+        return self.builder.buildExtractValue(operand, 0, "");
     }
 
     fn airWrapOptional(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
src/codegen/wasm.zig
@@ -1167,12 +1167,18 @@ pub const Context = struct {
                 try leb.writeULEB128(writer, error_index);
             },
             .ErrorUnion => {
-                const data = val.castTag(.error_union).?.data;
                 const error_type = ty.errorUnionSet();
                 const payload_type = ty.errorUnionPayload();
-                if (val.getError()) |_| {
+                if (val.castTag(.eu_payload)) |pl| {
+                    const payload_val = pl.data;
+                    // no error, so write a '0' const
+                    try writer.writeByte(wasm.opcode(.i32_const));
+                    try leb.writeULEB128(writer, @as(u32, 0));
+                    // after the error code, we emit the payload
+                    try self.emitConstant(payload_val, payload_type);
+                } else {
                     // write the error val
-                    try self.emitConstant(data, error_type);
+                    try self.emitConstant(val, error_type);
 
                     // no payload, so write a '0' const
                     const opcode: wasm.Opcode = buildOpcode(.{
@@ -1181,12 +1187,6 @@ pub const Context = struct {
                     });
                     try writer.writeByte(wasm.opcode(opcode));
                     try leb.writeULEB128(writer, @as(u32, 0));
-                } else {
-                    // no error, so write a '0' const
-                    try writer.writeByte(wasm.opcode(.i32_const));
-                    try leb.writeULEB128(writer, @as(u32, 0));
-                    // after the error code, we emit the payload
-                    try self.emitConstant(data, payload_type);
                 }
             },
             .Optional => {
src/codegen.zig
@@ -4815,7 +4815,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type {
                 .ErrorUnion => {
                     const error_type = typed_value.ty.errorUnionSet();
                     const payload_type = typed_value.ty.errorUnionPayload();
-                    const sub_val = typed_value.val.castTag(.error_union).?.data;
+                    const sub_val = typed_value.val.castTag(.eu_payload).?.data;
 
                     if (!payload_type.hasCodeGenBits()) {
                         // We use the error type directly as the type.
src/Sema.zig
@@ -3468,7 +3468,7 @@ fn zirErrUnionPayload(
         if (val.getError()) |name| {
             return sema.mod.fail(&block.base, src, "caught unexpected error '{s}'", .{name});
         }
-        const data = val.castTag(.error_union).?.data;
+        const data = val.castTag(.eu_payload).?.data;
         const result_ty = operand_ty.errorUnionPayload();
         return sema.addConstant(result_ty, data);
     }
@@ -3539,8 +3539,7 @@ fn zirErrUnionCode(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) Compi
 
     if (try sema.resolveDefinedValue(block, src, operand)) |val| {
         assert(val.getError() != null);
-        const data = val.castTag(.error_union).?.data;
-        return sema.addConstant(result_ty, data);
+        return sema.addConstant(result_ty, val);
     }
 
     try sema.requireRuntimeBlock(block, src);
@@ -3566,8 +3565,7 @@ fn zirErrUnionCodePtr(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) Co
     if (try sema.resolveDefinedValue(block, src, operand)) |pointer_val| {
         if (try pointer_val.pointerDeref(sema.arena)) |val| {
             assert(val.getError() != null);
-            const data = val.castTag(.error_union).?.data;
-            return sema.addConstant(result_ty, data);
+            return sema.addConstant(result_ty, val);
         }
     }
 
@@ -8900,7 +8898,9 @@ fn wrapErrorUnion(
     if (try sema.resolveMaybeUndefVal(block, inst_src, inst)) |val| {
         if (inst_ty.zigTypeTag() != .ErrorSet) {
             _ = try sema.coerce(block, dest_payload_ty, inst, inst_src);
-        } else switch (dest_err_set_ty.tag()) {
+            return sema.addConstant(dest_type, try Value.Tag.eu_payload.create(sema.arena, val));
+        }
+        switch (dest_err_set_ty.tag()) {
             .anyerror => {},
             .error_set_single => {
                 const expected_name = val.castTag(.@"error").?.data.name;
@@ -8946,9 +8946,7 @@ fn wrapErrorUnion(
             },
             else => unreachable,
         }
-
-        // Create a SubValue for the error_union payload.
-        return sema.addConstant(dest_type, try Value.Tag.error_union.create(sema.arena, val));
+        return sema.addConstant(dest_type, val);
     }
 
     try sema.requireRuntimeBlock(block, inst_src);
src/value.zig
@@ -129,7 +129,13 @@ pub const Value = extern union {
         /// A specific enum tag, indicated by the field index (declaration order).
         enum_field_index,
         @"error",
-        error_union,
+        /// When the type is error union:
+        /// * If the tag is `.@"error"`, the error union is an error.
+        /// * If the tag is `.eu_payload`, the error union is a payload.
+        /// * A nested error such as `((anyerror!T1)!T2)` in which the the outer error union
+        ///   is non-error, but the inner error union is an error, is represented as
+        ///   a tag of `.eu_payload`, with a sub-tag of `.@"error"`.
+        eu_payload,
         /// A pointer to the payload of an error union, based on a pointer to an error union.
         eu_payload_ptr,
         /// An instance of a struct.
@@ -228,7 +234,7 @@ pub const Value = extern union {
                 => Payload.Decl,
 
                 .repeated,
-                .error_union,
+                .eu_payload,
                 .eu_payload_ptr,
                 => Payload.SubValue,
 
@@ -450,7 +456,7 @@ pub const Value = extern union {
                 return Value{ .ptr_otherwise = &new_payload.base };
             },
             .bytes => return self.copyPayloadShallow(allocator, Payload.Bytes),
-            .repeated, .error_union, .eu_payload_ptr => {
+            .repeated, .eu_payload, .eu_payload_ptr => {
                 const payload = self.cast(Payload.SubValue).?;
                 const new_payload = try allocator.create(Payload.SubValue);
                 new_payload.* = .{
@@ -642,7 +648,10 @@ pub const Value = extern union {
             .float_128 => return out_stream.print("{}", .{val.castTag(.float_128).?.data}),
             .@"error" => return out_stream.print("error.{s}", .{val.castTag(.@"error").?.data.name}),
             // TODO to print this it should be error{ Set, Items }!T(val), but we need the type for that
-            .error_union => return out_stream.print("error_union_val({})", .{val.castTag(.error_union).?.data}),
+            .eu_payload => {
+                try out_stream.writeAll("(eu_payload) ");
+                val = val.castTag(.eu_payload).?.data;
+            },
             .inferred_alloc => return out_stream.writeAll("(inferred allocation value)"),
             .inferred_alloc_comptime => return out_stream.writeAll("(inferred comptime allocation value)"),
             .eu_payload_ptr => {
@@ -1241,7 +1250,7 @@ pub const Value = extern union {
             .eu_payload_ptr => blk: {
                 const err_union_ptr = self.castTag(.eu_payload_ptr).?.data;
                 const err_union_val = (try err_union_ptr.pointerDeref(allocator)) orelse return null;
-                break :blk err_union_val.castTag(.error_union).?.data;
+                break :blk err_union_val.castTag(.eu_payload).?.data;
             },
 
             .zero,
@@ -1351,16 +1360,16 @@ pub const Value = extern union {
     }
 
     /// Valid for all types. Asserts the value is not undefined and not unreachable.
+    /// Prefer `errorUnionIsPayload` to find out whether something is an error or not
+    /// because it works without having to figure out the string.
     pub fn getError(self: Value) ?[]const u8 {
         return switch (self.tag()) {
-            .error_union => {
-                const data = self.castTag(.error_union).?.data;
-                return if (data.tag() == .@"error")
-                    data.castTag(.@"error").?.data.name
-                else
-                    null;
-            },
             .@"error" => self.castTag(.@"error").?.data.name,
+            .int_u64 => @panic("TODO"),
+            .int_i64 => @panic("TODO"),
+            .int_big_positive => @panic("TODO"),
+            .int_big_negative => @panic("TODO"),
+            .one => @panic("TODO"),
             .undef => unreachable,
             .unreachable_value => unreachable,
             .inferred_alloc => unreachable,
@@ -1369,6 +1378,16 @@ pub const Value = extern union {
             else => null,
         };
     }
+
+    /// Assumes the type is an error union. Returns true if and only if the value is
+    /// the error union payload, not an error.
+    pub fn errorUnionIsPayload(val: Value) bool {
+        return switch (val.tag()) {
+            .eu_payload => true,
+            else => false,
+        };
+    }
+
     /// Valid for all types. Asserts the value is not undefined.
     pub fn isFloat(self: Value) bool {
         return switch (self.tag()) {
test/behavior/if.zig
@@ -65,45 +65,3 @@ test "labeled break inside comptime if inside runtime if" {
     }
     try expect(answer == 42);
 }
-
-test "const result loc, runtime if cond, else unreachable" {
-    const Num = enum {
-        One,
-        Two,
-    };
-
-    var t = true;
-    const x = if (t) Num.Two else unreachable;
-    try expect(x == .Two);
-}
-
-test "if prongs cast to expected type instead of peer type resolution" {
-    const S = struct {
-        fn doTheTest(f: bool) !void {
-            var x: i32 = 0;
-            x = if (f) 1 else 2;
-            try expect(x == 2);
-
-            var b = true;
-            const y: i32 = if (b) 1 else 2;
-            try expect(y == 1);
-        }
-    };
-    try S.doTheTest(false);
-    comptime try S.doTheTest(false);
-}
-
-test "while copies its payload" {
-    const S = struct {
-        fn doTheTest() !void {
-            var tmp: ?i32 = 10;
-            if (tmp) |value| {
-                // Modify the original variable
-                tmp = null;
-                try expectEqual(@as(i32, 10), value);
-            } else unreachable;
-        }
-    };
-    try S.doTheTest();
-    comptime try S.doTheTest();
-}
test/behavior/if_stage1.zig
@@ -0,0 +1,45 @@
+const std = @import("std");
+const expect = std.testing.expect;
+const expectEqual = std.testing.expectEqual;
+
+test "const result loc, runtime if cond, else unreachable" {
+    const Num = enum {
+        One,
+        Two,
+    };
+
+    var t = true;
+    const x = if (t) Num.Two else unreachable;
+    try expect(x == .Two);
+}
+
+test "if prongs cast to expected type instead of peer type resolution" {
+    const S = struct {
+        fn doTheTest(f: bool) !void {
+            var x: i32 = 0;
+            x = if (f) 1 else 2;
+            try expect(x == 2);
+
+            var b = true;
+            const y: i32 = if (b) 1 else 2;
+            try expect(y == 1);
+        }
+    };
+    try S.doTheTest(false);
+    comptime try S.doTheTest(false);
+}
+
+test "while copies its payload" {
+    const S = struct {
+        fn doTheTest() !void {
+            var tmp: ?i32 = 10;
+            if (tmp) |value| {
+                // Modify the original variable
+                tmp = null;
+                try expectEqual(@as(i32, 10), value);
+            } else unreachable;
+        }
+    };
+    try S.doTheTest();
+    comptime try S.doTheTest();
+}
test/behavior.zig
@@ -7,6 +7,7 @@ test {
     _ = @import("behavior/generics.zig");
     _ = @import("behavior/eval.zig");
     _ = @import("behavior/pointers.zig");
+    _ = @import("behavior/if.zig");
 
     if (!builtin.zig_is_stage2) {
         // Tests that only pass for stage1.
@@ -100,7 +101,7 @@ test {
         _ = @import("behavior/generics_stage1.zig");
         _ = @import("behavior/hasdecl.zig");
         _ = @import("behavior/hasfield.zig");
-        _ = @import("behavior/if.zig");
+        _ = @import("behavior/if_stage1.zig");
         _ = @import("behavior/import.zig");
         _ = @import("behavior/incomplete_struct_param_tld.zig");
         _ = @import("behavior/inttoptr.zig");