Commit 342bca7f46

Andrew Kelley <andrew@ziglang.org>
2019-02-11 21:31:09
C pointer comparison and arithmetic
See #1059
1 parent d9e01be
src/analyze.cpp
@@ -434,7 +434,7 @@ ZigType *get_pointer_to_type_extra(CodeGen *g, ZigType *child_type, bool is_cons
         uint32_t bit_offset_in_host, uint32_t host_int_bytes)
 {
     assert(!type_is_invalid(child_type));
-    assert(ptr_len == PtrLenSingle || child_type->id != ZigTypeIdOpaque);
+    assert(ptr_len != PtrLenUnknown || child_type->id != ZigTypeIdOpaque);
 
     if (byte_alignment != 0) {
         uint32_t abi_alignment = get_abi_alignment(g, child_type);
src/codegen.cpp
@@ -2657,7 +2657,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
         (op1->value.type->id == ZigTypeIdErrorSet && op2->value.type->id == ZigTypeIdErrorSet) ||
         (op1->value.type->id == ZigTypeIdPointer &&
             (op_id == IrBinOpAdd || op_id == IrBinOpSub) &&
-            op1->value.type->data.pointer.ptr_len == PtrLenUnknown)
+            op1->value.type->data.pointer.ptr_len != PtrLenSingle)
     );
     ZigType *operand_type = op1->value.type;
     ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type;
@@ -2716,7 +2716,7 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutable *executable,
                 AddSubMulMul;
 
             if (scalar_type->id == ZigTypeIdPointer) {
-                assert(scalar_type->data.pointer.ptr_len == PtrLenUnknown);
+                assert(scalar_type->data.pointer.ptr_len != PtrLenSingle);
                 LLVMValueRef subscript_value;
                 if (operand_type->id == ZigTypeIdVector)
                     zig_panic("TODO: Implement vector operations on pointers.");
src/ir.cpp
@@ -8943,7 +8943,9 @@ static void update_errors_helper(CodeGen *g, ErrorTableEntry ***errors, size_t *
     *errors = reallocate(*errors, old_errors_count, *errors_count);
 }
 
-static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigType *expected_type, IrInstruction **instructions, size_t instruction_count) {
+static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigType *expected_type,
+        IrInstruction **instructions, size_t instruction_count)
+{
     Error err;
     assert(instruction_count >= 1);
     IrInstruction *prev_inst = instructions[0];
@@ -9260,6 +9262,19 @@ static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigT
             continue;
         }
 
+        if (prev_type->id == ZigTypeIdPointer && prev_type->data.pointer.ptr_len == PtrLenC &&
+            (cur_type->id == ZigTypeIdComptimeInt || cur_type->id == ZigTypeIdInt))
+        {
+            continue;
+        }
+
+        if (cur_type->id == ZigTypeIdPointer && cur_type->data.pointer.ptr_len == PtrLenC &&
+            (prev_type->id == ZigTypeIdComptimeInt || prev_type->id == ZigTypeIdInt))
+        {
+            prev_inst = cur_inst;
+            continue;
+        }
+
         if (types_match_const_cast_only(ira, prev_type, cur_type, source_node, false).id == ConstCastResultIdOk) {
             continue;
         }
@@ -11852,7 +11867,6 @@ static IrInstruction *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp *
         case ZigTypeIdBool:
         case ZigTypeIdMetaType:
         case ZigTypeIdVoid:
-        case ZigTypeIdPointer:
         case ZigTypeIdErrorSet:
         case ZigTypeIdFn:
         case ZigTypeIdOpaque:
@@ -11864,6 +11878,10 @@ static IrInstruction *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp *
             operator_allowed = is_equality_cmp;
             break;
 
+        case ZigTypeIdPointer:
+            operator_allowed = is_equality_cmp || (resolved_type->data.pointer.ptr_len != PtrLenSingle);
+            break;
+
         case ZigTypeIdUnreachable:
         case ZigTypeIdArray:
         case ZigTypeIdStruct:
@@ -12324,6 +12342,26 @@ static bool ok_float_op(IrBinOp op) {
     zig_unreachable();
 }
 
+static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) {
+    if (lhs_type->id != ZigTypeIdPointer)
+        return false;
+    switch (op) {
+        case IrBinOpAdd:
+        case IrBinOpSub:
+            break;
+        default:
+            return false;
+    }
+    switch (lhs_type->data.pointer.ptr_len) {
+        case PtrLenSingle:
+            return false;
+        case PtrLenUnknown:
+        case PtrLenC:
+            break;
+    }
+    return true;
+}
+
 static IrInstruction *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstructionBinOp *instruction) {
     IrInstruction *op1 = instruction->op1->child;
     if (type_is_invalid(op1->value.type))
@@ -12336,9 +12374,7 @@ static IrInstruction *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstructionBinOp
     IrBinOp op_id = instruction->op_id;
 
     // look for pointer math
-    if (op1->value.type->id == ZigTypeIdPointer && op1->value.type->data.pointer.ptr_len == PtrLenUnknown &&
-        (op_id == IrBinOpAdd || op_id == IrBinOpSub))
-    {
+    if (is_pointer_arithmetic_allowed(op1->value.type, op_id)) {
         IrInstruction *casted_op2 = ir_implicit_cast(ira, op2, ira->codegen->builtin_types.entry_usize);
         if (casted_op2 == ira->codegen->invalid_instruction)
             return ira->codegen->invalid_instruction;
src/translate_c.cpp
@@ -1677,7 +1677,7 @@ static AstNode *trans_implicit_cast_expr(Context *c, TransScope *scope, const Im
                 return node;
             }
         case CK_NullToPointer:
-            return trans_create_node(c, NodeTypeNullLiteral);
+            return trans_create_node_unsigned(c, 0);
         case CK_Dependent:
             emit_warning(c, stmt->getLocStart(), "TODO handle C translation cast CK_Dependent");
             return nullptr;
@@ -2409,7 +2409,8 @@ static AstNode *trans_bool_expr(Context *c, ResultUsed result_used, TransScope *
                 case BuiltinType::Float16:
                     return trans_create_node_bin_op(c, res, BinOpTypeCmpNotEq, trans_create_node_unsigned_negative(c, 0, false));
                 case BuiltinType::NullPtr:
-                    return trans_create_node_bin_op(c, res, BinOpTypeCmpNotEq, trans_create_node(c, NodeTypeNullLiteral));
+                    return trans_create_node_bin_op(c, res, BinOpTypeCmpNotEq,
+                            trans_create_node_unsigned(c, 0));
 
                 case BuiltinType::Void:
                 case BuiltinType::Half:
@@ -2494,7 +2495,8 @@ static AstNode *trans_bool_expr(Context *c, ResultUsed result_used, TransScope *
             break;
         }
         case Type::Pointer:
-            return trans_create_node_bin_op(c, res, BinOpTypeCmpNotEq, trans_create_node(c, NodeTypeNullLiteral));
+            return trans_create_node_bin_op(c, res, BinOpTypeCmpNotEq,
+                    trans_create_node_unsigned(c, 0));
 
         case Type::Typedef:
         {
test/stage1/behavior/pointers.zig
@@ -56,3 +56,23 @@ test "implicit cast single item pointer to C pointer and back" {
     z.* += 1;
     expect(y == 12);
 }
+
+test "C pointer comparison and arithmetic" {
+    var one: usize = 1;
+    var ptr1: [*c]u8 = 0;
+    var ptr2 = ptr1 + 10;
+    expect(ptr1 == 0);
+    expect(ptr1 >= 0);
+    expect(ptr1 <= 0);
+    expect(ptr1 < 1);
+    expect(ptr1 < one);
+    expect(1 > ptr1);
+    expect(one > ptr1);
+    expect(ptr1 < ptr2);
+    expect(ptr2 > ptr1);
+    expect(ptr2 >= 10);
+    expect(ptr2 == 10);
+    expect(ptr2 <= 10);
+    ptr2 -= 10;
+    expect(ptr1 == ptr2);
+}
test/translate_c.zig
@@ -610,11 +610,11 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
     ,
         \\pub export fn and_or_none_bool(a: c_int, b: f32, c: [*c]c_void) c_int {
         \\    if ((a != 0) and (b != 0)) return 0;
-        \\    if ((b != 0) and (c != null)) return 1;
-        \\    if ((a != 0) and (c != null)) return 2;
+        \\    if ((b != 0) and (c != 0)) return 1;
+        \\    if ((a != 0) and (c != 0)) return 2;
         \\    if ((a != 0) or (b != 0)) return 3;
-        \\    if ((b != 0) or (c != null)) return 4;
-        \\    if ((a != 0) or (c != null)) return 5;
+        \\    if ((b != 0) or (c != 0)) return 4;
+        \\    if ((a != 0) or (c != 0)) return 5;
         \\    return 6;
         \\}
     );
@@ -778,7 +778,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\}
     ,
         \\pub export fn foo() [*c]c_int {
-        \\    return null;
+        \\    return 0;
         \\}
     );
 
@@ -1280,7 +1280,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\    return !(a == 0);
         \\    return !(a != 0);
         \\    return !(b != 0);
-        \\    return !(c != null);
+        \\    return !(c != 0);
         \\}
     );
 
@@ -1337,7 +1337,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\pub fn if_none_bool(a: c_int, b: f32, c: [*c]c_void, d: enum_SomeEnum) c_int {
         \\    if (a != 0) return 0;
         \\    if (b != 0) return 1;
-        \\    if (c != null) return 2;
+        \\    if (c != 0) return 2;
         \\    if (d != @bitCast(enum_SomeEnum, @TagType(enum_SomeEnum)(0))) return 3;
         \\    return 4;
         \\}
@@ -1354,7 +1354,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\pub fn while_none_bool(a: c_int, b: f32, c: [*c]c_void) c_int {
         \\    while (a != 0) return 0;
         \\    while (b != 0) return 1;
-        \\    while (c != null) return 2;
+        \\    while (c != 0) return 2;
         \\    return 3;
         \\}
     );
@@ -1370,7 +1370,7 @@ pub fn addCases(cases: *tests.TranslateCContext) void {
         \\pub fn for_none_bool(a: c_int, b: f32, c: [*c]c_void) c_int {
         \\    while (a != 0) return 0;
         \\    while (b != 0) return 1;
-        \\    while (c != null) return 2;
+        \\    while (c != 0) return 2;
         \\    return 3;
         \\}
     );