Commit c2e66d9bab

Andrew Kelley <andrew@ziglang.org>
2021-07-08 05:47:21
stage2: basic inferred error set support
* Inferred error sets are stored in the return Type of the function, owned by the Module.Fn. So it cleans up that memory in deinit(). * Sema: update the inferred error set in zirRetErrValue - Update relevant code in wrapErrorUnion * C backend: improve some some instructions to take advantage of liveness analysis to avoid being emitted when unused. * C backend: when an error union has a payload type with no runtime bits, emit the error union as the same type as the error set.
1 parent 5c8bd44
Changed files (5)
src/codegen/c.zig
@@ -360,6 +360,12 @@ pub const DeclGen = struct {
                 const error_type = t.errorUnionSet();
                 const payload_type = t.errorUnionChild();
                 const data = val.castTag(.error_union).?.data;
+
+                if (!payload_type.hasCodeGenBits()) {
+                    // We use the error type directly as the type.
+                    return dg.renderValue(writer, error_type, data);
+                }
+
                 try writer.writeByte('(');
                 try dg.renderType(writer, t);
                 try writer.writeAll("){");
@@ -604,6 +610,10 @@ pub const DeclGen = struct {
                 const child_type = t.errorUnionChild();
                 const err_set_type = t.errorUnionSet();
 
+                if (!child_type.hasCodeGenBits()) {
+                    return dg.renderType(w, err_set_type);
+                }
+
                 var buffer = std.ArrayList(u8).init(dg.typedefs.allocator);
                 defer buffer.deinit();
                 const bw = buffer.writer();
@@ -613,7 +623,7 @@ pub const DeclGen = struct {
                 try bw.writeAll(" payload; uint16_t error; } ");
                 const name_index = buffer.items.len;
                 if (err_set_type.castTag(.error_set_inferred)) |inf_err_set_payload| {
-                    const func = inf_err_set_payload.data;
+                    const func = inf_err_set_payload.data.func;
                     try bw.print("zig_E_{s};\n", .{func.owner_decl.name});
                 } else {
                     try bw.print("zig_E_{s}_{s};\n", .{
@@ -895,10 +905,10 @@ pub fn genBody(o: *Object, body: ir.Body) error{ AnalysisFail, OutOfMemory }!voi
             .ref => try genRef(o, inst.castTag(.ref).?),
             .struct_field_ptr => try genStructFieldPtr(o, inst.castTag(.struct_field_ptr).?),
 
-            .is_err => try genIsErr(o, inst.castTag(.is_err).?, "", "!="),
-            .is_non_err => try genIsErr(o, inst.castTag(.is_non_err).?, "", "=="),
-            .is_err_ptr => try genIsErr(o, inst.castTag(.is_err_ptr).?, "[0]", "!="),
-            .is_non_err_ptr => try genIsErr(o, inst.castTag(.is_non_err_ptr).?, "[0]", "=="),
+            .is_err => try genIsErr(o, inst.castTag(.is_err).?, "", ".", "!="),
+            .is_non_err => try genIsErr(o, inst.castTag(.is_non_err).?, "", ".", "=="),
+            .is_err_ptr => try genIsErr(o, inst.castTag(.is_err_ptr).?, "*", "->", "!="),
+            .is_non_err_ptr => try genIsErr(o, inst.castTag(.is_non_err_ptr).?, "*", "->", "=="),
 
             .unwrap_errunion_payload => try genUnwrapErrUnionPay(o, inst.castTag(.unwrap_errunion_payload).?),
             .unwrap_errunion_err => try genUnwrapErrUnionErr(o, inst.castTag(.unwrap_errunion_err).?),
@@ -1384,9 +1394,25 @@ fn genStructFieldPtr(o: *Object, inst: *Inst.StructFieldPtr) !CValue {
 
 // *(E!T) -> E NOT *E
 fn genUnwrapErrUnionErr(o: *Object, inst: *Inst.UnOp) !CValue {
+    if (inst.base.isUnused())
+        return CValue.none;
+
     const writer = o.writer();
     const operand = try o.resolveInst(inst.operand);
 
+    const payload_ty = inst.operand.ty.errorUnionChild();
+    if (!payload_ty.hasCodeGenBits()) {
+        if (inst.operand.ty.zigTypeTag() == .Pointer) {
+            const local = try o.allocLocal(inst.base.ty, .Const);
+            try writer.writeAll(" = *");
+            try o.writeCValue(writer, operand);
+            try writer.writeAll(";\n");
+            return local;
+        } else {
+            return operand;
+        }
+    }
+
     const maybe_deref = if (inst.operand.ty.zigTypeTag() == .Pointer) "->" else ".";
 
     const local = try o.allocLocal(inst.base.ty, .Const);
@@ -1396,10 +1422,19 @@ fn genUnwrapErrUnionErr(o: *Object, inst: *Inst.UnOp) !CValue {
     try writer.print("){s}error;\n", .{maybe_deref});
     return local;
 }
+
 fn genUnwrapErrUnionPay(o: *Object, inst: *Inst.UnOp) !CValue {
+    if (inst.base.isUnused())
+        return CValue.none;
+
     const writer = o.writer();
     const operand = try o.resolveInst(inst.operand);
 
+    const payload_ty = inst.operand.ty.errorUnionChild();
+    if (!payload_ty.hasCodeGenBits()) {
+        return CValue.none;
+    }
+
     const maybe_deref = if (inst.operand.ty.zigTypeTag() == .Pointer) "->" else ".";
     const maybe_addrof = if (inst.base.ty.zigTypeTag() == .Pointer) "&" else "";
 
@@ -1448,14 +1483,26 @@ fn genWrapErrUnionPay(o: *Object, inst: *Inst.UnOp) !CValue {
     return local;
 }
 
-fn genIsErr(o: *Object, inst: *Inst.UnOp, deref_suffix: []const u8, op_str: []const u8) !CValue {
+fn genIsErr(
+    o: *Object,
+    inst: *Inst.UnOp,
+    deref_prefix: [*:0]const u8,
+    deref_suffix: [*:0]const u8,
+    op_str: [*:0]const u8,
+) !CValue {
     const writer = o.writer();
     const operand = try o.resolveInst(inst.operand);
-
     const local = try o.allocLocal(Type.initTag(.bool), .Const);
-    try writer.writeAll(" = (");
-    try o.writeCValue(writer, operand);
-    try writer.print("){s}.error {s} 0;\n", .{ deref_suffix, op_str });
+    const payload_ty = inst.operand.ty.errorUnionChild();
+    if (!payload_ty.hasCodeGenBits()) {
+        try writer.print(" = {s}", .{deref_prefix});
+        try o.writeCValue(writer, operand);
+        try writer.print(" {s} 0;\n", .{op_str});
+    } else {
+        try writer.writeAll(" = ");
+        try o.writeCValue(writer, operand);
+        try writer.print("{s}error {s} 0;\n", .{ deref_suffix, op_str });
+    }
     return local;
 }
 
src/Module.zig
@@ -777,8 +777,19 @@ pub const Fn = struct {
     }
 
     pub fn deinit(func: *Fn, gpa: *Allocator) void {
-        _ = func;
-        _ = gpa;
+        if (func.getInferredErrorSet()) |map| {
+            map.deinit(gpa);
+        }
+    }
+
+    pub fn getInferredErrorSet(func: *Fn) ?*std.StringHashMapUnmanaged(void) {
+        const ret_ty = func.owner_decl.ty.fnReturnType();
+        if (ret_ty.zigTypeTag() == .ErrorUnion) {
+            if (ret_ty.errorUnionSet().castTag(.error_set_inferred)) |payload| {
+                return &payload.data.map;
+            }
+        }
+        return null;
     }
 };
 
src/Sema.zig
@@ -3139,7 +3139,10 @@ fn funcCommon(
         }
 
         const return_type = if (!inferred_error_set) bare_return_type else blk: {
-            const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, new_func);
+            const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, .{
+                .func = new_func,
+                .map = .{},
+            });
             break :blk try Type.Tag.error_union.create(sema.arena, .{
                 .error_set = error_set_ty,
                 .payload = bare_return_type,
@@ -5424,12 +5427,8 @@ fn zirRetErrValue(
 
     // Add the error tag to the inferred error set of the in-scope function.
     if (sema.func) |func| {
-        const fn_ty = func.owner_decl.ty;
-        const fn_ret_ty = fn_ty.fnReturnType();
-        if (fn_ret_ty.zigTypeTag() == .ErrorUnion and
-            fn_ret_ty.errorUnionSet().tag() == .error_set_inferred)
-        {
-            return sema.mod.fail(&block.base, src, "TODO: Sema.zirRetErrValue", .{});
+        if (func.getInferredErrorSet()) |map| {
+            _ = try map.getOrPut(sema.gpa, err_name);
         }
     }
     // Return the error code from the function.
@@ -7535,6 +7534,18 @@ fn wrapErrorUnion(sema: *Sema, block: *Scope.Block, dest_type: Type, inst: *Inst
                     );
                 }
             },
+            .error_set_inferred => {
+                const expected_name = val.castTag(.@"error").?.data.name;
+                const map = &err_union.data.error_set.castTag(.error_set_inferred).?.data.map;
+                if (!map.contains(expected_name)) {
+                    return sema.mod.fail(
+                        &block.base,
+                        inst.src,
+                        "expected type '{}', found type '{}'",
+                        .{ err_union.data.error_set, inst.ty },
+                    );
+                }
+            },
             else => unreachable,
         }
 
src/type.zig
@@ -1041,7 +1041,7 @@ pub const Type = extern union {
                     return writer.writeAll(std.mem.spanZ(error_set.owner_decl.name));
                 },
                 .error_set_inferred => {
-                    const func = ty.castTag(.error_set_inferred).?.data;
+                    const func = ty.castTag(.error_set_inferred).?.data.func;
                     return writer.print("(inferred error set of {s})", .{func.owner_decl.name});
                 },
                 .error_set_single => {
@@ -3154,7 +3154,10 @@ pub const Type = extern union {
             pub const base_tag = Tag.error_set_inferred;
 
             base: Payload = Payload{ .tag = base_tag },
-            data: *Module.Fn,
+            data: struct {
+                func: *Module.Fn,
+                map: std.StringHashMapUnmanaged(void),
+            },
         };
 
         pub const Pointer = struct {
test/stage2/cbe.zig
@@ -804,6 +804,26 @@ pub fn addCases(ctx: *TestContext) !void {
         });
     }
 
+    {
+        var case = ctx.exeFromCompiledC("inferred error sets", .{});
+
+        case.addCompareOutput(
+            \\pub export fn main() c_int {
+            \\    if (foo()) |_| {
+            \\        @panic("test fail");
+            \\    } else |err| {
+            \\        if (err != error.ItBroke) {
+            \\            @panic("test fail");
+            \\        }
+            \\    }
+            \\    return 0;
+            \\}
+            \\fn foo() !void {
+            \\    return error.ItBroke;
+            \\}
+        , "");
+    }
+
     ctx.h("simple header", linux_x64,
         \\export fn start() void{}
     ,