Commit 8ad75a9bf3

Andrew Kelley <superjoe30@gmail.com>
2016-04-24 21:09:51
add compile error for invalid equality operator uses
See #145
1 parent 46ab981
Changed files (2)
src/analyze.cpp
@@ -3014,20 +3014,52 @@ static TypeTableEntry *analyze_bool_bin_op_expr(CodeGen *g, ImportTableEntry *im
     TypeTableEntry *resolved_type = resolve_peer_type_compatibility(g, import, context, node,
             op_nodes, op_types, 2);
 
-    bool type_can_gt_lt_cmp = (resolved_type->id == TypeTableEntryIdNumLitFloat ||
-            resolved_type->id == TypeTableEntryIdNumLitInt ||
-            resolved_type->id == TypeTableEntryIdFloat ||
-            resolved_type->id == TypeTableEntryIdInt);
+    bool is_equality_cmp = (bin_op_type == BinOpTypeCmpEq || bin_op_type == BinOpTypeCmpNotEq);
 
-    if (resolved_type->id == TypeTableEntryIdInvalid) {
-        return g->builtin_types.entry_invalid;
-    } else if (bin_op_type != BinOpTypeCmpEq &&
-               bin_op_type != BinOpTypeCmpNotEq &&
-               !type_can_gt_lt_cmp)
-    {
-        add_node_error(g, node,
-            buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name)));
-        return g->builtin_types.entry_invalid;
+    switch (resolved_type->id) {
+        case TypeTableEntryIdInvalid:
+            return g->builtin_types.entry_invalid;
+
+        case TypeTableEntryIdNumLitFloat:
+        case TypeTableEntryIdNumLitInt:
+        case TypeTableEntryIdInt:
+        case TypeTableEntryIdFloat:
+            break;
+
+        case TypeTableEntryIdBool:
+        case TypeTableEntryIdMetaType:
+        case TypeTableEntryIdVoid:
+        case TypeTableEntryIdPointer:
+        case TypeTableEntryIdPureError:
+        case TypeTableEntryIdFn:
+        case TypeTableEntryIdTypeDecl:
+        case TypeTableEntryIdNamespace:
+        case TypeTableEntryIdGenericFn:
+            if (!is_equality_cmp) {
+                add_node_error(g, node,
+                    buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name)));
+                return g->builtin_types.entry_invalid;
+            }
+            break;
+
+        case TypeTableEntryIdEnum:
+            if (!is_equality_cmp || resolved_type->data.enumeration.gen_field_count != 0) {
+                add_node_error(g, node,
+                    buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name)));
+                return g->builtin_types.entry_invalid;
+            }
+            break;
+
+        case TypeTableEntryIdUnreachable:
+        case TypeTableEntryIdArray:
+        case TypeTableEntryIdStruct:
+        case TypeTableEntryIdUndefLit:
+        case TypeTableEntryIdMaybe:
+        case TypeTableEntryIdErrorUnion:
+        case TypeTableEntryIdUnion:
+            add_node_error(g, node,
+                buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name)));
+            return g->builtin_types.entry_invalid;
     }
 
     ConstExprValue *op1_val = &get_resolved_expr(*op1)->const_val;
test/run_tests.cpp
@@ -1218,6 +1218,21 @@ fn test_a_thing() {
     bad_fn_call();
 }
     )SOURCE", 1, ".tmp_source.zig:6:5: error: use of undeclared identifier 'bad_fn_call'");
+
+    add_compile_fail_case("illegal comparison of types", R"SOURCE(
+fn bad_eql_1(a: []u8, b: []u8) -> bool {
+    a == b
+}
+enum EnumWithData {
+    One,
+    Two: i32,
+}
+fn bad_eql_2(a: EnumWithData, b: EnumWithData) -> bool {
+    a == b
+}
+    )SOURCE", 2,
+            ".tmp_source.zig:3:7: error: operator not allowed for type '[]u8'",
+            ".tmp_source.zig:10:7: error: operator not allowed for type 'EnumWithData'");
 }
 
 //////////////////////////////////////////////////////////////////////////////