Commit e027492243

Veikka Tuominen <git@vexu.eu>
2022-02-19 10:35:49
stage2: support anon init through error unions and optionals
1 parent 2f0204a
Changed files (5)
src/AstGen.zig
@@ -1384,7 +1384,8 @@ fn arrayInitExprRlPtr(
     array_ty: Zir.Inst.Ref,
 ) InnerError!Zir.Inst.Ref {
     if (array_ty == .none) {
-        return arrayInitExprRlPtrInner(gz, scope, node, result_ptr, elements);
+        const base_ptr = try gz.addUnNode(.array_base_ptr, result_ptr, node);
+        return arrayInitExprRlPtrInner(gz, scope, node, base_ptr, elements);
     }
 
     var as_scope = try gz.makeCoercionScope(scope, array_ty, result_ptr);
@@ -1493,6 +1494,7 @@ fn structInitExpr(
 
     switch (rl) {
         .discard => {
+            // TODO if a type expr is given the fields should be validated for that type
             if (struct_init.ast.type_expr != 0)
                 _ = try typeExpr(gz, scope, struct_init.ast.type_expr);
             for (struct_init.ast.fields) |field_init| {
@@ -1567,7 +1569,8 @@ fn structInitExprRlPtr(
     result_ptr: Zir.Inst.Ref,
 ) InnerError!Zir.Inst.Ref {
     if (struct_init.ast.type_expr == 0) {
-        return structInitExprRlPtrInner(gz, scope, node, struct_init, result_ptr);
+        const base_ptr = try gz.addUnNode(.field_base_ptr, result_ptr, node);
+        return structInitExprRlPtrInner(gz, scope, node, struct_init, base_ptr);
     }
     const ty_inst = try typeExpr(gz, scope, struct_init.ast.type_expr);
 
@@ -2281,6 +2284,8 @@ fn unusedResultExpr(gz: *GenZir, scope: *Scope, statement: Ast.Node.Index) Inner
             .ret_err_value_code,
             .extended,
             .closure_get,
+            .array_base_ptr,
+            .field_base_ptr,
             => break :b false,
 
             // ZIR instructions that are always `noreturn`.
src/print_zir.zig
@@ -235,6 +235,8 @@ const Writer = struct {
             .fence,
             .switch_cond,
             .switch_cond_ref,
+            .array_base_ptr,
+            .field_base_ptr,
             => try self.writeUnNode(stream, inst),
 
             .ref,
src/Sema.zig
@@ -739,6 +739,8 @@ fn analyzeBodyInner(
             .@"await"                     => try sema.zirAwait(block, inst, false),
             .await_nosuspend              => try sema.zirAwait(block, inst, true),
             .extended                     => try sema.zirExtended(block, inst),
+            .array_base_ptr               => try sema.zirArrayBasePtr(block, inst),
+            .field_base_ptr               => try sema.zirFieldBasePtr(block, inst),
 
             .clz       => try sema.zirBitCount(block, inst, .clz,      Value.clz),
             .ctz       => try sema.zirBitCount(block, inst, .ctz,      Value.ctz),
@@ -2584,6 +2586,59 @@ fn zirResolveInferredAlloc(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Com
     }
 }
 
+fn zirArrayBasePtr(
+    sema: *Sema,
+    block: *Block,
+    inst: Zir.Inst.Index,
+) CompileError!Air.Inst.Ref {
+    const inst_data = sema.code.instructions.items(.data)[inst].un_node;
+    const src = inst_data.src();
+
+    const start_ptr = sema.resolveInst(inst_data.operand);
+    var base_ptr = start_ptr;
+    while (true) switch (sema.typeOf(base_ptr).childType().zigTypeTag()) {
+        .ErrorUnion => base_ptr = try sema.analyzeErrUnionPayloadPtr(block, src, base_ptr, false),
+        .Optional => base_ptr = try sema.analyzeOptionalPayloadPtr(block, src, base_ptr, false),
+        else => break,
+    };
+
+    const elem_ty = sema.typeOf(base_ptr).childType();
+    switch (elem_ty.zigTypeTag()) {
+        .Array, .Vector => return base_ptr,
+        .Struct => if (elem_ty.isTuple()) return base_ptr,
+        else => {},
+    }
+    return sema.fail(block, src, "type '{}' does not support array initialization syntax", .{
+        sema.typeOf(start_ptr).childType(),
+    });
+}
+
+fn zirFieldBasePtr(
+    sema: *Sema,
+    block: *Block,
+    inst: Zir.Inst.Index,
+) CompileError!Air.Inst.Ref {
+    const inst_data = sema.code.instructions.items(.data)[inst].un_node;
+    const src = inst_data.src();
+
+    const start_ptr = sema.resolveInst(inst_data.operand);
+    var base_ptr = start_ptr;
+    while (true) switch (sema.typeOf(base_ptr).childType().zigTypeTag()) {
+        .ErrorUnion => base_ptr = try sema.analyzeErrUnionPayloadPtr(block, src, base_ptr, false),
+        .Optional => base_ptr = try sema.analyzeOptionalPayloadPtr(block, src, base_ptr, false),
+        else => break,
+    };
+
+    const elem_ty = sema.typeOf(base_ptr).childType();
+    switch (elem_ty.zigTypeTag()) {
+        .Struct, .Union => return base_ptr,
+        else => {},
+    }
+    return sema.fail(block, src, "type '{}' does not support struct initialization syntax", .{
+        sema.typeOf(start_ptr).childType(),
+    });
+}
+
 fn zirValidateStructInit(
     sema: *Sema,
     block: *Block,
@@ -4377,7 +4432,9 @@ fn analyzeCall(
                 if (payload.data.error_set.tag() == .error_set_inferred) {
                     const node = try sema.gpa.create(Module.Fn.InferredErrorSetListNode);
                     node.data = .{ .func = module_fn };
-                    parent_func.?.inferred_error_sets.prepend(node);
+                    if (parent_func) |some| {
+                        some.inferred_error_sets.prepend(node);
+                    }
 
                     const error_set_ty = try Type.Tag.error_set_inferred.create(sema.arena, &node.data);
                     break :blk try Type.Tag.error_union.create(sema.arena, .{
@@ -5198,9 +5255,20 @@ fn zirOptionalPayloadPtr(
 
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const optional_ptr = sema.resolveInst(inst_data.operand);
+    const src = inst_data.src();
+
+    return sema.analyzeOptionalPayloadPtr(block, src, optional_ptr, safety_check);
+}
+
+fn analyzeOptionalPayloadPtr(
+    sema: *Sema,
+    block: *Block,
+    src: LazySrcLoc,
+    optional_ptr: Air.Inst.Ref,
+    safety_check: bool,
+) CompileError!Air.Inst.Ref {
     const optional_ptr_ty = sema.typeOf(optional_ptr);
     assert(optional_ptr_ty.zigTypeTag() == .Pointer);
-    const src = inst_data.src();
 
     const opt_type = optional_ptr_ty.elemType();
     if (opt_type.zigTypeTag() != .Optional) {
@@ -5216,8 +5284,10 @@ fn zirOptionalPayloadPtr(
 
     if (try sema.resolveDefinedValue(block, src, optional_ptr)) |pointer_val| {
         if (try sema.pointerDeref(block, src, pointer_val, optional_ptr_ty)) |val| {
-            if (val.isNull()) {
-                return sema.fail(block, src, "unable to unwrap null", .{});
+            if (safety_check) {
+                if (val.isNull()) {
+                    return sema.fail(block, src, "unable to unwrap null", .{});
+                }
             }
             // The same Value represents the pointer to the optional and the payload.
             return sema.addConstant(
@@ -5333,8 +5403,19 @@ fn zirErrUnionPayloadPtr(
     defer tracy.end();
 
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
-    const src = inst_data.src();
     const operand = sema.resolveInst(inst_data.operand);
+    const src = inst_data.src();
+
+    return sema.analyzeErrUnionPayloadPtr(block, src, operand, safety_check);
+}
+
+fn analyzeErrUnionPayloadPtr(
+    sema: *Sema,
+    block: *Block,
+    src: LazySrcLoc,
+    operand: Air.Inst.Ref,
+    safety_check: bool,
+) CompileError!Air.Inst.Ref {
     const operand_ty = sema.typeOf(operand);
     assert(operand_ty.zigTypeTag() == .Pointer);
 
@@ -5350,9 +5431,12 @@ fn zirErrUnionPayloadPtr(
 
     if (try sema.resolveDefinedValue(block, src, operand)) |pointer_val| {
         if (try sema.pointerDeref(block, src, pointer_val, operand_ty)) |val| {
-            if (val.getError()) |name| {
-                return sema.fail(block, src, "caught unexpected error '{s}'", .{name});
+            if (safety_check) {
+                if (val.getError()) |name| {
+                    return sema.fail(block, src, "caught unexpected error '{s}'", .{name});
+                }
             }
+
             return sema.addConstant(
                 operand_pointer_ty,
                 try Value.Tag.eu_payload_ptr.create(sema.arena, pointer_val),
@@ -12859,7 +12943,7 @@ fn zirCUndef(
     extended: Zir.Inst.Extended.InstData,
 ) CompileError!Air.Inst.Ref {
     const extra = sema.code.extraData(Zir.Inst.UnNode, extended.operand).data;
-    const src: LazySrcLoc = .{ .node_offset = extra.node };
+    const src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
 
     const name = try sema.resolveConstString(block, src, extra.operand);
     try block.c_import_buf.?.writer().print("#undefine {s}\n", .{name});
@@ -12872,7 +12956,7 @@ fn zirCInclude(
     extended: Zir.Inst.Extended.InstData,
 ) CompileError!Air.Inst.Ref {
     const extra = sema.code.extraData(Zir.Inst.UnNode, extended.operand).data;
-    const src: LazySrcLoc = .{ .node_offset = extra.node };
+    const src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
 
     const name = try sema.resolveConstString(block, src, extra.operand);
     try block.c_import_buf.?.writer().print("#include <{s}>\n", .{name});
@@ -12885,12 +12969,13 @@ fn zirCDefine(
     extended: Zir.Inst.Extended.InstData,
 ) CompileError!Air.Inst.Ref {
     const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
-    const src: LazySrcLoc = .{ .node_offset = extra.node };
+    const name_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = extra.node };
+    const val_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = extra.node };
 
-    const name = try sema.resolveConstString(block, src, extra.lhs);
+    const name = try sema.resolveConstString(block, name_src, extra.lhs);
     const rhs = sema.resolveInst(extra.rhs);
     if (sema.typeOf(rhs).zigTypeTag() != .Void) {
-        const value = try sema.resolveConstString(block, src, extra.rhs);
+        const value = try sema.resolveConstString(block, val_src, extra.rhs);
         try block.c_import_buf.?.writer().print("#define {s} {s}\n", .{ name, value });
     } else {
         try block.c_import_buf.?.writer().print("#define {s}\n", .{name});
src/Zir.zig
@@ -643,6 +643,18 @@ pub const Inst = struct {
         /// Result is a pointer to the value.
         /// Uses the `switch_capture` field.
         switch_capture_multi_ref,
+        /// Given a
+        ///   *A returns *A
+        ///   *E!A returns *A
+        ///   *?A returns *A
+        /// Uses the `un_node` field.
+        array_base_ptr,
+        /// Given a
+        ///   *S returns *S
+        ///   *E!S returns *S
+        ///   *?S returns *S
+        /// Uses the `un_node` field.
+        field_base_ptr,
         /// Given a set of `field_ptr` instructions, assumes they are all part of a struct
         /// initialization expression, and emits compile errors for duplicate fields
         /// as well as missing fields, if applicable.
@@ -1087,6 +1099,8 @@ pub const Inst = struct {
                 .switch_block,
                 .switch_cond,
                 .switch_cond_ref,
+                .array_base_ptr,
+                .field_base_ptr,
                 .validate_struct_init,
                 .validate_struct_init_comptime,
                 .validate_array_init,
@@ -1340,6 +1354,8 @@ pub const Inst = struct {
                 .switch_capture_ref = .switch_capture,
                 .switch_capture_multi = .switch_capture,
                 .switch_capture_multi_ref = .switch_capture,
+                .array_base_ptr = .un_node,
+                .field_base_ptr = .un_node,
                 .validate_struct_init = .pl_node,
                 .validate_struct_init_comptime = .pl_node,
                 .validate_array_init = .pl_node,
test/behavior/struct.zig
@@ -1199,3 +1199,64 @@ test "for loop over pointers to struct, getting field from struct pointer" {
     };
     try S.doTheTest();
 }
+
+test "anon init through error unions and optionals" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    if (true) return error.SkipZigTest; // TODO
+
+    const S = struct {
+        a: u32,
+
+        fn foo() anyerror!?anyerror!@This() {
+            return @This(){ .a = 1 };
+        }
+        fn bar() ?anyerror![2]u8 {
+            return [2]u8{ 1, 2 };
+        }
+
+        fn doTheTest() !void {
+            var a = ((foo() catch unreachable).?) catch unreachable;
+            var b = (bar().?) catch unreachable;
+            try expect(a.a + b[1] == 3);
+        }
+    };
+
+    try S.doTheTest();
+    comptime try S.doTheTest();
+}
+
+test "anon init through optional" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    // not sure why this is needed, we only do the test at comptime
+    if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest;
+
+    const S = struct {
+        a: u32,
+
+        fn doTheTest() !void {
+            var s: ?@This() = null;
+            s = .{ .a = 1 };
+            try expect(s.?.a == 1);
+        }
+    };
+    // try S.doTheTest(); // TODO
+    comptime try S.doTheTest();
+}
+
+test "anon init through error union" {
+    if (builtin.zig_backend == .stage1) return error.SkipZigTest;
+    // not sure why this is needed, we only do the test at comptime
+    if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest;
+
+    const S = struct {
+        a: u32,
+
+        fn doTheTest() !void {
+            var s: anyerror!@This() = error.Foo;
+            s = .{ .a = 1 };
+            try expect((try s).a == 1);
+        }
+    };
+    // try S.doTheTest(); // TODO
+    comptime try S.doTheTest();
+}