Commit 569ae762e1

mlugg <mlugg@mlugg.co.uk>
2023-06-25 03:26:48
compiler: allow cast builtins to coerce result to error union or optional
Also improves some error messages
1 parent 8afadee
src/AstGen.zig
@@ -8050,17 +8050,9 @@ fn ptrCast(
     }
 
     // Full cast including result type
-    const need_result_type_builtin = if (flags.ptr_cast)
-        "@ptrCast"
-    else if (flags.align_cast)
-        "@alignCast"
-    else if (flags.addrspace_cast)
-        "@addrSpaceCast"
-    else
-        unreachable;
 
     const cursor = maybeAdvanceSourceCursorToMainToken(gz, root_node);
-    const result_type = try ri.rl.resultType(gz, root_node, need_result_type_builtin);
+    const result_type = try ri.rl.resultType(gz, root_node, flags.needResultTypeBuiltinName());
     const operand = try expr(gz, scope, .{ .rl = .none }, node);
     try emitDbgStmt(gz, cursor);
     const result = try gz.addExtendedPayloadSmall(.ptr_cast_full, flags_i, Zir.Inst.BinNode{
src/Module.zig
@@ -34,6 +34,7 @@ const isUpDir = @import("introspect.zig").isUpDir;
 const clang = @import("clang.zig");
 const InternPool = @import("InternPool.zig");
 const Alignment = InternPool.Alignment;
+const BuiltinFn = @import("BuiltinFn.zig");
 
 comptime {
     @setEvalBranchQuota(4000);
@@ -2273,6 +2274,41 @@ pub const SrcLoc = struct {
             .node_offset_builtin_call_arg3 => |n| return src_loc.byteOffsetBuiltinCallArg(gpa, n, 3),
             .node_offset_builtin_call_arg4 => |n| return src_loc.byteOffsetBuiltinCallArg(gpa, n, 4),
             .node_offset_builtin_call_arg5 => |n| return src_loc.byteOffsetBuiltinCallArg(gpa, n, 5),
+            .node_offset_ptrcast_operand => |node_off| {
+                const tree = try src_loc.file_scope.getTree(gpa);
+                const main_tokens = tree.nodes.items(.main_token);
+                const node_datas = tree.nodes.items(.data);
+                const node_tags = tree.nodes.items(.tag);
+
+                var node = src_loc.declRelativeToNodeIndex(node_off);
+                while (true) {
+                    switch (node_tags[node]) {
+                        .builtin_call_two, .builtin_call_two_comma => {},
+                        else => break,
+                    }
+
+                    if (node_datas[node].lhs == 0) break; // 0 args
+                    if (node_datas[node].rhs != 0) break; // 2 args
+
+                    const builtin_token = main_tokens[node];
+                    const builtin_name = tree.tokenSlice(builtin_token);
+                    const info = BuiltinFn.list.get(builtin_name) orelse break;
+
+                    switch (info.tag) {
+                        else => break,
+                        .ptr_cast,
+                        .align_cast,
+                        .addrspace_cast,
+                        .const_cast,
+                        .volatile_cast,
+                        => {},
+                    }
+
+                    node = node_datas[node].lhs;
+                }
+
+                return nodeToSpan(tree, node);
+            },
             .node_offset_array_access_index => |node_off| {
                 const tree = try src_loc.file_scope.getTree(gpa);
                 const node_datas = tree.nodes.items(.data);
@@ -2887,6 +2923,9 @@ pub const LazySrcLoc = union(enum) {
     node_offset_builtin_call_arg3: i32,
     node_offset_builtin_call_arg4: i32,
     node_offset_builtin_call_arg5: i32,
+    /// Like `node_offset_builtin_call_arg0` but recurses through arbitrarily many calls
+    /// to pointer cast builtins.
+    node_offset_ptrcast_operand: i32,
     /// The source location points to the index expression of an array access
     /// expression, found by taking this AST node index offset from the containing
     /// Decl AST node, which points to an array access AST node. Next, navigate
@@ -3145,6 +3184,7 @@ pub const LazySrcLoc = union(enum) {
             .node_offset_builtin_call_arg3,
             .node_offset_builtin_call_arg4,
             .node_offset_builtin_call_arg5,
+            .node_offset_ptrcast_operand,
             .node_offset_array_access_index,
             .node_offset_slice_ptr,
             .node_offset_slice_start,
src/Sema.zig
@@ -1820,8 +1820,25 @@ pub fn resolveType(sema: *Sema, block: *Block, src: LazySrcLoc, zir_ref: Zir.Ins
     return ty;
 }
 
-fn resolveCastDestType(sema: *Sema, block: *Block, src: LazySrcLoc, zir_ref: Zir.Inst.Ref, builtin_name: []const u8) !Type {
-    return sema.resolveType(block, src, zir_ref) catch |err| switch (err) {
+fn resolveCastDestType(
+    sema: *Sema,
+    block: *Block,
+    src: LazySrcLoc,
+    zir_ref: Zir.Inst.Ref,
+    strat: enum { remove_eu_opt, remove_eu, remove_opt },
+    builtin_name: []const u8,
+) !Type {
+    const mod = sema.mod;
+    const remove_eu = switch (strat) {
+        .remove_eu_opt, .remove_eu => true,
+        .remove_opt => false,
+    };
+    const remove_opt = switch (strat) {
+        .remove_eu_opt, .remove_opt => true,
+        .remove_eu => false,
+    };
+
+    const raw_ty = sema.resolveType(block, src, zir_ref) catch |err| switch (err) {
         error.GenericPoison => {
             // Cast builtins use their result type as the destination type, but
             // it could be an anytype argument, which we can't catch in AstGen.
@@ -1836,6 +1853,18 @@ fn resolveCastDestType(sema: *Sema, block: *Block, src: LazySrcLoc, zir_ref: Zir
         },
         else => |e| return e,
     };
+
+    if (remove_eu and raw_ty.zigTypeTag(mod) == .ErrorUnion) {
+        const eu_child = raw_ty.errorUnionPayload(mod);
+        if (remove_opt and eu_child.zigTypeTag(mod) == .Optional) {
+            return eu_child.childType(mod);
+        }
+        return eu_child;
+    }
+    if (remove_opt and raw_ty.zigTypeTag(mod) == .Optional) {
+        return raw_ty.childType(mod);
+    }
+    return raw_ty;
 }
 
 fn analyzeAsType(
@@ -8304,7 +8333,7 @@ fn zirEnumFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
     const src = inst_data.src();
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@enumFromInt");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu_opt, "@enumFromInt");
     const operand = try sema.resolveInst(extra.rhs);
 
     if (dest_ty.zigTypeTag(mod) != .Enum) {
@@ -9600,7 +9629,7 @@ fn zirIntCast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
 
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@intCast");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu_opt, "@intCast");
     const operand = try sema.resolveInst(extra.rhs);
 
     return sema.intCast(block, inst_data.src(), dest_ty, src, operand, operand_src, true);
@@ -9761,7 +9790,7 @@ fn zirBitcast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
 
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@bitCast");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu_opt, "@bitCast");
     const operand = try sema.resolveInst(extra.rhs);
     const operand_ty = sema.typeOf(operand);
     switch (dest_ty.zigTypeTag(mod)) {
@@ -9904,7 +9933,7 @@ fn zirFloatCast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
 
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@floatCast");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu_opt, "@floatCast");
     const operand = try sema.resolveInst(extra.rhs);
 
     const target = mod.getTarget();
@@ -20706,7 +20735,7 @@ fn zirIntFromFloat(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileErro
     const src = inst_data.src();
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@intFromFloat");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu_opt, "@intFromFloat");
     const operand = try sema.resolveInst(extra.rhs);
     const operand_ty = sema.typeOf(operand);
 
@@ -20746,7 +20775,7 @@ fn zirFloatFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileErro
     const src = inst_data.src();
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@floatFromInt");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu_opt, "@floatFromInt");
     const operand = try sema.resolveInst(extra.rhs);
     const operand_ty = sema.typeOf(operand);
 
@@ -20775,7 +20804,7 @@ fn zirPtrFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!
     const operand_res = try sema.resolveInst(extra.rhs);
     const operand_coerced = try sema.coerce(block, Type.usize, operand_res, operand_src);
 
-    const ptr_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@ptrFromInt");
+    const ptr_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu, "@ptrFromInt");
     try sema.checkPtrType(block, src, ptr_ty);
     const elem_ty = ptr_ty.elemType2(mod);
     const ptr_align = try ptr_ty.ptrAlignmentAdvanced(mod, sema);
@@ -20833,7 +20862,7 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
     const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
     const src = LazySrcLoc.nodeOffset(extra.node);
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@errSetCast");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu_opt, "@errSetCast");
     const operand = try sema.resolveInst(extra.rhs);
     const operand_ty = sema.typeOf(operand);
     try sema.checkErrorSetType(block, src, dest_ty);
@@ -20915,12 +20944,12 @@ fn zirErrSetCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstDat
 }
 
 fn zirPtrCastFull(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
-    const flags = @as(Zir.Inst.FullPtrCastFlags, @bitCast(@as(u5, @truncate(extended.small))));
+    const flags: Zir.Inst.FullPtrCastFlags = @bitCast(@as(u5, @truncate(extended.small)));
     const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
     const src = LazySrcLoc.nodeOffset(extra.node);
-    const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
+    const operand_src: LazySrcLoc = .{ .node_offset_ptrcast_operand = extra.node };
     const operand = try sema.resolveInst(extra.rhs);
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@ptrCast"); // TODO: better error message (builtin name)
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu, flags.needResultTypeBuiltinName());
     return sema.ptrCastFull(
         block,
         flags,
@@ -20936,7 +20965,7 @@ fn zirPtrCast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
     const src = inst_data.src();
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@ptrCast");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu, "@ptrCast");
     const operand = try sema.resolveInst(extra.rhs);
 
     return sema.ptrCastFull(
@@ -21325,7 +21354,7 @@ fn zirPtrCastNoDest(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.Inst
     const flags = @as(Zir.Inst.FullPtrCastFlags, @bitCast(@as(u5, @truncate(extended.small))));
     const extra = sema.code.extraData(Zir.Inst.UnNode, extended.operand).data;
     const src = LazySrcLoc.nodeOffset(extra.node);
-    const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
+    const operand_src: LazySrcLoc = .{ .node_offset_ptrcast_operand = extra.node };
     const operand = try sema.resolveInst(extra.operand);
     const operand_ty = sema.typeOf(operand);
     try sema.checkPtrOperand(block, operand_src, operand_ty);
@@ -21349,7 +21378,7 @@ fn zirTruncate(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
     const src = inst_data.src();
     const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
-    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, "@truncate");
+    const dest_ty = try sema.resolveCastDestType(block, src, extra.lhs, .remove_eu_opt, "@truncate");
     const dest_scalar_ty = try sema.checkIntOrVectorAllowComptime(block, dest_ty, src);
     const operand = try sema.resolveInst(extra.rhs);
     const operand_ty = sema.typeOf(operand);
src/Zir.zig
@@ -2820,6 +2820,13 @@ pub const Inst = struct {
         addrspace_cast: bool = false,
         const_cast: bool = false,
         volatile_cast: bool = false,
+
+        pub inline fn needResultTypeBuiltinName(flags: FullPtrCastFlags) []const u8 {
+            if (flags.ptr_cast) return "@ptrCast";
+            if (flags.align_cast) return "@alignCast";
+            if (flags.addrspace_cast) return "@addrSpaceCast";
+            unreachable;
+        }
     };
 
     /// Trailing: