Commit 855edd2949

LemonBoy <thatlemon@gmail.com>
2020-03-22 20:20:36
ir: Rewrite the bound checks in slice operator
Closes #4777
1 parent 0dbf8aa
Changed files (1)
src/codegen.cpp
@@ -5408,6 +5408,8 @@ static LLVMValueRef ir_render_memcpy(CodeGen *g, IrExecutableGen *executable, Ir
 }
 
 static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutableGen *executable, IrInstGenSlice *instruction) {
+    Error err;
+
     LLVMValueRef array_ptr_ptr = ir_llvm_value(g, instruction->ptr);
     ZigType *array_ptr_type = instruction->ptr->value->type;
     assert(array_ptr_type->id == ZigTypeIdPointer);
@@ -5416,15 +5418,16 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutableGen *executable, IrI
 
     bool want_runtime_safety = instruction->safety_check_on && ir_want_runtime_safety(g, &instruction->base);
 
+    // The result is either a slice or a pointer to an array
     ZigType *result_type = instruction->base.value->type;
-    if (!type_has_bits(g, result_type)) {
-        return nullptr;
-    }
 
     // This is not whether the result type has a sentinel, but whether there should be a sentinel check,
     // e.g. if they used [a..b :s] syntax.
     ZigValue *sentinel = instruction->sentinel;
 
+    LLVMValueRef slice_start_ptr = nullptr;
+    LLVMValueRef len_value = nullptr;
+
     if (array_type->id == ZigTypeIdArray ||
         (array_type->id == ZigTypeIdPointer && array_type->data.pointer.ptr_len == PtrLenSingle))
     {
@@ -5438,111 +5441,86 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutableGen *executable, IrI
         } else {
             end_val = LLVMConstInt(g->builtin_types.entry_usize->llvm_type, array_type->data.array.len, false);
         }
+
         if (want_runtime_safety) {
+            // Safety check: start <= end
             if (instruction->start->value->special == ConstValSpecialRuntime || instruction->end) {
                 add_bounds_check(g, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val);
             }
-            if (instruction->end) {
-                LLVMValueRef array_end = LLVMConstInt(g->builtin_types.entry_usize->llvm_type,
-                        array_type->data.array.len, false);
-                add_bounds_check(g, end_val, LLVMIntEQ, nullptr, LLVMIntULE, array_end);
 
-                if (sentinel != nullptr) {
-                    LLVMValueRef indices[] = {
-                        LLVMConstNull(g->builtin_types.entry_usize->llvm_type),
-                        end_val,
-                    };
-                    LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, "");
-                    add_sentinel_check(g, sentinel_elem_ptr, sentinel);
-                }
-            }
-        }
-        if (!type_has_bits(g, array_type)) {
-            LLVMValueRef tmp_struct_ptr = ir_llvm_value(g, instruction->result_loc);
+            // Safety check: the last element of the slice (the sentinel if
+            // requested) must be inside the array
+            // XXX: Overflow is not checked here...
+            const size_t full_len = array_type->data.array.len +
+                (array_type->data.array.sentinel != nullptr);
+            LLVMValueRef array_end = LLVMConstInt(g->builtin_types.entry_usize->llvm_type,
+                    full_len, false);
 
-            LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, slice_len_index, "");
-
-            // TODO if runtime safety is on, store 0xaaaaaaa in ptr field
-            LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, "");
-            gen_store_untyped(g, len_value, len_field_ptr, 0, false);
-            return tmp_struct_ptr;
+            LLVMValueRef check_end_val = end_val;
+            if (sentinel != nullptr) {
+                LLVMValueRef usize_one = LLVMConstInt(g->builtin_types.entry_usize->llvm_type, 1, false);
+                check_end_val = LLVMBuildNUWAdd(g->builder, end_val, usize_one, "");
+            }
+            add_bounds_check(g, check_end_val, LLVMIntEQ, nullptr, LLVMIntULE, array_end);
         }
 
-        LLVMValueRef indices[] = {
-            LLVMConstNull(g->builtin_types.entry_usize->llvm_type),
-            start_val,
-        };
-        LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, "");
-        if (result_type->id == ZigTypeIdPointer) {
-            ir_assert(instruction->result_loc == nullptr, &instruction->base);
-            LLVMTypeRef result_ptr_type = get_llvm_type(g, result_type);
-            return LLVMBuildBitCast(g->builder, slice_start_ptr, result_ptr_type, "");
-        } else {
-            LLVMValueRef tmp_struct_ptr = ir_llvm_value(g, instruction->result_loc);
-            LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, slice_ptr_index, "");
-            gen_store_untyped(g, slice_start_ptr, ptr_field_ptr, 0, false);
+        bool value_has_bits;
+        if ((err = type_has_bits2(g, array_type, &value_has_bits)))
+            codegen_report_errors_and_exit(g);
 
-            LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, slice_len_index, "");
-            LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, "");
-            gen_store_untyped(g, len_value, len_field_ptr, 0, false);
+        if (value_has_bits) {
+            if (want_runtime_safety && sentinel != nullptr) {
+                LLVMValueRef indices[] = {
+                    LLVMConstNull(g->builtin_types.entry_usize->llvm_type),
+                    end_val,
+                };
+                LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, "");
+                add_sentinel_check(g, sentinel_elem_ptr, sentinel);
+            }
 
-            return tmp_struct_ptr;
+            LLVMValueRef indices[] = {
+                LLVMConstNull(g->builtin_types.entry_usize->llvm_type),
+                start_val,
+            };
+            slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, "");
         }
+
+        len_value = LLVMBuildNUWSub(g->builder, end_val, start_val, "");
     } else if (array_type->id == ZigTypeIdPointer) {
         assert(array_type->data.pointer.ptr_len != PtrLenSingle);
         LLVMValueRef start_val = ir_llvm_value(g, instruction->start);
         LLVMValueRef end_val = ir_llvm_value(g, instruction->end);
 
         if (want_runtime_safety) {
+            // Safety check: start <= end
             add_bounds_check(g, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val);
-            if (sentinel != nullptr) {
+        }
+
+        bool value_has_bits;
+        if ((err = type_has_bits2(g, array_type, &value_has_bits)))
+            codegen_report_errors_and_exit(g);
+
+        if (value_has_bits) {
+            if (want_runtime_safety && sentinel != nullptr) {
                 LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, &end_val, 1, "");
                 add_sentinel_check(g, sentinel_elem_ptr, sentinel);
             }
-        }
 
-        if (!type_has_bits(g, array_type)) {
-            LLVMValueRef tmp_struct_ptr = ir_llvm_value(g, instruction->result_loc);
-            size_t gen_len_index = result_type->data.structure.fields[slice_len_index]->gen_index;
-            LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, gen_len_index, "");
-            LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, "");
-            gen_store_untyped(g, len_value, len_field_ptr, 0, false);
-            return tmp_struct_ptr;
+            slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, &start_val, 1, "");
         }
 
-        LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, &start_val, 1, "");
-        if (result_type->id == ZigTypeIdPointer) {
-            ir_assert(instruction->result_loc == nullptr, &instruction->base);
-            LLVMTypeRef result_ptr_type = get_llvm_type(g, result_type);
-            return LLVMBuildBitCast(g->builder, slice_start_ptr, result_ptr_type, "");
-        }
-
-        LLVMValueRef tmp_struct_ptr = ir_llvm_value(g, instruction->result_loc);
-
-        size_t gen_ptr_index = result_type->data.structure.fields[slice_ptr_index]->gen_index;
-        LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, gen_ptr_index, "");
-        gen_store_untyped(g, slice_start_ptr, ptr_field_ptr, 0, false);
-
-        size_t gen_len_index = result_type->data.structure.fields[slice_len_index]->gen_index;
-        LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, gen_len_index, "");
-        LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, "");
-        gen_store_untyped(g, len_value, len_field_ptr, 0, false);
-
-        return tmp_struct_ptr;
-
+        len_value = LLVMBuildNUWSub(g->builder, end_val, start_val, "");
     } else if (array_type->id == ZigTypeIdStruct) {
         assert(array_type->data.structure.special == StructSpecialSlice);
         assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind);
         assert(LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(array_ptr))) == LLVMStructTypeKind);
 
-        size_t ptr_index = array_type->data.structure.fields[slice_ptr_index]->gen_index;
-        assert(ptr_index != SIZE_MAX);
-        size_t len_index = array_type->data.structure.fields[slice_len_index]->gen_index;
-        assert(len_index != SIZE_MAX);
+        const size_t gen_len_index = array_type->data.structure.fields[slice_len_index]->gen_index;
+        assert(gen_len_index != SIZE_MAX);
 
         LLVMValueRef prev_end = nullptr;
         if (!instruction->end || want_runtime_safety) {
-            LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, (unsigned)len_index, "");
+            LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, gen_len_index, "");
             prev_end = gen_load_untyped(g, src_len_ptr, 0, false, "");
         }
 
@@ -5554,41 +5532,104 @@ static LLVMValueRef ir_render_slice(CodeGen *g, IrExecutableGen *executable, IrI
             end_val = prev_end;
         }
 
-        LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, (unsigned)ptr_index, "");
-        LLVMValueRef src_ptr = gen_load_untyped(g, src_ptr_ptr, 0, false, "");
+        ZigType *ptr_field_type = array_type->data.structure.fields[slice_ptr_index]->type_entry;
 
         if (want_runtime_safety) {
             assert(prev_end);
+            // Safety check: start <= end
             add_bounds_check(g, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val);
-            if (instruction->end) {
-                add_bounds_check(g, end_val, LLVMIntEQ, nullptr, LLVMIntULE, prev_end);
 
-                if (sentinel != nullptr) {
-                    LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &end_val, 1, "");
-                    add_sentinel_check(g, sentinel_elem_ptr, sentinel);
-                }
+            // Safety check: the sentinel counts as one more element
+            // XXX: Overflow is not checked here...
+            LLVMValueRef check_prev_end = prev_end;
+            if (ptr_field_type->data.pointer.sentinel != nullptr) {
+                LLVMValueRef usize_one = LLVMConstInt(g->builtin_types.entry_usize->llvm_type, 1, false);
+                check_prev_end = LLVMBuildNUWAdd(g->builder, prev_end, usize_one, "");
+            }
+            LLVMValueRef check_end_val = end_val;
+            if (sentinel != nullptr) {
+                LLVMValueRef usize_one = LLVMConstInt(g->builtin_types.entry_usize->llvm_type, 1, false);
+                check_end_val = LLVMBuildNUWAdd(g->builder, end_val, usize_one, "");
             }
+
+            add_bounds_check(g, check_end_val, LLVMIntEQ, nullptr, LLVMIntULE, check_prev_end);
         }
 
-        LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &start_val, 1, "");
-        if (result_type->id == ZigTypeIdPointer) {
-            ir_assert(instruction->result_loc == nullptr, &instruction->base);
-            LLVMTypeRef result_ptr_type = get_llvm_type(g, result_type);
-            return LLVMBuildBitCast(g->builder, slice_start_ptr, result_ptr_type, "");
-        } else {
-            LLVMValueRef tmp_struct_ptr = ir_llvm_value(g, instruction->result_loc);
-            LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, (unsigned)ptr_index, "");
-            gen_store_untyped(g, slice_start_ptr, ptr_field_ptr, 0, false);
+        bool ptr_has_bits;
+        if ((err = type_has_bits2(g, ptr_field_type, &ptr_has_bits)))
+            codegen_report_errors_and_exit(g);
+
+        if (ptr_has_bits) {
+            const size_t gen_ptr_index = array_type->data.structure.fields[slice_ptr_index]->gen_index;
+            assert(gen_ptr_index != SIZE_MAX);
 
-            LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, (unsigned)len_index, "");
-            LLVMValueRef len_value = LLVMBuildNSWSub(g->builder, end_val, start_val, "");
-            gen_store_untyped(g, len_value, len_field_ptr, 0, false);
+            LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, gen_ptr_index, "");
+            LLVMValueRef src_ptr = gen_load_untyped(g, src_ptr_ptr, 0, false, "");
 
-            return tmp_struct_ptr;
+            if (sentinel != nullptr) {
+                LLVMValueRef sentinel_elem_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &end_val, 1, "");
+                add_sentinel_check(g, sentinel_elem_ptr, sentinel);
+            }
+
+            slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, src_ptr, &start_val, 1, "");
         }
+
+        len_value = LLVMBuildNUWSub(g->builder, end_val, start_val, "");
     } else {
         zig_unreachable();
     }
+
+    bool result_has_bits;
+    if ((err = type_has_bits2(g, result_type, &result_has_bits)))
+        codegen_report_errors_and_exit(g);
+
+    // Nothing to do, we're only interested in the bound checks emitted above
+    if (!result_has_bits)
+        return nullptr;
+
+    // The starting pointer for the slice may be null in case of zero-sized
+    // arrays, the length value is always defined.
+    assert(len_value != nullptr);
+
+    // The slice decays into a pointer to an array, the size is tracked in the
+    // type itself
+    if (result_type->id == ZigTypeIdPointer) {
+        ir_assert(instruction->result_loc == nullptr, &instruction->base);
+        LLVMTypeRef result_ptr_type = get_llvm_type(g, result_type);
+
+        if (slice_start_ptr != nullptr) {
+            return LLVMBuildBitCast(g->builder, slice_start_ptr, result_ptr_type, "");
+        }
+
+        return LLVMGetUndef(result_ptr_type);
+    }
+
+    ir_assert(instruction->result_loc != nullptr, &instruction->base);
+    // Create a new slice
+    LLVMValueRef tmp_struct_ptr = ir_llvm_value(g, instruction->result_loc);
+
+    ZigType *slice_ptr_type = result_type->data.structure.fields[slice_ptr_index]->type_entry;
+
+    // The slice may not have a pointer at all if it points to a zero-sized type
+    const size_t gen_ptr_index = result_type->data.structure.fields[slice_ptr_index]->gen_index;
+    if (gen_ptr_index != SIZE_MAX) {
+        LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, gen_ptr_index, "");
+        if (slice_start_ptr != nullptr) {
+            gen_store_untyped(g, slice_start_ptr, ptr_field_ptr, 0, false);
+        } else if (want_runtime_safety) {
+            gen_undef_init(g, slice_ptr_type->abi_align, slice_ptr_type, ptr_field_ptr);
+        } else {
+            gen_store_untyped(g, LLVMGetUndef(get_llvm_type(g, slice_ptr_type)), ptr_field_ptr, 0, false);
+        }
+    }
+
+    const size_t gen_len_index = result_type->data.structure.fields[slice_len_index]->gen_index;
+    assert(gen_len_index != SIZE_MAX);
+
+    LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, gen_len_index, "");
+    gen_store_untyped(g, len_value, len_field_ptr, 0, false);
+
+    return tmp_struct_ptr;
 }
 
 static LLVMValueRef get_trap_fn_val(CodeGen *g) {