Commit e95e7651ec

Alex Rønne Petersen <alex@alexrp.com>
2024-10-31 09:18:46
llvm: Set null_pointer_is_valid attribute when accessing allowzero pointers.
This informs optimization passes that they shouldn't assume that a load from a null pointer invokes undefined behavior. Closes #15816.
1 parent 5c2e300
Changed files (1)
src
codegen
src/codegen/llvm.zig
@@ -1419,8 +1419,6 @@ pub const Object = struct {
             }
         }
 
-        function_index.setAttributes(try attributes.finish(&o.builder), &o.builder);
-
         const file, const subprogram = if (!wip.strip) debug_info: {
             const file = try o.getDebugFile(file_scope);
 
@@ -1517,6 +1515,17 @@ pub const Object = struct {
             else => |e| return e,
         };
 
+        // If we saw any loads or stores involving `allowzero` pointers, we need to mark the whole
+        // function as considering null pointers valid so that LLVM's optimizers don't remove these
+        // operations on the assumption that they're undefined behavior.
+        if (fg.allowzero_access) {
+            try attributes.addFnAttr(.null_pointer_is_valid, &o.builder);
+        } else {
+            _ = try attributes.removeFnAttr(.null_pointer_is_valid);
+        }
+
+        function_index.setAttributes(try attributes.finish(&o.builder), &o.builder);
+
         if (fg.fuzz) |*f| {
             {
                 const array_llvm_ty = try o.builder.arrayType(f.pcs.items.len, .i8);
@@ -4667,6 +4676,15 @@ pub const FuncGen = struct {
 
     disable_intrinsics: bool,
 
+    /// Have we seen loads or stores involving `allowzero` pointers?
+    allowzero_access: bool = false,
+
+    pub fn maybeMarkAllowZeroAccess(self: *FuncGen, info: InternPool.Key.PtrType) void {
+        // LLVM already considers null pointers to be valid in non-generic address spaces, so avoid
+        // pessimizing optimization for functions with accesses to such pointers.
+        if (info.flags.address_space == .generic and info.flags.is_allowzero) self.allowzero_access = true;
+    }
+
     const Fuzz = struct {
         counters_variable: Builder.Variable.Index,
         pcs: std.ArrayListUnmanaged(Builder.Constant),
@@ -6206,6 +6224,9 @@ pub const FuncGen = struct {
         const body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra.end..][0..extra.data.body_len]);
         const err_union_ty = self.typeOf(extra.data.ptr).childType(zcu);
         const is_unused = self.liveness.isUnused(inst);
+
+        self.maybeMarkAllowZeroAccess(self.typeOf(extra.data.ptr).ptrInfo(zcu));
+
         return lowerTry(self, err_union_ptr, body, err_union_ty, true, true, is_unused, err_cold);
     }
 
@@ -6751,10 +6772,14 @@ pub const FuncGen = struct {
             if (self.canElideLoad(body_tail))
                 return ptr;
 
+            self.maybeMarkAllowZeroAccess(slice_ty.ptrInfo(zcu));
+
             const elem_alignment = elem_ty.abiAlignment(zcu).toLlvm();
             return self.loadByRef(ptr, elem_ty, elem_alignment, .normal);
         }
 
+        self.maybeMarkAllowZeroAccess(slice_ty.ptrInfo(zcu));
+
         return self.load(ptr, slice_ty);
     }
 
@@ -6824,10 +6849,15 @@ pub const FuncGen = struct {
             &.{rhs}, "");
         if (isByRef(elem_ty, zcu)) {
             if (self.canElideLoad(body_tail)) return ptr;
+
+            self.maybeMarkAllowZeroAccess(ptr_ty.ptrInfo(zcu));
+
             const elem_alignment = elem_ty.abiAlignment(zcu).toLlvm();
             return self.loadByRef(ptr, elem_ty, elem_alignment, .normal);
         }
 
+        self.maybeMarkAllowZeroAccess(ptr_ty.ptrInfo(zcu));
+
         return self.load(ptr, ptr_ty);
     }
 
@@ -7235,6 +7265,8 @@ pub const FuncGen = struct {
                     }),
                 }
 
+                self.maybeMarkAllowZeroAccess(output_ty.ptrInfo(zcu));
+
                 // Pass any non-return outputs indirectly, if the constraint accepts a memory location
                 is_indirect.* = constraintAllowsMemory(constraint);
                 if (is_indirect.*) {
@@ -7341,10 +7373,11 @@ pub const FuncGen = struct {
 
             // In the case of indirect inputs, LLVM requires the callsite to have
             // an elementtype(<ty>) attribute.
-            llvm_param_attrs[llvm_param_i] = if (constraint[0] == '*')
-                try o.lowerPtrElemTy(if (is_by_ref) arg_ty else arg_ty.childType(zcu))
-            else
-                .none;
+            llvm_param_attrs[llvm_param_i] = if (constraint[0] == '*') blk: {
+                if (!is_by_ref) self.maybeMarkAllowZeroAccess(arg_ty.ptrInfo(zcu));
+
+                break :blk try o.lowerPtrElemTy(if (is_by_ref) arg_ty else arg_ty.childType(zcu));
+            } else .none;
 
             llvm_param_i += 1;
             total_i += 1;
@@ -7530,7 +7563,6 @@ pub const FuncGen = struct {
             if (output != .none) {
                 const output_ptr = try self.resolveInst(output);
                 const output_ptr_ty = self.typeOf(output);
-
                 const alignment = output_ptr_ty.ptrAlignment(zcu).toLlvm();
                 _ = try self.wip.store(.normal, output_value, output_ptr, alignment);
             } else {
@@ -7557,6 +7589,9 @@ pub const FuncGen = struct {
         const optional_ty = if (operand_is_ptr) operand_ty.childType(zcu) else operand_ty;
         const optional_llvm_ty = try o.lowerType(optional_ty);
         const payload_ty = optional_ty.optionalChild(zcu);
+
+        if (operand_is_ptr) self.maybeMarkAllowZeroAccess(operand_ty.ptrInfo(zcu));
+
         if (optional_ty.optionalReprIsPayload(zcu)) {
             const loaded = if (operand_is_ptr)
                 try self.wip.load(.normal, optional_llvm_ty, operand, .default, "")
@@ -7613,6 +7648,8 @@ pub const FuncGen = struct {
             return val.toValue();
         }
 
+        if (operand_is_ptr) self.maybeMarkAllowZeroAccess(operand_ty.ptrInfo(zcu));
+
         if (!payload_ty.hasRuntimeBitsIgnoreComptime(zcu)) {
             const loaded = if (operand_is_ptr)
                 try self.wip.load(.normal, try o.lowerType(err_union_ty), operand, .default, "")
@@ -7664,6 +7701,8 @@ pub const FuncGen = struct {
         const payload_ty = optional_ty.optionalChild(zcu);
         const non_null_bit = try o.builder.intValue(.i8, 1);
         if (!payload_ty.hasRuntimeBitsIgnoreComptime(zcu)) {
+            self.maybeMarkAllowZeroAccess(self.typeOf(ty_op.operand).ptrInfo(zcu));
+
             // We have a pointer to a i8. We need to set it to 1 and then return the same pointer.
             _ = try self.wip.store(.normal, non_null_bit, operand, .default);
             return operand;
@@ -7677,6 +7716,9 @@ pub const FuncGen = struct {
         // First set the non-null bit.
         const optional_llvm_ty = try o.lowerType(optional_ty);
         const non_null_ptr = try self.wip.gepStruct(optional_llvm_ty, operand, 1, "");
+
+        self.maybeMarkAllowZeroAccess(self.typeOf(ty_op.operand).ptrInfo(zcu));
+
         // TODO set alignment on this store
         _ = try self.wip.store(.normal, non_null_bit, non_null_ptr, .default);
 
@@ -7767,12 +7809,17 @@ pub const FuncGen = struct {
         const payload_ty = err_union_ty.errorUnionPayload(zcu);
         if (!payload_ty.hasRuntimeBitsIgnoreComptime(zcu)) {
             if (!operand_is_ptr) return operand;
+
+            self.maybeMarkAllowZeroAccess(operand_ty.ptrInfo(zcu));
+
             return self.wip.load(.normal, error_type, operand, .default, "");
         }
 
         const offset = try errUnionErrorOffset(payload_ty, pt);
 
         if (operand_is_ptr or isByRef(err_union_ty, zcu)) {
+            if (operand_is_ptr) self.maybeMarkAllowZeroAccess(operand_ty.ptrInfo(zcu));
+
             const err_union_llvm_ty = try o.lowerType(err_union_ty);
             const err_field_ptr = try self.wip.gepStruct(err_union_llvm_ty, operand, offset, "");
             return self.wip.load(.normal, error_type, err_field_ptr, .default, "");
@@ -7792,11 +7839,15 @@ pub const FuncGen = struct {
         const payload_ty = err_union_ty.errorUnionPayload(zcu);
         const non_error_val = try o.builder.intValue(try o.errorIntType(), 0);
         if (!payload_ty.hasRuntimeBitsIgnoreComptime(zcu)) {
+            self.maybeMarkAllowZeroAccess(self.typeOf(ty_op.operand).ptrInfo(zcu));
+
             _ = try self.wip.store(.normal, non_error_val, operand, .default);
             return operand;
         }
         const err_union_llvm_ty = try o.lowerType(err_union_ty);
         {
+            self.maybeMarkAllowZeroAccess(self.typeOf(ty_op.operand).ptrInfo(zcu));
+
             const err_int_ty = try pt.errorIntType();
             const error_alignment = err_int_ty.abiAlignment(zcu).toLlvm();
             const error_offset = try errUnionErrorOffset(payload_ty, pt);
@@ -8017,6 +8068,8 @@ pub const FuncGen = struct {
         const index = try self.resolveInst(extra.lhs);
         const operand = try self.resolveInst(extra.rhs);
 
+        self.maybeMarkAllowZeroAccess(vector_ptr_ty.ptrInfo(zcu));
+
         const access_kind: Builder.MemoryAccessKind =
             if (vector_ptr_ty.isVolatilePtr(zcu)) .@"volatile" else .normal;
         const elem_llvm_ty = try o.lowerType(vector_ptr_ty.childType(zcu));
@@ -9482,6 +9535,8 @@ pub const FuncGen = struct {
                 return .none;
             }
 
+            self.maybeMarkAllowZeroAccess(ptr_info);
+
             const len = try o.builder.intValue(try o.lowerType(Type.usize), operand_ty.abiSize(zcu));
             _ = try self.wip.callMemSet(
                 dest_ptr,
@@ -9497,6 +9552,8 @@ pub const FuncGen = struct {
             return .none;
         }
 
+        self.maybeMarkAllowZeroAccess(ptr_ty.ptrInfo(zcu));
+
         const src_operand = try self.resolveInst(bin_op.rhs);
         try self.store(dest_ptr, ptr_ty, src_operand, .none);
         return .none;
@@ -9539,6 +9596,9 @@ pub const FuncGen = struct {
             if (!canElideLoad(fg, body_tail)) break :elide;
             return ptr;
         }
+
+        fg.maybeMarkAllowZeroAccess(ptr_info);
+
         return fg.load(ptr, ptr_ty);
     }
 
@@ -9598,6 +9658,8 @@ pub const FuncGen = struct {
             new_value = try self.wip.conv(signedness, new_value, llvm_abi_ty, "");
         }
 
+        self.maybeMarkAllowZeroAccess(ptr_ty.ptrInfo(zcu));
+
         const result = try self.wip.cmpxchg(
             kind,
             if (ptr_ty.isVolatilePtr(zcu)) .@"volatile" else .normal,
@@ -9649,6 +9711,8 @@ pub const FuncGen = struct {
             if (ptr_ty.isVolatilePtr(zcu)) .@"volatile" else .normal;
         const ptr_alignment = ptr_ty.ptrAlignment(zcu).toLlvm();
 
+        self.maybeMarkAllowZeroAccess(ptr_ty.ptrInfo(zcu));
+
         if (llvm_abi_ty != .none) {
             // operand needs widening and truncating or bitcasting.
             return self.wip.cast(if (is_float) .bitcast else .trunc, try self.wip.atomicrmw(
@@ -9712,6 +9776,8 @@ pub const FuncGen = struct {
             if (info.flags.is_volatile) .@"volatile" else .normal;
         const elem_llvm_ty = try o.lowerType(elem_ty);
 
+        self.maybeMarkAllowZeroAccess(info);
+
         if (llvm_abi_ty != .none) {
             // operand needs widening and truncating
             const loaded = try self.wip.loadAtomic(
@@ -9761,6 +9827,9 @@ pub const FuncGen = struct {
                 "",
             );
         }
+
+        self.maybeMarkAllowZeroAccess(ptr_ty.ptrInfo(zcu));
+
         try self.store(ptr, ptr_ty, element, ordering);
         return .none;
     }
@@ -9778,6 +9847,8 @@ pub const FuncGen = struct {
         const access_kind: Builder.MemoryAccessKind =
             if (ptr_ty.isVolatilePtr(zcu)) .@"volatile" else .normal;
 
+        self.maybeMarkAllowZeroAccess(ptr_ty.ptrInfo(zcu));
+
         if (try self.air.value(bin_op.rhs, pt)) |elem_val| {
             if (elem_val.isUndefDeep(zcu)) {
                 // Even if safety is disabled, we still emit a memset to undefined since it conveys
@@ -9916,6 +9987,9 @@ pub const FuncGen = struct {
         const access_kind: Builder.MemoryAccessKind = if (src_ptr_ty.isVolatilePtr(zcu) or
             dest_ptr_ty.isVolatilePtr(zcu)) .@"volatile" else .normal;
 
+        self.maybeMarkAllowZeroAccess(dest_ptr_ty.ptrInfo(zcu));
+        self.maybeMarkAllowZeroAccess(src_ptr_ty.ptrInfo(zcu));
+
         _ = try self.wip.callMemCpy(
             dest_ptr,
             dest_ptr_ty.ptrAlignment(zcu).toLlvm(),
@@ -9962,6 +10036,9 @@ pub const FuncGen = struct {
         const un_ty = self.typeOf(bin_op.lhs).childType(zcu);
         const layout = un_ty.unionGetLayout(zcu);
         if (layout.tag_size == 0) return .none;
+
+        self.maybeMarkAllowZeroAccess(self.typeOf(bin_op.lhs).ptrInfo(zcu));
+
         const union_ptr = try self.resolveInst(bin_op.lhs);
         const new_tag = try self.resolveInst(bin_op.rhs);
         if (layout.payload_size == 0) {