Commit 74db8c2e83

Andrew Kelley <andrew@ziglang.org>
2023-02-18 23:34:00
omit safety checks for element access in for loops
One of the main points of for loops is that you can safety check the length once, before entering the loop, and then safely assume that every element inside the loop is in bounds. In master branch, the safety checks are incorrectly intact even inside for loops. This commit fixes it. It's especially nice with multi-object loops because the number of elided checks is N * M where N is how many iterations and M is how many objects.
1 parent 7abeb52
Changed files (3)
src
test
behavior
src/Sema.zig
@@ -4544,7 +4544,7 @@ fn zirValidateArrayInit(
         // any ZIR instructions at comptime; we need to do that here.
         if (array_ty.sentinel()) |sentinel_val| {
             const array_len_ref = try sema.addIntUnsigned(Type.usize, array_len);
-            const sentinel_ptr = try sema.elemPtrArray(block, init_src, init_src, array_ptr, init_src, array_len_ref, true);
+            const sentinel_ptr = try sema.elemPtrArray(block, init_src, init_src, array_ptr, init_src, array_len_ref, true, true);
             const sentinel = try sema.addConstant(array_ty.childType(), sentinel_val);
             try sema.storePtr2(block, init_src, sentinel_ptr, init_src, sentinel, init_src, .store);
         }
@@ -9691,7 +9691,7 @@ fn zirElemVal(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
     const array = try sema.resolveInst(extra.lhs);
     const elem_index = try sema.resolveInst(extra.rhs);
-    return sema.elemVal(block, src, array, elem_index, src);
+    return sema.elemVal(block, src, array, elem_index, src, false);
 }
 
 fn zirElemValNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -9704,7 +9704,7 @@ fn zirElemValNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
     const array = try sema.resolveInst(extra.lhs);
     const elem_index = try sema.resolveInst(extra.rhs);
-    return sema.elemVal(block, src, array, elem_index, elem_index_src);
+    return sema.elemVal(block, src, array, elem_index, elem_index_src, true);
 }
 
 fn zirElemPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -9731,7 +9731,7 @@ fn zirElemPtr(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air
         };
         return sema.failWithOwnedErrorMsg(msg);
     }
-    return sema.elemPtrOneLayerOnly(block, src, array_ptr, elem_index, src, false);
+    return sema.elemPtrOneLayerOnly(block, src, array_ptr, elem_index, src, false, false);
 }
 
 fn zirElemPtrNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -9744,7 +9744,7 @@ fn zirElemPtrNode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError
     const extra = sema.code.extraData(Zir.Inst.Bin, inst_data.payload_index).data;
     const array_ptr = try sema.resolveInst(extra.lhs);
     const elem_index = try sema.resolveInst(extra.rhs);
-    return sema.elemPtr(block, src, array_ptr, elem_index, elem_index_src, false);
+    return sema.elemPtr(block, src, array_ptr, elem_index, elem_index_src, false, true);
 }
 
 fn zirElemPtrImm(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -9756,7 +9756,7 @@ fn zirElemPtrImm(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!
     const extra = sema.code.extraData(Zir.Inst.ElemPtrImm, inst_data.payload_index).data;
     const array_ptr = try sema.resolveInst(extra.ptr);
     const elem_index = try sema.addIntUnsigned(Type.usize, extra.index);
-    return sema.elemPtr(block, src, array_ptr, elem_index, src, true);
+    return sema.elemPtr(block, src, array_ptr, elem_index, src, true, true);
 }
 
 fn zirSliceStart(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
@@ -12521,14 +12521,14 @@ fn zirArrayCat(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
         while (elem_i < lhs_len) : (elem_i += 1) {
             const elem_index = try sema.addIntUnsigned(Type.usize, elem_i);
             const elem_ptr = try block.addPtrElemPtr(alloc, elem_index, elem_ptr_ty);
-            const init = try sema.elemVal(block, lhs_src, lhs, elem_index, src);
+            const init = try sema.elemVal(block, lhs_src, lhs, elem_index, src, true);
             try sema.storePtr2(block, src, elem_ptr, src, init, lhs_src, .store);
         }
         while (elem_i < result_len) : (elem_i += 1) {
             const elem_index = try sema.addIntUnsigned(Type.usize, elem_i);
             const rhs_index = try sema.addIntUnsigned(Type.usize, elem_i - lhs_len);
             const elem_ptr = try block.addPtrElemPtr(alloc, elem_index, elem_ptr_ty);
-            const init = try sema.elemVal(block, rhs_src, rhs, rhs_index, src);
+            const init = try sema.elemVal(block, rhs_src, rhs, rhs_index, src, true);
             try sema.storePtr2(block, src, elem_ptr, src, init, rhs_src, .store);
         }
         if (res_sent_val) |sent_val| {
@@ -12546,12 +12546,12 @@ fn zirArrayCat(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
         var elem_i: usize = 0;
         while (elem_i < lhs_len) : (elem_i += 1) {
             const index = try sema.addIntUnsigned(Type.usize, elem_i);
-            const init = try sema.elemVal(block, lhs_src, lhs, index, src);
+            const init = try sema.elemVal(block, lhs_src, lhs, index, src, true);
             element_refs[elem_i] = try sema.coerce(block, resolved_elem_ty, init, lhs_src);
         }
         while (elem_i < result_len) : (elem_i += 1) {
             const index = try sema.addIntUnsigned(Type.usize, elem_i - lhs_len);
-            const init = try sema.elemVal(block, rhs_src, rhs, index, src);
+            const init = try sema.elemVal(block, rhs_src, rhs, index, src, true);
             element_refs[elem_i] = try sema.coerce(block, resolved_elem_ty, init, rhs_src);
         }
     }
@@ -12771,7 +12771,7 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
                 elem_i += 1;
                 const lhs_index = try sema.addIntUnsigned(Type.usize, lhs_i);
                 const elem_ptr = try block.addPtrElemPtr(alloc, elem_index, elem_ptr_ty);
-                const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src);
+                const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src, true);
                 try sema.storePtr2(block, src, elem_ptr, src, init, lhs_src, .store);
             }
         }
@@ -12791,7 +12791,7 @@ fn zirArrayMul(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Ai
         var lhs_i: usize = 0;
         while (lhs_i < lhs_len) : (lhs_i += 1) {
             const lhs_index = try sema.addIntUnsigned(Type.usize, lhs_i);
-            const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src);
+            const init = try sema.elemVal(block, lhs_src, lhs, lhs_index, src, true);
             element_refs[elem_i] = init;
             elem_i += 1;
         }
@@ -24145,6 +24145,7 @@ fn elemPtr(
     elem_index: Air.Inst.Ref,
     elem_index_src: LazySrcLoc,
     init: bool,
+    oob_safety: bool,
 ) CompileError!Air.Inst.Ref {
     const indexable_ptr_src = src; // TODO better source location
     const indexable_ptr_ty = sema.typeOf(indexable_ptr);
@@ -24154,7 +24155,7 @@ fn elemPtr(
         else => return sema.fail(block, indexable_ptr_src, "expected pointer, found '{}'", .{indexable_ptr_ty.fmt(sema.mod)}),
     };
     switch (indexable_ty.zigTypeTag()) {
-        .Array, .Vector => return sema.elemPtrArray(block, src, indexable_ptr_src, indexable_ptr, elem_index_src, elem_index, init),
+        .Array, .Vector => return sema.elemPtrArray(block, src, indexable_ptr_src, indexable_ptr, elem_index_src, elem_index, init, oob_safety),
         .Struct => {
             // Tuple field access.
             const index_val = try sema.resolveConstValue(block, elem_index_src, elem_index, "tuple field access index must be comptime-known");
@@ -24163,11 +24164,12 @@ fn elemPtr(
         },
         else => {
             const indexable = try sema.analyzeLoad(block, indexable_ptr_src, indexable_ptr, indexable_ptr_src);
-            return elemPtrOneLayerOnly(sema, block, src, indexable, elem_index, elem_index_src, init);
+            return elemPtrOneLayerOnly(sema, block, src, indexable, elem_index, elem_index_src, init, oob_safety);
         },
     }
 }
 
+/// Asserts that the type of indexable is pointer.
 fn elemPtrOneLayerOnly(
     sema: *Sema,
     block: *Block,
@@ -24176,6 +24178,7 @@ fn elemPtrOneLayerOnly(
     elem_index: Air.Inst.Ref,
     elem_index_src: LazySrcLoc,
     init: bool,
+    oob_safety: bool,
 ) CompileError!Air.Inst.Ref {
     const indexable_src = src; // TODO better source location
     const indexable_ty = sema.typeOf(indexable);
@@ -24184,33 +24187,28 @@ fn elemPtrOneLayerOnly(
     }
     const target = sema.mod.getTarget();
 
-    switch (indexable_ty.zigTypeTag()) {
-        .Pointer => {
-            switch (indexable_ty.ptrSize()) {
-                .Slice => return sema.elemPtrSlice(block, src, indexable_src, indexable, elem_index_src, elem_index),
-                .Many, .C => {
-                    const maybe_ptr_val = try sema.resolveDefinedValue(block, indexable_src, indexable);
-                    const maybe_index_val = try sema.resolveDefinedValue(block, elem_index_src, elem_index);
-                    const runtime_src = rs: {
-                        const ptr_val = maybe_ptr_val orelse break :rs indexable_src;
-                        const index_val = maybe_index_val orelse break :rs elem_index_src;
-                        const index = @intCast(usize, index_val.toUnsignedInt(target));
-                        const elem_ptr = try ptr_val.elemPtr(indexable_ty, sema.arena, index, sema.mod);
-                        const result_ty = try sema.elemPtrType(indexable_ty, index);
-                        return sema.addConstant(result_ty, elem_ptr);
-                    };
-                    const result_ty = try sema.elemPtrType(indexable_ty, null);
+    switch (indexable_ty.ptrSize()) {
+        .Slice => return sema.elemPtrSlice(block, src, indexable_src, indexable, elem_index_src, elem_index, oob_safety),
+        .Many, .C => {
+            const maybe_ptr_val = try sema.resolveDefinedValue(block, indexable_src, indexable);
+            const maybe_index_val = try sema.resolveDefinedValue(block, elem_index_src, elem_index);
+            const runtime_src = rs: {
+                const ptr_val = maybe_ptr_val orelse break :rs indexable_src;
+                const index_val = maybe_index_val orelse break :rs elem_index_src;
+                const index = @intCast(usize, index_val.toUnsignedInt(target));
+                const elem_ptr = try ptr_val.elemPtr(indexable_ty, sema.arena, index, sema.mod);
+                const result_ty = try sema.elemPtrType(indexable_ty, index);
+                return sema.addConstant(result_ty, elem_ptr);
+            };
+            const result_ty = try sema.elemPtrType(indexable_ty, null);
 
-                    try sema.requireRuntimeBlock(block, src, runtime_src);
-                    return block.addPtrElemPtr(indexable, elem_index, result_ty);
-                },
-                .One => {
-                    assert(indexable_ty.childType().zigTypeTag() == .Array); // Guaranteed by isIndexable
-                    return sema.elemPtrArray(block, src, indexable_src, indexable, elem_index_src, elem_index, init);
-                },
-            }
+            try sema.requireRuntimeBlock(block, src, runtime_src);
+            return block.addPtrElemPtr(indexable, elem_index, result_ty);
+        },
+        .One => {
+            assert(indexable_ty.childType().zigTypeTag() == .Array); // Guaranteed by isIndexable
+            return sema.elemPtrArray(block, src, indexable_src, indexable, elem_index_src, elem_index, init, oob_safety);
         },
-        else => unreachable,
     }
 }
 
@@ -24221,6 +24219,7 @@ fn elemVal(
     indexable: Air.Inst.Ref,
     elem_index_uncasted: Air.Inst.Ref,
     elem_index_src: LazySrcLoc,
+    oob_safety: bool,
 ) CompileError!Air.Inst.Ref {
     const indexable_src = src; // TODO better source location
     const indexable_ty = sema.typeOf(indexable);
@@ -24236,7 +24235,7 @@ fn elemVal(
 
     switch (indexable_ty.zigTypeTag()) {
         .Pointer => switch (indexable_ty.ptrSize()) {
-            .Slice => return sema.elemValSlice(block, src, indexable_src, indexable, elem_index_src, elem_index),
+            .Slice => return sema.elemValSlice(block, src, indexable_src, indexable, elem_index_src, elem_index, oob_safety),
             .Many, .C => {
                 const maybe_indexable_val = try sema.resolveDefinedValue(block, indexable_src, indexable);
                 const maybe_index_val = try sema.resolveDefinedValue(block, elem_index_src, elem_index);
@@ -24257,14 +24256,14 @@ fn elemVal(
             },
             .One => {
                 assert(indexable_ty.childType().zigTypeTag() == .Array); // Guaranteed by isIndexable
-                const elem_ptr = try sema.elemPtr(block, indexable_src, indexable, elem_index, elem_index_src, false);
+                const elem_ptr = try sema.elemPtr(block, indexable_src, indexable, elem_index, elem_index_src, false, oob_safety);
                 return sema.analyzeLoad(block, indexable_src, elem_ptr, elem_index_src);
             },
         },
-        .Array => return sema.elemValArray(block, src, indexable_src, indexable, elem_index_src, elem_index),
+        .Array => return sema.elemValArray(block, src, indexable_src, indexable, elem_index_src, elem_index, oob_safety),
         .Vector => {
             // TODO: If the index is a vector, the result should be a vector.
-            return sema.elemValArray(block, src, indexable_src, indexable, elem_index_src, elem_index);
+            return sema.elemValArray(block, src, indexable_src, indexable, elem_index_src, elem_index, oob_safety);
         },
         .Struct => {
             // Tuple field access.
@@ -24409,6 +24408,7 @@ fn elemValArray(
     array: Air.Inst.Ref,
     elem_index_src: LazySrcLoc,
     elem_index: Air.Inst.Ref,
+    oob_safety: bool,
 ) CompileError!Air.Inst.Ref {
     const array_ty = sema.typeOf(array);
     const array_sent = array_ty.sentinel();
@@ -24452,7 +24452,7 @@ fn elemValArray(
 
     const runtime_src = if (maybe_undef_array_val != null) elem_index_src else array_src;
     try sema.requireRuntimeBlock(block, src, runtime_src);
-    if (block.wantSafety()) {
+    if (oob_safety and block.wantSafety()) {
         // Runtime check is only needed if unable to comptime check
         if (maybe_index_val == null) {
             const len_inst = try sema.addIntUnsigned(Type.usize, array_len);
@@ -24472,6 +24472,7 @@ fn elemPtrArray(
     elem_index_src: LazySrcLoc,
     elem_index: Air.Inst.Ref,
     init: bool,
+    oob_safety: bool,
 ) CompileError!Air.Inst.Ref {
     const target = sema.mod.getTarget();
     const array_ptr_ty = sema.typeOf(array_ptr);
@@ -24515,7 +24516,7 @@ fn elemPtrArray(
     try sema.requireRuntimeBlock(block, src, runtime_src);
 
     // Runtime check is only needed if unable to comptime check.
-    if (block.wantSafety() and offset == null) {
+    if (oob_safety and block.wantSafety() and offset == null) {
         const len_inst = try sema.addIntUnsigned(Type.usize, array_len);
         const cmp_op: Air.Inst.Tag = if (array_sent) .cmp_lte else .cmp_lt;
         try sema.panicIndexOutOfBounds(block, elem_index, len_inst, cmp_op);
@@ -24532,6 +24533,7 @@ fn elemValSlice(
     slice: Air.Inst.Ref,
     elem_index_src: LazySrcLoc,
     elem_index: Air.Inst.Ref,
+    oob_safety: bool,
 ) CompileError!Air.Inst.Ref {
     const slice_ty = sema.typeOf(slice);
     const slice_sent = slice_ty.sentinel() != null;
@@ -24568,7 +24570,7 @@ fn elemValSlice(
     try sema.validateRuntimeElemAccess(block, elem_index_src, elem_ty, slice_ty, slice_src);
 
     try sema.requireRuntimeBlock(block, src, runtime_src);
-    if (block.wantSafety()) {
+    if (oob_safety and block.wantSafety()) {
         const len_inst = if (maybe_slice_val) |slice_val|
             try sema.addIntUnsigned(Type.usize, slice_val.sliceLen(sema.mod))
         else
@@ -24588,6 +24590,7 @@ fn elemPtrSlice(
     slice: Air.Inst.Ref,
     elem_index_src: LazySrcLoc,
     elem_index: Air.Inst.Ref,
+    oob_safety: bool,
 ) CompileError!Air.Inst.Ref {
     const target = sema.mod.getTarget();
     const slice_ty = sema.typeOf(slice);
@@ -24625,7 +24628,7 @@ fn elemPtrSlice(
 
     const runtime_src = if (maybe_undef_slice_val != null) elem_index_src else slice_src;
     try sema.requireRuntimeBlock(block, src, runtime_src);
-    if (block.wantSafety()) {
+    if (oob_safety and block.wantSafety()) {
         const len_inst = len: {
             if (maybe_undef_slice_val) |slice_val|
                 if (!slice_val.isUndef())
@@ -26330,7 +26333,7 @@ fn storePtr2(
             const elem_src = operand_src; // TODO better source location
             const elem = try sema.tupleField(block, operand_src, uncasted_operand, elem_src, i);
             const elem_index = try sema.addIntUnsigned(Type.usize, i);
-            const elem_ptr = try sema.elemPtr(block, ptr_src, ptr, elem_index, elem_src, false);
+            const elem_ptr = try sema.elemPtr(block, ptr_src, ptr, elem_index, elem_src, false, true);
             try sema.storePtr2(block, src, elem_ptr, elem_src, elem, elem_src, .store);
         }
         return;
@@ -27782,7 +27785,7 @@ fn coerceArrayLike(
         );
         const src = inst_src; // TODO better source location
         const elem_src = inst_src; // TODO better source location
-        const elem_ref = try sema.elemValArray(block, src, inst_src, inst, elem_src, index_ref);
+        const elem_ref = try sema.elemValArray(block, src, inst_src, inst, elem_src, index_ref, true);
         const coerced = try sema.coerce(block, dest_elem_ty, elem_ref, elem_src);
         element_refs[i] = coerced;
         if (runtime_src == null) {
src/Zir.zig
@@ -382,7 +382,9 @@ pub const Inst = struct {
         /// Uses the `pl_node` union field. AST node is a[b] syntax. Payload is `Bin`.
         elem_ptr_node,
         /// Same as `elem_ptr_node` but used only for for loop.
-        /// Uses the `pl_node` union field. AST node is the condition of a for loop. Payload is `Bin`.
+        /// Uses the `pl_node` union field. AST node is the condition of a for loop.
+        /// Payload is `Bin`.
+        /// No OOB safety check is emitted.
         elem_ptr,
         /// Same as `elem_ptr_node` except the index is stored immediately rather than
         /// as a reference to another ZIR instruction.
@@ -395,7 +397,9 @@ pub const Inst = struct {
         /// Uses the `pl_node` union field. AST node is a[b] syntax. Payload is `Bin`.
         elem_val_node,
         /// Same as `elem_val_node` but used only for for loop.
-        /// Uses the `pl_node` union field. AST node is the condition of a for loop. Payload is `Bin`.
+        /// Uses the `pl_node` union field. AST node is the condition of a for loop.
+        /// Payload is `Bin`.
+        /// No OOB safety check is emitted.
         elem_val,
         /// Emits a compile error if the operand is not `void`.
         /// Uses the `un_node` field.
test/behavior/for.zig
@@ -314,6 +314,7 @@ test "slice and two counters, one is offset and one is runtime" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
 
     const slice: []const u8 = "blah";
     var start: usize = 0;
@@ -342,6 +343,7 @@ test "two slices, one captured by-ref" {
     if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
     if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+    if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
 
     var buf: [10]u8 = undefined;
     const slice1: []const u8 = "blah";