Commit 5354d1f5fc

Andrew Kelley <superjoe30@gmail.com>
2018-07-13 18:34:42
allow == for comparing optional pointers
closes #658
1 parent ac096c2
Changed files (3)
src/codegen.cpp
@@ -2249,10 +2249,8 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
                 return LLVMBuildICmp(g->builder, pred, op1_value, op2_value, "");
             } else if (type_entry->id == TypeTableEntryIdEnum ||
                     type_entry->id == TypeTableEntryIdErrorSet ||
-                    type_entry->id == TypeTableEntryIdPointer ||
                     type_entry->id == TypeTableEntryIdBool ||
-                    type_entry->id == TypeTableEntryIdPromise ||
-                    type_entry->id == TypeTableEntryIdFn)
+                    get_codegen_ptr_type(type_entry) != nullptr)
             {
                 LLVMIntPredicate pred = cmp_op_to_int_predicate(op_id, false);
                 return LLVMBuildICmp(g->builder, pred, op1_value, op2_value, "");
src/ir.cpp
@@ -11147,7 +11147,7 @@ static TypeTableEntry *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp
     if (type_is_invalid(resolved_type))
         return resolved_type;
 
-
+    bool operator_allowed;
     switch (resolved_type->id) {
         case TypeTableEntryIdInvalid:
             zig_unreachable(); // handled above
@@ -11156,6 +11156,7 @@ static TypeTableEntry *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp
         case TypeTableEntryIdComptimeInt:
         case TypeTableEntryIdInt:
         case TypeTableEntryIdFloat:
+            operator_allowed = true;
             break;
 
         case TypeTableEntryIdBool:
@@ -11170,19 +11171,8 @@ static TypeTableEntry *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp
         case TypeTableEntryIdBoundFn:
         case TypeTableEntryIdArgTuple:
         case TypeTableEntryIdPromise:
-            if (!is_equality_cmp) {
-                ir_add_error_node(ira, source_node,
-                    buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name)));
-                return ira->codegen->builtin_types.entry_invalid;
-            }
-            break;
-
         case TypeTableEntryIdEnum:
-            if (!is_equality_cmp) {
-                ir_add_error_node(ira, source_node,
-                    buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name)));
-                return ira->codegen->builtin_types.entry_invalid;
-            }
+            operator_allowed = is_equality_cmp;
             break;
 
         case TypeTableEntryIdUnreachable:
@@ -11190,12 +11180,18 @@ static TypeTableEntry *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp
         case TypeTableEntryIdStruct:
         case TypeTableEntryIdUndefined:
         case TypeTableEntryIdNull:
-        case TypeTableEntryIdOptional:
         case TypeTableEntryIdErrorUnion:
         case TypeTableEntryIdUnion:
-            ir_add_error_node(ira, source_node,
-                buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name)));
-            return ira->codegen->builtin_types.entry_invalid;
+            operator_allowed = false;
+            break;
+        case TypeTableEntryIdOptional:
+            operator_allowed = is_equality_cmp && get_codegen_ptr_type(resolved_type) != nullptr;
+            break;
+    }
+    if (!operator_allowed) {
+        ir_add_error_node(ira, source_node,
+            buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name)));
+        return ira->codegen->builtin_types.entry_invalid;
     }
 
     IrInstruction *casted_op1 = ir_implicit_cast(ira, op1, resolved_type);
test/cases/optional.zig
@@ -7,3 +7,24 @@ test "optional pointer to size zero struct" {
     var o: ?*EmptyStruct = &e;
     assert(o != null);
 }
+
+test "equality compare nullable pointers" {
+    testNullPtrsEql();
+    comptime testNullPtrsEql();
+}
+
+fn testNullPtrsEql() void {
+    var number: i32 = 1234;
+
+    var x: ?*i32 = null;
+    var y: ?*i32 = null;
+    assert(x == y);
+    y = &number;
+    assert(x != y);
+    assert(x != &number);
+    assert(&number != x);
+    x = &number;
+    assert(x == y);
+    assert(x == &number);
+    assert(&number == x);
+}