Commit 85e427e4b2

Andrew Kelley <andrew@ziglang.org>
2021-11-16 21:59:53
stage2: equality compare optional with non-optional
1 parent 4eaf3c6
Changed files (5)
src/codegen/llvm/bindings.zig
@@ -98,7 +98,12 @@ pub const Value = opaque {
     extern fn LLVMAppendExistingBasicBlock(Fn: *const Value, BB: *const BasicBlock) void;
 
     pub const addIncoming = LLVMAddIncoming;
-    extern fn LLVMAddIncoming(PhiNode: *const Value, IncomingValues: [*]*const Value, IncomingBlocks: [*]*const BasicBlock, Count: c_uint) void;
+    extern fn LLVMAddIncoming(
+        PhiNode: *const Value,
+        IncomingValues: [*]const *const Value,
+        IncomingBlocks: [*]const *const BasicBlock,
+        Count: c_uint,
+    ) void;
 
     pub const getNextInstruction = LLVMGetNextInstruction;
     extern fn LLVMGetNextInstruction(Inst: *const Value) ?*const Value;
src/codegen/llvm.zig
@@ -1931,11 +1931,93 @@ pub const FuncGen = struct {
         const lhs = try self.resolveInst(bin_op.lhs);
         const rhs = try self.resolveInst(bin_op.rhs);
         const operand_ty = self.air.typeOf(bin_op.lhs);
-        var buffer: Type.Payload.Bits = undefined;
+
+        return self.cmp(lhs, rhs, operand_ty, op);
+    }
+
+    fn cmp(
+        self: *FuncGen,
+        lhs: *const llvm.Value,
+        rhs: *const llvm.Value,
+        operand_ty: Type,
+        op: math.CompareOperator,
+    ) *const llvm.Value {
+        var int_buffer: Type.Payload.Bits = undefined;
+        var opt_buffer: Type.Payload.ElemType = undefined;
 
         const int_ty = switch (operand_ty.zigTypeTag()) {
-            .Enum => operand_ty.intTagType(&buffer),
-            .Int, .Bool, .Pointer, .Optional, .ErrorSet => operand_ty,
+            .Enum => operand_ty.intTagType(&int_buffer),
+            .Int, .Bool, .Pointer, .ErrorSet => operand_ty,
+            .Optional => blk: {
+                const payload_ty = operand_ty.optionalChild(&opt_buffer);
+                if (!payload_ty.hasCodeGenBits() or operand_ty.isPtrLikeOptional()) {
+                    break :blk operand_ty;
+                }
+                // We need to emit instructions to check for equality/inequality
+                // of optionals that are not pointers.
+                const is_by_ref = isByRef(operand_ty);
+                const lhs_non_null = self.optIsNonNull(lhs, is_by_ref);
+                const rhs_non_null = self.optIsNonNull(rhs, is_by_ref);
+                const llvm_i2 = self.context.intType(2);
+                const lhs_non_null_i2 = self.builder.buildZExt(lhs_non_null, llvm_i2, "");
+                const rhs_non_null_i2 = self.builder.buildZExt(rhs_non_null, llvm_i2, "");
+                const lhs_shifted = self.builder.buildShl(lhs_non_null_i2, llvm_i2.constInt(1, .False), "");
+                const lhs_rhs_ored = self.builder.buildOr(lhs_shifted, rhs_non_null_i2, "");
+                const both_null_block = self.context.appendBasicBlock(self.llvm_func, "BothNull");
+                const mixed_block = self.context.appendBasicBlock(self.llvm_func, "Mixed");
+                const both_pl_block = self.context.appendBasicBlock(self.llvm_func, "BothNonNull");
+                const end_block = self.context.appendBasicBlock(self.llvm_func, "End");
+                const llvm_switch = self.builder.buildSwitch(lhs_rhs_ored, mixed_block, 2);
+                const llvm_i2_00 = llvm_i2.constInt(0b00, .False);
+                const llvm_i2_11 = llvm_i2.constInt(0b11, .False);
+                llvm_switch.addCase(llvm_i2_00, both_null_block);
+                llvm_switch.addCase(llvm_i2_11, both_pl_block);
+
+                self.builder.positionBuilderAtEnd(both_null_block);
+                _ = self.builder.buildBr(end_block);
+
+                self.builder.positionBuilderAtEnd(mixed_block);
+                _ = self.builder.buildBr(end_block);
+
+                self.builder.positionBuilderAtEnd(both_pl_block);
+                const lhs_payload = self.optPayloadHandle(lhs, is_by_ref);
+                const rhs_payload = self.optPayloadHandle(rhs, is_by_ref);
+                const payload_cmp = self.cmp(lhs_payload, rhs_payload, payload_ty, op);
+                _ = self.builder.buildBr(end_block);
+                const both_pl_block_end = self.builder.getInsertBlock();
+
+                self.builder.positionBuilderAtEnd(end_block);
+                const incoming_blocks: [3]*const llvm.BasicBlock = .{
+                    both_null_block,
+                    mixed_block,
+                    both_pl_block_end,
+                };
+                const llvm_i1 = self.context.intType(1);
+                const llvm_i1_0 = llvm_i1.constInt(0, .False);
+                const llvm_i1_1 = llvm_i1.constInt(1, .False);
+                const incoming_values: [3]*const llvm.Value = .{
+                    switch (op) {
+                        .eq => llvm_i1_1,
+                        .neq => llvm_i1_0,
+                        else => unreachable,
+                    },
+                    switch (op) {
+                        .eq => llvm_i1_0,
+                        .neq => llvm_i1_1,
+                        else => unreachable,
+                    },
+                    payload_cmp,
+                };
+
+                const phi_node = self.builder.buildPhi(llvm_i1, "");
+                comptime assert(incoming_values.len == incoming_blocks.len);
+                phi_node.addIncoming(
+                    &incoming_values,
+                    &incoming_blocks,
+                    incoming_values.len,
+                );
+                return phi_node;
+            },
             .Float => {
                 const operation: llvm.RealPredicate = switch (op) {
                     .eq => .OEQ,
@@ -2493,24 +2575,8 @@ pub const FuncGen = struct {
             }
         }
 
-        if (operand_is_ptr or isByRef(optional_ty)) {
-            const index_type = self.context.intType(32);
-
-            const indices: [2]*const llvm.Value = .{
-                index_type.constNull(),
-                index_type.constInt(1, .False),
-            };
-
-            const field_ptr = self.builder.buildInBoundsGEP(operand, &indices, indices.len, "");
-            const non_null_bit = self.builder.buildLoad(field_ptr, "");
-            if (invert) {
-                return self.builder.buildNot(non_null_bit, "");
-            } else {
-                return non_null_bit;
-            }
-        }
-
-        const non_null_bit = self.builder.buildExtractValue(operand, 1, "");
+        const is_by_ref = operand_is_ptr or isByRef(optional_ty);
+        const non_null_bit = self.optIsNonNull(operand, is_by_ref);
         if (invert) {
             return self.builder.buildNot(non_null_bit, "");
         } else {
@@ -2622,17 +2688,7 @@ pub const FuncGen = struct {
             return operand;
         }
 
-        if (isByRef(payload_ty)) {
-            // We have a pointer and we need to return a pointer to the first field.
-            const index_type = self.context.intType(32);
-            const indices: [2]*const llvm.Value = .{
-                index_type.constNull(), // dereference the pointer
-                index_type.constNull(), // first field is the payload
-            };
-            return self.builder.buildInBoundsGEP(operand, &indices, indices.len, "");
-        }
-
-        return self.builder.buildExtractValue(operand, 0, "");
+        return self.optPayloadHandle(operand, isByRef(payload_ty));
     }
 
     fn airErrUnionPayload(
@@ -3748,6 +3804,38 @@ pub const FuncGen = struct {
         }
     }
 
+    /// Assumes the optional is not pointer-like and payload has bits.
+    fn optIsNonNull(self: *FuncGen, opt_handle: *const llvm.Value, is_by_ref: bool) *const llvm.Value {
+        if (is_by_ref) {
+            const index_type = self.context.intType(32);
+
+            const indices: [2]*const llvm.Value = .{
+                index_type.constNull(),
+                index_type.constInt(1, .False),
+            };
+
+            const field_ptr = self.builder.buildInBoundsGEP(opt_handle, &indices, indices.len, "");
+            return self.builder.buildLoad(field_ptr, "");
+        }
+
+        return self.builder.buildExtractValue(opt_handle, 1, "");
+    }
+
+    /// Assumes the optional is not pointer-like and payload has bits.
+    fn optPayloadHandle(self: *FuncGen, opt_handle: *const llvm.Value, is_by_ref: bool) *const llvm.Value {
+        if (is_by_ref) {
+            // We have a pointer and we need to return a pointer to the first field.
+            const index_type = self.context.intType(32);
+            const indices: [2]*const llvm.Value = .{
+                index_type.constNull(), // dereference the pointer
+                index_type.constNull(), // first field is the payload
+            };
+            return self.builder.buildInBoundsGEP(opt_handle, &indices, indices.len, "");
+        }
+
+        return self.builder.buildExtractValue(opt_handle, 0, "");
+    }
+
     fn callFloor(self: *FuncGen, arg: *const llvm.Value, ty: Type) !*const llvm.Value {
         return self.callFloatUnary(arg, ty, "floor");
     }
src/type.zig
@@ -175,7 +175,11 @@ pub const Type = extern union {
             => false,
 
             .Pointer => is_equality_cmp or ty.isCPtr(),
-            .Optional => is_equality_cmp and ty.isPtrLikeOptional(),
+            .Optional => {
+                if (!is_equality_cmp) return false;
+                var buf: Payload.ElemType = undefined;
+                return ty.optionalChild(&buf).isSelfComparable(is_equality_cmp);
+            },
         };
     }
 
test/behavior/optional.zig
@@ -103,3 +103,37 @@ test "nested optional field in struct" {
     };
     try expect(s.x.?.y == 127);
 }
+
+test "equality compare optional with non-optional" {
+    try test_cmp_optional_non_optional();
+    comptime try test_cmp_optional_non_optional();
+}
+
+fn test_cmp_optional_non_optional() !void {
+    var ten: i32 = 10;
+    var opt_ten: ?i32 = 10;
+    var five: i32 = 5;
+    var int_n: ?i32 = null;
+
+    try expect(int_n != ten);
+    try expect(opt_ten == ten);
+    try expect(opt_ten != five);
+
+    // test evaluation is always lexical
+    // ensure that the optional isn't always computed before the non-optional
+    var mutable_state: i32 = 0;
+    _ = blk1: {
+        mutable_state += 1;
+        break :blk1 @as(?f64, 10.0);
+    } != blk2: {
+        try expect(mutable_state == 1);
+        break :blk2 @as(f64, 5.0);
+    };
+    _ = blk1: {
+        mutable_state += 1;
+        break :blk1 @as(f64, 10.0);
+    } != blk2: {
+        try expect(mutable_state == 2);
+        break :blk2 @as(?f64, 5.0);
+    };
+}
test/behavior/optional_stage1.zig
@@ -3,40 +3,6 @@ const testing = std.testing;
 const expect = testing.expect;
 const expectEqual = testing.expectEqual;
 
-test "equality compare optional with non-optional" {
-    try test_cmp_optional_non_optional();
-    comptime try test_cmp_optional_non_optional();
-}
-
-fn test_cmp_optional_non_optional() !void {
-    var ten: i32 = 10;
-    var opt_ten: ?i32 = 10;
-    var five: i32 = 5;
-    var int_n: ?i32 = null;
-
-    try expect(int_n != ten);
-    try expect(opt_ten == ten);
-    try expect(opt_ten != five);
-
-    // test evaluation is always lexical
-    // ensure that the optional isn't always computed before the non-optional
-    var mutable_state: i32 = 0;
-    _ = blk1: {
-        mutable_state += 1;
-        break :blk1 @as(?f64, 10.0);
-    } != blk2: {
-        try expect(mutable_state == 1);
-        break :blk2 @as(f64, 5.0);
-    };
-    _ = blk1: {
-        mutable_state += 1;
-        break :blk1 @as(f64, 10.0);
-    } != blk2: {
-        try expect(mutable_state == 2);
-        break :blk2 @as(?f64, 5.0);
-    };
-}
-
 test "unwrap function call with optional pointer return value" {
     const S = struct {
         fn entry() !void {