Commit 6e3e23a941

mlugg <mlugg@mlugg.co.uk>
2024-08-31 03:20:12
compiler: implement decl literals
Resolves: #9938
1 parent 9e683f0
lib/std/zig/AstGen.zig
@@ -1028,7 +1028,18 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE
             const statements = tree.extra_data[node_datas[node].lhs..node_datas[node].rhs];
             return blockExpr(gz, scope, ri, node, statements, .normal);
         },
-        .enum_literal => return simpleStrTok(gz, ri, main_tokens[node], node, .enum_literal),
+        .enum_literal => if (try ri.rl.resultType(gz, node)) |res_ty| {
+            const str_index = try astgen.identAsString(main_tokens[node]);
+            const res = try gz.addPlNode(.decl_literal, node, Zir.Inst.Field{
+                .lhs = res_ty,
+                .field_name_start = str_index,
+            });
+            switch (ri.rl) {
+                .discard, .none, .ref => unreachable, // no result type
+                .ty, .coerced_ty => return res, // `decl_literal` does the coercion for us
+                .ref_coerced_ty, .ptr, .inferred_ptr, .destructure => return rvalue(gz, ri, res, node),
+            }
+        } else return simpleStrTok(gz, ri, main_tokens[node], node, .enum_literal),
         .error_value => return simpleStrTok(gz, ri, node_datas[node].rhs, node, .error_value),
         // TODO restore this when implementing https://github.com/ziglang/zig/issues/6025
         // .anyframe_literal => return rvalue(gz, ri, .anyframe_type, node),
@@ -2752,6 +2763,8 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
             .err_union_code_ptr,
             .ptr_type,
             .enum_literal,
+            .decl_literal,
+            .decl_literal_no_coerce,
             .merge_error_sets,
             .error_union_type,
             .bit_not,
@@ -5889,22 +5902,21 @@ fn tryExpr(
     }
     const try_lc = LineColumn{ astgen.source_line - parent_gz.decl_line, astgen.source_column };
 
-    const operand_ri: ResultInfo = .{
-        .rl = switch (ri.rl) {
-            .ref => .ref,
-            .ref_coerced_ty => |payload_ptr_ty| .{
-                .ref_coerced_ty = try parent_gz.addUnNode(.try_ref_operand_ty, payload_ptr_ty, node),
-            },
-            else => if (try ri.rl.resultType(parent_gz, node)) |payload_ty| .{
-                // `coerced_ty` is OK due to the `rvalue` call below
-                .coerced_ty = try parent_gz.addUnNode(.try_operand_ty, payload_ty, node),
-            } else .none,
+    const operand_rl: ResultInfo.Loc, const block_tag: Zir.Inst.Tag = switch (ri.rl) {
+        .ref => .{ .ref, .try_ptr },
+        .ref_coerced_ty => |payload_ptr_ty| .{
+            .{ .ref_coerced_ty = try parent_gz.addUnNode(.try_ref_operand_ty, payload_ptr_ty, node) },
+            .try_ptr,
         },
-        .ctx = .error_handling_expr,
+        else => if (try ri.rl.resultType(parent_gz, node)) |payload_ty| .{
+            // `coerced_ty` is OK due to the `rvalue` call below
+            .{ .coerced_ty = try parent_gz.addUnNode(.try_operand_ty, payload_ty, node) },
+            .@"try",
+        } else .{ .none, .@"try" },
     };
+    const operand_ri: ResultInfo = .{ .rl = operand_rl, .ctx = .error_handling_expr };
     // This could be a pointer or value depending on the `ri` parameter.
     const operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, node);
-    const block_tag: Zir.Inst.Tag = if (operand_ri.rl == .ref) .try_ptr else .@"try";
     const try_inst = try parent_gz.makeBlockInst(block_tag, node);
     try parent_gz.instructions.append(astgen.gpa, try_inst);
 
@@ -9916,7 +9928,7 @@ fn callExpr(
 ) InnerError!Zir.Inst.Ref {
     const astgen = gz.astgen;
 
-    const callee = try calleeExpr(gz, scope, call.ast.fn_expr);
+    const callee = try calleeExpr(gz, scope, ri.rl, call.ast.fn_expr);
     const modifier: std.builtin.CallModifier = blk: {
         if (gz.is_comptime) {
             break :blk .compile_time;
@@ -10044,6 +10056,7 @@ const Callee = union(enum) {
 fn calleeExpr(
     gz: *GenZir,
     scope: *Scope,
+    call_rl: ResultInfo.Loc,
     node: Ast.Node.Index,
 ) InnerError!Callee {
     const astgen = gz.astgen;
@@ -10070,6 +10083,19 @@ fn calleeExpr(
                 .field_name_start = str_index,
             } };
         },
+        .enum_literal => if (try call_rl.resultType(gz, node)) |res_ty| {
+            // Decl literal call syntax, e.g.
+            // `const foo: T = .init();`
+            // Look up `init` in `T`, but don't try and coerce it.
+            const str_index = try astgen.identAsString(tree.nodes.items(.main_token)[node]);
+            const callee = try gz.addPlNode(.decl_literal_no_coerce, node, Zir.Inst.Field{
+                .lhs = res_ty,
+                .field_name_start = str_index,
+            });
+            return .{ .direct = callee };
+        } else {
+            return .{ .direct = try expr(gz, scope, .{ .rl = .none }, node) };
+        },
         else => return .{ .direct = try expr(gz, scope, .{ .rl = .none }, node) },
     }
 }
lib/std/zig/Zir.zig
@@ -651,6 +651,14 @@ pub const Inst = struct {
         err_union_code_ptr,
         /// An enum literal. Uses the `str_tok` union field.
         enum_literal,
+        /// A decl literal. This is similar to `field`, but unwraps error unions and optionals,
+        /// and coerces the result to the given type.
+        /// Uses the `pl_node` union field. Payload is `Field`.
+        decl_literal,
+        /// The same as `decl_literal`, but the coercion is omitted. This is used for decl literal
+        /// function call syntax, i.e. `.foo()`.
+        /// Uses the `pl_node` union field. Payload is `Field`.
+        decl_literal_no_coerce,
         /// A switch expression. Uses the `pl_node` union field.
         /// AST node is the switch, payload is `SwitchBlock`.
         switch_block,
@@ -1144,6 +1152,8 @@ pub const Inst = struct {
                 .err_union_code_ptr,
                 .ptr_type,
                 .enum_literal,
+                .decl_literal,
+                .decl_literal_no_coerce,
                 .merge_error_sets,
                 .error_union_type,
                 .bit_not,
@@ -1442,6 +1452,8 @@ pub const Inst = struct {
                 .err_union_code_ptr,
                 .ptr_type,
                 .enum_literal,
+                .decl_literal,
+                .decl_literal_no_coerce,
                 .merge_error_sets,
                 .error_union_type,
                 .bit_not,
@@ -1697,6 +1709,8 @@ pub const Inst = struct {
                 .err_union_code = .un_node,
                 .err_union_code_ptr = .un_node,
                 .enum_literal = .str_tok,
+                .decl_literal = .pl_node,
+                .decl_literal_no_coerce = .pl_node,
                 .switch_block = .pl_node,
                 .switch_block_ref = .pl_node,
                 .switch_block_err_union = .pl_node,
@@ -3842,6 +3856,8 @@ fn findDeclsInner(
         .err_union_code,
         .err_union_code_ptr,
         .enum_literal,
+        .decl_literal,
+        .decl_literal_no_coerce,
         .validate_deref,
         .validate_destructure,
         .field_type_ref,
src/print_zir.zig
@@ -462,6 +462,8 @@ const Writer = struct {
 
             .field_val,
             .field_ptr,
+            .decl_literal,
+            .decl_literal_no_coerce,
             => try self.writePlNodeField(stream, inst),
 
             .field_ptr_named,
src/Sema.zig
@@ -1072,6 +1072,8 @@ fn analyzeBodyInner(
             .indexable_ptr_elem_type      => try sema.zirIndexablePtrElemType(block, inst),
             .vector_elem_type             => try sema.zirVectorElemType(block, inst),
             .enum_literal                 => try sema.zirEnumLiteral(block, inst),
+            .decl_literal                 => try sema.zirDeclLiteral(block, inst, true),
+            .decl_literal_no_coerce       => try sema.zirDeclLiteral(block, inst, false),
             .int_from_enum                => try sema.zirIntFromEnum(block, inst),
             .enum_from_int                => try sema.zirEnumFromInt(block, inst),
             .err_union_code               => try sema.zirErrUnionCode(block, inst),
@@ -8874,6 +8876,54 @@ fn zirEnumLiteral(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     })));
 }
 
+fn zirDeclLiteral(sema: *Sema, block: *Block, inst: Zir.Inst.Index, do_coerce: bool) CompileError!Air.Inst.Ref {
+    const tracy = trace(@src());
+    defer tracy.end();
+
+    const pt = sema.pt;
+    const zcu = pt.zcu;
+    const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].pl_node;
+    const src = block.nodeOffset(inst_data.src_node);
+    const extra = sema.code.extraData(Zir.Inst.Field, inst_data.payload_index).data;
+    const name = try zcu.intern_pool.getOrPutString(
+        sema.gpa,
+        pt.tid,
+        sema.code.nullTerminatedString(extra.field_name_start),
+        .no_embedded_nulls,
+    );
+    const orig_ty = sema.resolveType(block, src, extra.lhs) catch |err| switch (err) {
+        error.GenericPoison => {
+            // Treat this as a normal enum literal.
+            return Air.internedToRef(try pt.intern(.{ .enum_literal = name }));
+        },
+        else => |e| return e,
+    };
+
+    var ty = orig_ty;
+    while (true) switch (ty.zigTypeTag(zcu)) {
+        .error_union => ty = ty.errorUnionPayload(zcu),
+        .optional => ty = ty.optionalChild(zcu),
+        .enum_literal, .error_set => {
+            // Treat this as a normal enum literal.
+            return Air.internedToRef(try pt.intern(.{ .enum_literal = name }));
+        },
+        else => break,
+    };
+
+    const result = try sema.fieldVal(block, src, Air.internedToRef(ty.toIntern()), name, src);
+
+    // Decl literals cannot lookup runtime `var`s.
+    if (!try sema.isComptimeKnown(result)) {
+        return sema.fail(block, src, "decl literal must be comptime-known", .{});
+    }
+
+    if (do_coerce) {
+        return sema.coerce(block, orig_ty, result, src);
+    } else {
+        return result;
+    }
+}
+
 fn zirIntFromEnum(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
     const pt = sema.pt;
     const zcu = pt.zcu;
test/behavior/decl_literals.zig
@@ -0,0 +1,38 @@
+const builtin = @import("builtin");
+const std = @import("std");
+const expect = std.testing.expect;
+
+test "decl literal" {
+    const S = struct {
+        x: u32,
+        const foo: @This() = .{ .x = 123 };
+    };
+
+    const val: S = .foo;
+    try expect(val.x == 123);
+}
+
+test "call decl literal" {
+    const S = struct {
+        x: u32,
+        fn init() @This() {
+            return .{ .x = 123 };
+        }
+    };
+
+    const val: S = .init();
+    try expect(val.x == 123);
+}
+
+test "call decl literal with error union" {
+    const S = struct {
+        x: u32,
+        fn init(err: bool) !@This() {
+            if (err) return error.Bad;
+            return .{ .x = 123 };
+        }
+    };
+
+    const val: S = try .init(false);
+    try expect(val.x == 123);
+}
test/cases/compile_errors/cast_enum_literal_to_enum_but_it_doesnt_match.zig
@@ -11,5 +11,5 @@ export fn entry() void {
 // backend=stage2
 // target=native
 //
-// :6:21: error: no field named 'c' in enum 'tmp.Foo'
+// :6:21: error: enum 'tmp.Foo' has no member named 'c'
 // :1:13: note: enum declared here
test/cases/compile_errors/comptime_arg_to_generic_fn_callee_error.zig
@@ -17,5 +17,5 @@ pub export fn entry() void {
 // backend=stage2
 // target=native
 //
-// :7:28: error: no field named 'c' in enum 'meta.FieldEnum(tmp.MyStruct)'
+// :7:28: error: enum 'meta.FieldEnum(tmp.MyStruct)' has no member named 'c'
 // :?:?: note: enum declared here
test/behavior.zig
@@ -21,6 +21,7 @@ test {
     _ = @import("behavior/cast_int.zig");
     _ = @import("behavior/comptime_memory.zig");
     _ = @import("behavior/const_slice_child.zig");
+    _ = @import("behavior/decl_literals.zig");
     _ = @import("behavior/decltest.zig");
     _ = @import("behavior/duplicated_test_names.zig");
     _ = @import("behavior/defer.zig");