Commit cb785b9c6b

Andrew Kelley <andrew@ziglang.org>
2021-11-10 06:58:27
Sema: implement coerce_result_ptr for optionals
New AIR instruction: `optional_payload_ptr_set` It's like `optional_payload_ptr` except it sets the non-null bit. When storing to the payload via a result location that is an optional, `optional_payload_ptr_set` is now emitted. There is a new algorithm in `zirCoerceResultPtr` which stores a dummy value through the result pointer into a temporary block, and then pops off the AIR instructions from the temporary block in order to determine how to transform the result location pointer in case any in-between coercions need to happen. Fixes a couple of behavior tests regarding optionals.
1 parent 008b0ec
src/arch/aarch64/CodeGen.zig
@@ -592,6 +592,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
                     .optional_payload           => try self.airOptionalPayload(inst),
                     .optional_payload_ptr       => try self.airOptionalPayloadPtr(inst),
+                    .optional_payload_ptr_set   => try self.airOptionalPayloadPtrSet(inst),
                     .unwrap_errunion_err        => try self.airUnwrapErrErr(inst),
                     .unwrap_errunion_payload    => try self.airUnwrapErrPayload(inst),
                     .unwrap_errunion_err_ptr    => try self.airUnwrapErrErrPtr(inst),
@@ -1010,6 +1011,12 @@ fn airOptionalPayloadPtr(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
+fn airOptionalPayloadPtrSet(self: *Self, inst: Air.Inst.Index) !void {
+    const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+    const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement .optional_payload_ptr_set for {}", .{self.target.cpu.arch});
+    return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
+}
+
 fn airUnwrapErrErr(self: *Self, inst: Air.Inst.Index) !void {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement unwrap error union error for {}", .{self.target.cpu.arch});
src/arch/arm/CodeGen.zig
@@ -510,6 +510,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
                     .optional_payload           => try self.airOptionalPayload(inst),
                     .optional_payload_ptr       => try self.airOptionalPayloadPtr(inst),
+                    .optional_payload_ptr_set   => try self.airOptionalPayloadPtrSet(inst),
                     .unwrap_errunion_err        => try self.airUnwrapErrErr(inst),
                     .unwrap_errunion_payload    => try self.airUnwrapErrPayload(inst),
                     .unwrap_errunion_err_ptr    => try self.airUnwrapErrErrPtr(inst),
@@ -1008,6 +1009,12 @@ fn airOptionalPayloadPtr(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
+fn airOptionalPayloadPtrSet(self: *Self, inst: Air.Inst.Index) !void {
+    const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+    const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement .optional_payload_ptr_set for {}", .{self.target.cpu.arch});
+    return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
+}
+
 fn airUnwrapErrErr(self: *Self, inst: Air.Inst.Index) !void {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement unwrap error union error for {}", .{self.target.cpu.arch});
src/arch/riscv64/CodeGen.zig
@@ -505,6 +505,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
                     .optional_payload           => try self.airOptionalPayload(inst),
                     .optional_payload_ptr       => try self.airOptionalPayloadPtr(inst),
+                    .optional_payload_ptr_set   => try self.airOptionalPayloadPtrSet(inst),
                     .unwrap_errunion_err        => try self.airUnwrapErrErr(inst),
                     .unwrap_errunion_payload    => try self.airUnwrapErrPayload(inst),
                     .unwrap_errunion_err_ptr    => try self.airUnwrapErrErrPtr(inst),
@@ -926,6 +927,12 @@ fn airOptionalPayloadPtr(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
+fn airOptionalPayloadPtrSet(self: *Self, inst: Air.Inst.Index) !void {
+    const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+    const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement .optional_payload_ptr_set for {}", .{self.target.cpu.arch});
+    return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
+}
+
 fn airUnwrapErrErr(self: *Self, inst: Air.Inst.Index) !void {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement unwrap error union error for {}", .{self.target.cpu.arch});
src/arch/x86_64/CodeGen.zig
@@ -578,6 +578,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
 
                     .optional_payload           => try self.airOptionalPayload(inst),
                     .optional_payload_ptr       => try self.airOptionalPayloadPtr(inst),
+                    .optional_payload_ptr_set   => try self.airOptionalPayloadPtrSet(inst),
                     .unwrap_errunion_err        => try self.airUnwrapErrErr(inst),
                     .unwrap_errunion_payload    => try self.airUnwrapErrPayload(inst),
                     .unwrap_errunion_err_ptr    => try self.airUnwrapErrErrPtr(inst),
@@ -1043,6 +1044,15 @@ fn airOptionalPayloadPtr(self: *Self, inst: Air.Inst.Index) !void {
     return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
 }
 
+fn airOptionalPayloadPtrSet(self: *Self, inst: Air.Inst.Index) !void {
+    const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+    const result: MCValue = if (self.liveness.isUnused(inst))
+        .dead
+    else
+        return self.fail("TODO implement .optional_payload_ptr_set for {}", .{self.target.cpu.arch});
+    return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
+}
+
 fn airUnwrapErrErr(self: *Self, inst: Air.Inst.Index) !void {
     const ty_op = self.air.instructions.items(.data)[inst].ty_op;
     const result: MCValue = if (self.liveness.isUnused(inst))
src/codegen/c.zig
@@ -1125,8 +1125,9 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
             .shl, .shl_exact => try airBinOp(f, inst, " << "),
             .not             => try airNot  (f, inst),
 
-            .optional_payload     => try airOptionalPayload(f, inst),
-            .optional_payload_ptr => try airOptionalPayload(f, inst),
+            .optional_payload         => try airOptionalPayload(f, inst),
+            .optional_payload_ptr     => try airOptionalPayload(f, inst),
+            .optional_payload_ptr_set => try airOptionalPayloadPtrSet(f, inst),
 
             .is_err          => try airIsErr(f, inst, "", ".", "!="),
             .is_non_err      => try airIsErr(f, inst, "", ".", "=="),
@@ -2218,6 +2219,33 @@ fn airOptionalPayload(f: *Function, inst: Air.Inst.Index) !CValue {
     return local;
 }
 
+fn airOptionalPayloadPtrSet(f: *Function, inst: Air.Inst.Index) !CValue {
+    const ty_op = f.air.instructions.items(.data)[inst].ty_op;
+    const writer = f.object.writer();
+    const operand = try f.resolveInst(ty_op.operand);
+    const operand_ty = f.air.typeOf(ty_op.operand);
+
+    const opt_ty = operand_ty.elemType();
+
+    if (opt_ty.isPtrLikeOptional()) {
+        // The payload and the optional are the same value.
+        // Setting to non-null will be done when the payload is set.
+        return operand;
+    }
+
+    try writer.writeAll("(");
+    try f.writeCValue(writer, operand);
+    try writer.writeAll(")->is_null = false;\n");
+
+    const inst_ty = f.air.typeOfIndex(inst);
+    const local = try f.allocLocal(inst_ty, .Const);
+    try writer.writeAll(" = &(");
+    try f.writeCValue(writer, operand);
+
+    try writer.writeAll(")->payload;\n");
+    return local;
+}
+
 fn airStructFieldPtr(f: *Function, inst: Air.Inst.Index) !CValue {
     if (f.liveness.isUnused(inst))
         // TODO this @as is needed because of a stage1 bug
src/codegen/llvm.zig
@@ -1796,8 +1796,9 @@ pub const FuncGen = struct {
                 .ptr_elem_val       => try self.airPtrElemVal(inst),
                 .ptr_elem_ptr       => try self.airPtrElemPtr(inst),
 
-                .optional_payload     => try self.airOptionalPayload(inst),
-                .optional_payload_ptr => try self.airOptionalPayloadPtr(inst),
+                .optional_payload         => try self.airOptionalPayload(inst),
+                .optional_payload_ptr     => try self.airOptionalPayloadPtr(inst),
+                .optional_payload_ptr_set => try self.airOptionalPayloadPtrSet(inst),
 
                 .unwrap_errunion_payload     => try self.airErrUnionPayload(inst, false),
                 .unwrap_errunion_payload_ptr => try self.airErrUnionPayload(inst, true),
@@ -2572,6 +2573,41 @@ pub const FuncGen = struct {
         return self.builder.buildInBoundsGEP(operand, &indices, indices.len, "");
     }
 
+    fn airOptionalPayloadPtrSet(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+        const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+        const operand = try self.resolveInst(ty_op.operand);
+        const optional_ty = self.air.typeOf(ty_op.operand).childType();
+        var buf: Type.Payload.ElemType = undefined;
+        const payload_ty = optional_ty.optionalChild(&buf);
+        const non_null_bit = self.context.intType(1).constAllOnes();
+        if (!payload_ty.hasCodeGenBits()) {
+            // We have a pointer to a i1. We need to set it to 1 and then return the same pointer.
+            _ = self.builder.buildStore(non_null_bit, operand);
+            return operand;
+        }
+        if (optional_ty.isPtrLikeOptional()) {
+            // The payload and the optional are the same value.
+            // Setting to non-null will be done when the payload is set.
+            return operand;
+        }
+        const index_type = self.context.intType(32);
+        {
+            // First set the non-null bit.
+            const indices: [2]*const llvm.Value = .{
+                index_type.constNull(), // dereference the pointer
+                index_type.constInt(1, .False), // second field is the payload
+            };
+            const non_null_ptr = self.builder.buildInBoundsGEP(operand, &indices, indices.len, "");
+            _ = self.builder.buildStore(non_null_bit, non_null_ptr);
+        }
+        // Then return the payload pointer.
+        const indices: [2]*const llvm.Value = .{
+            index_type.constNull(), // dereference the pointer
+            index_type.constNull(), // first field is the payload
+        };
+        return self.builder.buildInBoundsGEP(operand, &indices, indices.len, "");
+    }
+
     fn airOptionalPayload(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
         if (self.liveness.isUnused(inst)) return null;
 
src/codegen/wasm.zig
@@ -881,6 +881,7 @@ pub const Context = struct {
 
             .optional_payload => self.airOptionalPayload(inst),
             .optional_payload_ptr => self.airOptionalPayload(inst),
+            .optional_payload_ptr_set => self.airOptionalPayloadPtrSet(inst),
             else => |tag| self.fail("TODO: Implement wasm inst: {s}", .{@tagName(tag)}),
         };
     }
@@ -1702,6 +1703,13 @@ pub const Context = struct {
         return WValue{ .local = operand.multi_value.index + 1 };
     }
 
+    fn airOptionalPayloadPtrSet(self: *Context, inst: Air.Inst.Index) InnerError!WValue {
+        const ty_op = self.air.instructions.items(.data)[inst].ty_op;
+        const operand = self.resolveInst(ty_op.operand);
+        _ = operand;
+        return self.fail("TODO - wasm codegen for optional_payload_ptr_set", .{});
+    }
+
     fn airWrapOptional(self: *Context, inst: Air.Inst.Index) InnerError!WValue {
         const ty_op = self.air.instructions.items(.data)[inst].ty_op;
         return self.resolveInst(ty_op.operand);
src/Air.zig
@@ -336,6 +336,9 @@ pub const Inst = struct {
         /// *?T => *T. If the value is null, undefined behavior.
         /// Uses the `ty_op` field.
         optional_payload_ptr,
+        /// *?T => *T. Sets the value to non-null with an undefined payload value.
+        /// Uses the `ty_op` field.
+        optional_payload_ptr_set,
         /// Given a payload value, wraps it in an optional type.
         /// Uses the `ty_op` field.
         wrap_optional,
@@ -728,6 +731,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
         .trunc,
         .optional_payload,
         .optional_payload_ptr,
+        .optional_payload_ptr_set,
         .wrap_optional,
         .unwrap_errunion_payload,
         .unwrap_errunion_err,
src/Liveness.zig
@@ -292,6 +292,7 @@ fn analyzeInst(
         .trunc,
         .optional_payload,
         .optional_payload_ptr,
+        .optional_payload_ptr_set,
         .wrap_optional,
         .unwrap_errunion_payload,
         .unwrap_errunion_err,
src/print_air.zig
@@ -175,6 +175,7 @@ const Writer = struct {
             .trunc,
             .optional_payload,
             .optional_payload_ptr,
+            .optional_payload_ptr_set,
             .wrap_optional,
             .unwrap_errunion_payload,
             .unwrap_errunion_err,
src/Sema.zig
@@ -1414,9 +1414,10 @@ fn zirCoerceResultPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE
     const pointee_ty = try sema.resolveType(block, src, bin_inst.lhs);
     const ptr = sema.resolveInst(bin_inst.rhs);
 
+    const addr_space = target_util.defaultAddressSpace(sema.mod.getTarget(), .local);
     const ptr_ty = try Type.ptr(sema.arena, .{
         .pointee_type = pointee_ty,
-        .@"addrspace" = target_util.defaultAddressSpace(sema.mod.getTarget(), .local),
+        .@"addrspace" = addr_space,
     });
 
     if (Air.refToIndex(ptr)) |ptr_inst| {
@@ -1430,8 +1431,15 @@ fn zirCoerceResultPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE
                     // for the inferred allocation.
                     // This instruction will not make it to codegen; it is only to participate
                     // in the `stored_inst_list` of the `inferred_alloc`.
-                    const operand = try block.addBitCast(pointee_ty, .void_value);
+                    var trash_block = block.makeSubBlock();
+                    defer trash_block.instructions.deinit(sema.gpa);
+                    const operand = try trash_block.addBitCast(pointee_ty, .void_value);
+
                     try inferred_alloc.stored_inst_list.append(sema.arena, operand);
+
+                    try sema.requireRuntimeBlock(block, src);
+                    const bitcasted_ptr = try block.addBitCast(ptr_ty, ptr);
+                    return bitcasted_ptr;
                 },
                 .inferred_alloc_comptime => {
                     const iac = ptr_val.castTag(.inferred_alloc_comptime).?;
@@ -1456,9 +1464,78 @@ fn zirCoerceResultPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE
             }
         }
     }
+
     try sema.requireRuntimeBlock(block, src);
-    const bitcasted_ptr = try block.addBitCast(ptr_ty, ptr);
-    return bitcasted_ptr;
+
+    // Make a dummy store through the pointer to test the coercion.
+    // We will then use the generated instructions to decide what
+    // kind of transformations to make on the result pointer.
+    var trash_block = block.makeSubBlock();
+    defer trash_block.instructions.deinit(sema.gpa);
+
+    const dummy_operand = try trash_block.addBitCast(pointee_ty, .void_value);
+    try sema.storePtr(&trash_block, src, ptr, dummy_operand);
+
+    {
+        const air_tags = sema.air_instructions.items(.tag);
+
+        //std.debug.print("dummy storePtr instructions:\n", .{});
+        //for (trash_block.instructions.items) |item| {
+        //    std.debug.print("  {s}\n", .{@tagName(air_tags[item])});
+        //}
+
+        // The last one is always `store`.
+        const trash_inst = trash_block.instructions.pop();
+        assert(air_tags[trash_inst] == .store);
+        assert(trash_inst == sema.air_instructions.len - 1);
+        sema.air_instructions.len -= 1;
+    }
+
+    var new_ptr = ptr;
+
+    while (true) {
+        const air_tags = sema.air_instructions.items(.tag);
+        const air_datas = sema.air_instructions.items(.data);
+        const trash_inst = trash_block.instructions.pop();
+        switch (air_tags[trash_inst]) {
+            .bitcast => {
+                if (Air.indexToRef(trash_inst) == dummy_operand) {
+                    return block.addBitCast(ptr_ty, new_ptr);
+                }
+                const ty_op = air_datas[trash_inst].ty_op;
+                const operand_ty = sema.getTmpAir().typeOf(ty_op.operand);
+                const ptr_operand_ty = try Type.ptr(sema.arena, .{
+                    .pointee_type = operand_ty,
+                    .@"addrspace" = addr_space,
+                });
+                new_ptr = try block.addBitCast(ptr_operand_ty, new_ptr);
+            },
+            .wrap_optional => {
+                const ty_op = air_datas[trash_inst].ty_op;
+                const payload_ty = sema.getTmpAir().typeOf(ty_op.operand);
+                const ptr_payload_ty = try Type.ptr(sema.arena, .{
+                    .pointee_type = payload_ty,
+                    .@"addrspace" = addr_space,
+                });
+                new_ptr = try block.addTyOp(.optional_payload_ptr_set, ptr_payload_ty, new_ptr);
+            },
+            .wrap_errunion_err => {
+                return sema.fail(block, src, "TODO coerce_result_ptr wrap_errunion_err", .{});
+            },
+            .wrap_errunion_payload => {
+                return sema.fail(block, src, "TODO coerce_result_ptr wrap_errunion_payload", .{});
+            },
+            else => {
+                if (std.debug.runtime_safety) {
+                    std.debug.panic("unexpected AIR tag for coerce_result_ptr: {s}", .{
+                        air_tags[trash_inst],
+                    });
+                } else {
+                    unreachable;
+                }
+            },
+        }
+    } else unreachable; // TODO should not need else unreachable
 }
 
 pub fn analyzeStructDecl(
@@ -2365,7 +2442,13 @@ fn validateUnionInit(
             // Otherwise, the bitcast should be preserved and a store instruction should be
             // emitted to store the constant union value through the bitcast.
         },
-        else => unreachable,
+        else => |t| {
+            if (std.debug.runtime_safety) {
+                std.debug.panic("unexpected AIR tag for union pointer: {s}", .{@tagName(t)});
+            } else {
+                unreachable;
+            }
+        },
     }
 
     // Otherwise, we set the new union tag now.
test/behavior/optional.zig
@@ -73,3 +73,33 @@ test "optional with void type" {
     var x = Foo{ .x = null };
     try expect(x.x == null);
 }
+
+test "address of unwrap optional" {
+    const S = struct {
+        const Foo = struct {
+            a: i32,
+        };
+
+        var global: ?Foo = null;
+
+        pub fn getFoo() anyerror!*Foo {
+            return &global.?;
+        }
+    };
+    S.global = S.Foo{ .a = 1234 };
+    const foo = S.getFoo() catch unreachable;
+    try expect(foo.a == 1234);
+}
+
+test "nested optional field in struct" {
+    const S2 = struct {
+        y: u8,
+    };
+    const S1 = struct {
+        x: ?S2,
+    };
+    var s = S1{
+        .x = S2{ .y = 127 },
+    };
+    try expect(s.x.?.y == 127);
+}
test/behavior/optional_stage1.zig
@@ -3,23 +3,6 @@ const testing = std.testing;
 const expect = testing.expect;
 const expectEqual = testing.expectEqual;
 
-test "address of unwrap optional" {
-    const S = struct {
-        const Foo = struct {
-            a: i32,
-        };
-
-        var global: ?Foo = null;
-
-        pub fn getFoo() anyerror!*Foo {
-            return &global.?;
-        }
-    };
-    S.global = S.Foo{ .a = 1234 };
-    const foo = S.getFoo() catch unreachable;
-    try expect(foo.a == 1234);
-}
-
 test "equality compare optional with non-optional" {
     try test_cmp_optional_non_optional();
     comptime try test_cmp_optional_non_optional();
@@ -198,16 +181,3 @@ test "array of optional unaligned types" {
     i += 1;
     try expectEqual(Enum.three, values[i].?.Num);
 }
-
-test "nested optional field in struct" {
-    const S2 = struct {
-        y: u8,
-    };
-    const S1 = struct {
-        x: ?S2,
-    };
-    var s = S1{
-        .x = S2{ .y = 127 },
-    };
-    try expect(s.x.?.y == 127);
-}