Commit 682cdeceaa

Andrew Kelley <andrew@ziglang.org>
2021-10-16 03:37:09
stage2: optional comparison and 0-bit payloads
* Sema: implement peer type resolution for optionals and null. * Rename `Module.optionalType` to `Type.optional`. * LLVM backend: re-use anonymous values. This is especially useful when isByRef()=true because it means re-using the same generated LLVM globals. * LLVM backend: rework the implementation of is_null and is_non_null AIR instructions. Generate slightly better LLVM code, and also fix the behavior for optionals whose payload type is 0-bit. * LLVM backend: improve `cmp` AIR instruction lowering to support pointer-like optionals. * `Value`: implement support for equality-checking optionals.
1 parent 186126c
src/codegen/llvm.zig
@@ -1430,7 +1430,7 @@ pub const FuncGen = struct {
 
     /// This stores the LLVM values used in a function, such that they can be referred to
     /// in other instructions. This table is cleared before every function is generated.
-    func_inst_table: std.AutoHashMapUnmanaged(Air.Inst.Index, *const llvm.Value),
+    func_inst_table: std.AutoHashMapUnmanaged(Air.Inst.Ref, *const llvm.Value),
 
     /// If the return type isByRef, this is the result pointer. Otherwise null.
     ret_ptr: ?*const llvm.Value,
@@ -1472,23 +1472,27 @@ pub const FuncGen = struct {
     }
 
     fn resolveInst(self: *FuncGen, inst: Air.Inst.Ref) !*const llvm.Value {
-        if (self.air.value(inst)) |val| {
-            const ty = self.air.typeOf(inst);
-            const llvm_val = try self.dg.genTypedValue(.{ .ty = ty, .val = val });
-            if (!isByRef(ty)) return llvm_val;
+        const gop = try self.func_inst_table.getOrPut(self.dg.gpa, inst);
+        if (gop.found_existing) return gop.value_ptr.*;
 
-            // We have an LLVM value but we need to create a global constant and
-            // set the value as its initializer, and then return a pointer to the global.
-            const target = self.dg.module.getTarget();
-            const global = self.dg.object.llvm_module.addGlobal(llvm_val.typeOf(), "");
-            global.setInitializer(llvm_val);
-            global.setLinkage(.Private);
-            global.setGlobalConstant(.True);
-            global.setAlignment(ty.abiAlignment(target));
-            return global;
+        const val = self.air.value(inst).?;
+        const ty = self.air.typeOf(inst);
+        const llvm_val = try self.dg.genTypedValue(.{ .ty = ty, .val = val });
+        if (!isByRef(ty)) {
+            gop.value_ptr.* = llvm_val;
+            return llvm_val;
         }
-        const inst_index = Air.refToIndex(inst).?;
-        return self.func_inst_table.get(inst_index).?;
+
+        // We have an LLVM value but we need to create a global constant and
+        // set the value as its initializer, and then return a pointer to the global.
+        const target = self.dg.module.getTarget();
+        const global = self.dg.object.llvm_module.addGlobal(llvm_val.typeOf(), "");
+        global.setInitializer(llvm_val);
+        global.setLinkage(.Private);
+        global.setGlobalConstant(.True);
+        global.setAlignment(ty.abiAlignment(target));
+        gop.value_ptr.* = global;
+        return global;
     }
 
     fn genBody(self: *FuncGen, body: []const Air.Inst.Index) Error!void {
@@ -1528,10 +1532,11 @@ pub const FuncGen = struct {
                 .cmp_lte => try self.airCmp(inst, .lte),
                 .cmp_neq => try self.airCmp(inst, .neq),
 
-                .is_non_null     => try self.airIsNonNull(inst, false),
-                .is_non_null_ptr => try self.airIsNonNull(inst, true),
-                .is_null         => try self.airIsNull(inst, false),
-                .is_null_ptr     => try self.airIsNull(inst, true),
+                .is_non_null     => try self.airIsNonNull(inst, false, false, .NE),
+                .is_non_null_ptr => try self.airIsNonNull(inst, true , false, .NE),
+                .is_null         => try self.airIsNonNull(inst, false, true , .EQ),
+                .is_null_ptr     => try self.airIsNonNull(inst, true , true , .EQ),
+
                 .is_non_err      => try self.airIsErr(inst, .EQ, false),
                 .is_non_err_ptr  => try self.airIsErr(inst, .EQ, true),
                 .is_err          => try self.airIsErr(inst, .NE, false),
@@ -1618,7 +1623,10 @@ pub const FuncGen = struct {
                 },
                 // zig fmt: on
             };
-            if (opt_value) |val| try self.func_inst_table.putNoClobber(self.gpa, inst, val);
+            if (opt_value) |val| {
+                const ref = Air.indexToRef(inst);
+                try self.func_inst_table.putNoClobber(self.gpa, ref, val);
+            }
         }
     }
 
@@ -1722,8 +1730,7 @@ pub const FuncGen = struct {
     }
 
     fn airCmp(self: *FuncGen, inst: Air.Inst.Index, op: math.CompareOperator) !?*const llvm.Value {
-        if (self.liveness.isUnused(inst))
-            return null;
+        if (self.liveness.isUnused(inst)) return null;
 
         const bin_op = self.air.instructions.items(.data)[inst].bin_op;
         const lhs = try self.resolveInst(bin_op.lhs);
@@ -1733,7 +1740,7 @@ pub const FuncGen = struct {
 
         const int_ty = switch (operand_ty.zigTypeTag()) {
             .Enum => operand_ty.intTagType(&buffer),
-            .Int, .Bool, .Pointer, .ErrorSet => operand_ty,
+            .Int, .Bool, .Pointer, .Optional, .ErrorSet => operand_ty,
             .Float => {
                 const operation: llvm.RealPredicate = switch (op) {
                     .eq => .OEQ,
@@ -2227,45 +2234,57 @@ pub const FuncGen = struct {
         );
     }
 
-    fn airIsNonNull(self: *FuncGen, inst: Air.Inst.Index, operand_is_ptr: bool) !?*const llvm.Value {
-        if (self.liveness.isUnused(inst))
-            return null;
+    fn airIsNonNull(
+        self: *FuncGen,
+        inst: Air.Inst.Index,
+        operand_is_ptr: bool,
+        invert: bool,
+        pred: llvm.IntPredicate,
+    ) !?*const llvm.Value {
+        if (self.liveness.isUnused(inst)) return null;
 
         const un_op = self.air.instructions.items(.data)[inst].un_op;
         const operand = try self.resolveInst(un_op);
-
-        if (operand_is_ptr) {
-            const operand_ty = self.air.typeOf(un_op).elemType();
-            if (operand_ty.isPtrLikeOptional()) {
-                const operand_llvm_ty = try self.dg.llvmType(operand_ty);
-                const loaded = self.builder.buildLoad(operand, "");
-                return self.builder.buildICmp(.NE, loaded, operand_llvm_ty.constNull(), "");
+        const operand_ty = self.air.typeOf(un_op);
+        const optional_ty = if (operand_is_ptr) operand_ty.childType() else operand_ty;
+        var buf: Type.Payload.ElemType = undefined;
+        const payload_ty = optional_ty.optionalChild(&buf);
+        if (!payload_ty.hasCodeGenBits()) {
+            if (invert) {
+                return self.builder.buildNot(operand, "");
+            } else {
+                return operand;
             }
+        }
+        if (optional_ty.isPtrLikeOptional()) {
+            const optional_llvm_ty = try self.dg.llvmType(optional_ty);
+            const loaded = if (operand_is_ptr) self.builder.buildLoad(operand, "") else operand;
+            return self.builder.buildICmp(pred, loaded, optional_llvm_ty.constNull(), "");
+        }
 
+        if (operand_is_ptr or isByRef(optional_ty)) {
             const index_type = self.context.intType(32);
 
-            var indices: [2]*const llvm.Value = .{
+            const indices: [2]*const llvm.Value = .{
                 index_type.constNull(),
                 index_type.constInt(1, .False),
             };
 
-            return self.builder.buildLoad(self.builder.buildInBoundsGEP(operand, &indices, indices.len, ""), "");
+            const field_ptr = self.builder.buildInBoundsGEP(operand, &indices, indices.len, "");
+            const non_null_bit = self.builder.buildLoad(field_ptr, "");
+            if (invert) {
+                return self.builder.buildNot(non_null_bit, "");
+            } else {
+                return non_null_bit;
+            }
         }
 
-        const operand_ty = self.air.typeOf(un_op);
-        if (operand_ty.isPtrLikeOptional()) {
-            const operand_llvm_ty = try self.dg.llvmType(operand_ty);
-            return self.builder.buildICmp(.NE, operand, operand_llvm_ty.constNull(), "");
+        const non_null_bit = self.builder.buildExtractValue(operand, 1, "");
+        if (invert) {
+            return self.builder.buildNot(non_null_bit, "");
+        } else {
+            return non_null_bit;
         }
-
-        return self.builder.buildExtractValue(operand, 1, "");
-    }
-
-    fn airIsNull(self: *FuncGen, inst: Air.Inst.Index, operand_is_ptr: bool) !?*const llvm.Value {
-        if (self.liveness.isUnused(inst))
-            return null;
-
-        return self.builder.buildNot((try self.airIsNonNull(inst, operand_is_ptr)).?, "");
     }
 
     fn airIsErr(
src/Module.zig
@@ -4249,20 +4249,6 @@ pub fn errNoteNonLazy(
     };
 }
 
-pub fn optionalType(arena: *Allocator, child_type: Type) Allocator.Error!Type {
-    switch (child_type.tag()) {
-        .single_const_pointer => return Type.Tag.optional_single_const_pointer.create(
-            arena,
-            child_type.elemType(),
-        ),
-        .single_mut_pointer => return Type.Tag.optional_single_mut_pointer.create(
-            arena,
-            child_type.elemType(),
-        ),
-        else => return Type.Tag.optional.create(arena, child_type),
-    }
-}
-
 pub fn errorUnionType(
     arena: *Allocator,
     error_set: Type,
src/Sema.zig
@@ -4108,7 +4108,7 @@ fn zirOptionalType(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileErro
     const inst_data = sema.code.instructions.items(.data)[inst].un_node;
     const src = inst_data.src();
     const child_type = try sema.resolveType(block, src, inst_data.operand);
-    const opt_type = try Module.optionalType(sema.arena, child_type);
+    const opt_type = try Type.optional(sema.arena, child_type);
 
     return sema.addType(opt_type);
 }
@@ -9675,7 +9675,7 @@ fn zirCmpxchg(
         return sema.fail(block, failure_order_src, "failure atomic ordering must not be Release or AcqRel", .{});
     }
 
-    const result_ty = try Module.optionalType(sema.arena, elem_ty);
+    const result_ty = try Type.optional(sema.arena, elem_ty);
 
     // special case zero bit types
     if ((try sema.typeHasOnePossibleValue(block, elem_ty_src, elem_ty)) != null) {
@@ -10517,7 +10517,7 @@ fn panicWithMsg(
         .@"addrspace" = target_util.defaultAddressSpace(mod.getTarget(), .global_constant), // TODO might need a place that is more dynamic
     });
     const null_stack_trace = try sema.addConstant(
-        try Module.optionalType(arena, ptr_stack_trace_ty),
+        try Type.optional(arena, ptr_stack_trace_ty),
         Value.initTag(.null_value),
     );
     const args = try arena.create([2]Air.Inst.Ref);
@@ -12797,6 +12797,7 @@ fn resolvePeerTypes(
     const target = sema.mod.getTarget();
 
     var chosen = instructions[0];
+    var any_are_null = false;
     var chosen_i: usize = 0;
     for (instructions[1..]) |candidate, candidate_i| {
         const candidate_ty = sema.typeOf(candidate);
@@ -12878,6 +12879,44 @@ fn resolvePeerTypes(
             continue;
         }
 
+        if (chosen_ty_tag == .Null) {
+            any_are_null = true;
+            chosen = candidate;
+            chosen_i = candidate_i + 1;
+            continue;
+        }
+        if (candidate_ty_tag == .Null) {
+            any_are_null = true;
+            continue;
+        }
+
+        if (chosen_ty_tag == .Optional) {
+            var opt_child_buf: Type.Payload.ElemType = undefined;
+            const opt_child_ty = chosen_ty.optionalChild(&opt_child_buf);
+            if (coerceInMemoryAllowed(opt_child_ty, candidate_ty, false, target) == .ok) {
+                continue;
+            }
+            if (coerceInMemoryAllowed(candidate_ty, opt_child_ty, false, target) == .ok) {
+                any_are_null = true;
+                chosen = candidate;
+                chosen_i = candidate_i + 1;
+                continue;
+            }
+        }
+        if (candidate_ty_tag == .Optional) {
+            var opt_child_buf: Type.Payload.ElemType = undefined;
+            const opt_child_ty = candidate_ty.optionalChild(&opt_child_buf);
+            if (coerceInMemoryAllowed(opt_child_ty, chosen_ty, false, target) == .ok) {
+                chosen = candidate;
+                chosen_i = candidate_i + 1;
+                continue;
+            }
+            if (coerceInMemoryAllowed(chosen_ty, opt_child_ty, false, target) == .ok) {
+                any_are_null = true;
+                continue;
+            }
+        }
+
         // At this point, we hit a compile error. We need to recover
         // the source locations.
         const chosen_src = candidate_srcs.resolve(
@@ -12906,7 +12945,16 @@ fn resolvePeerTypes(
         return sema.failWithOwnedErrorMsg(msg);
     }
 
-    return sema.typeOf(chosen);
+    const chosen_ty = sema.typeOf(chosen);
+
+    if (any_are_null) {
+        switch (chosen_ty.zigTypeTag()) {
+            .Null, .Optional => return chosen_ty,
+            else => return Type.optional(sema.arena, chosen_ty),
+        }
+    }
+
+    return chosen_ty;
 }
 
 pub fn resolveTypeLayout(
src/type.zig
@@ -4031,6 +4031,20 @@ pub const Type = extern union {
         });
     }
 
+    pub fn optional(arena: *Allocator, child_type: Type) Allocator.Error!Type {
+        switch (child_type.tag()) {
+            .single_const_pointer => return Type.Tag.optional_single_const_pointer.create(
+                arena,
+                child_type.elemType(),
+            ),
+            .single_mut_pointer => return Type.Tag.optional_single_mut_pointer.create(
+                arena,
+                child_type.elemType(),
+            ),
+            else => return Type.Tag.optional.create(arena, child_type),
+        }
+    }
+
     pub fn smallestUnsignedBits(max: u64) u16 {
         if (max == 0) return 0;
         const base = std.math.log2(max);
src/value.zig
@@ -1365,12 +1365,20 @@ pub const Value = extern union {
                     const b_field_index = b.castTag(.enum_field_index).?.data;
                     return a_field_index == b_field_index;
                 },
+                .opt_payload => {
+                    const a_payload = a.castTag(.opt_payload).?.data;
+                    const b_payload = b.castTag(.opt_payload).?.data;
+                    var buffer: Type.Payload.ElemType = undefined;
+                    return eql(a_payload, b_payload, ty.optionalChild(&buffer));
+                },
                 .elem_ptr => @panic("TODO: Implement more pointer eql cases"),
                 .field_ptr => @panic("TODO: Implement more pointer eql cases"),
                 .eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
                 .opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
                 else => {},
             }
+        } else if (a_tag == .null_value or b_tag == .null_value) {
+            return false;
         }
 
         if (a.pointerDecl()) |a_decl| {
test/behavior/optional.zig
@@ -44,3 +44,32 @@ test "optional pointer to size zero struct" {
     var o: ?*EmptyStruct = &e;
     try expect(o != null);
 }
+
+test "equality compare optional pointers" {
+    try testNullPtrsEql();
+    comptime try testNullPtrsEql();
+}
+
+fn testNullPtrsEql() !void {
+    var number: i32 = 1234;
+
+    var x: ?*i32 = null;
+    var y: ?*i32 = null;
+    try expect(x == y);
+    y = &number;
+    try expect(x != y);
+    try expect(x != &number);
+    try expect(&number != x);
+    x = &number;
+    try expect(x == y);
+    try expect(x == &number);
+    try expect(&number == x);
+}
+
+test "optional with void type" {
+    const Foo = struct {
+        x: ?void,
+    };
+    var x = Foo{ .x = null };
+    try expect(x.x == null);
+}
test/behavior/optional_stage1.zig
@@ -3,27 +3,6 @@ const testing = std.testing;
 const expect = testing.expect;
 const expectEqual = testing.expectEqual;
 
-test "equality compare nullable pointers" {
-    try testNullPtrsEql();
-    comptime try testNullPtrsEql();
-}
-
-fn testNullPtrsEql() !void {
-    var number: i32 = 1234;
-
-    var x: ?*i32 = null;
-    var y: ?*i32 = null;
-    try expect(x == y);
-    y = &number;
-    try expect(x != y);
-    try expect(x != &number);
-    try expect(&number != x);
-    x = &number;
-    try expect(x == y);
-    try expect(x == &number);
-    try expect(&number == x);
-}
-
 test "address of unwrap optional" {
     const S = struct {
         const Foo = struct {
@@ -143,14 +122,6 @@ test "coerce an anon struct literal to optional struct" {
     comptime try S.doTheTest();
 }
 
-test "optional with void type" {
-    const Foo = struct {
-        x: ?void,
-    };
-    var x = Foo{ .x = null };
-    try expect(x.x == null);
-}
-
 test "0-bit child type coerced to optional return ptr result location" {
     const S = struct {
         fn doTheTest() !void {