Commit 8d0ac6dc4d

Andrew Kelley <andrew@ziglang.org>
2020-03-17 22:33:44
`@ptrCast` supports casting a slice to pointer
1 parent c896c50
Changed files (4)
lib/std/mem.zig
@@ -1750,34 +1750,50 @@ fn BytesAsSliceReturnType(comptime T: type, comptime bytesType: type) type {
 }
 
 pub fn bytesAsSlice(comptime T: type, bytes: var) BytesAsSliceReturnType(T, @TypeOf(bytes)) {
-    const bytesSlice = if (comptime trait.isPtrTo(.Array)(@TypeOf(bytes))) bytes[0..] else bytes;
-
     // let's not give an undefined pointer to @ptrCast
     // it may be equal to zero and fail a null check
-    if (bytesSlice.len == 0) {
+    if (bytes.len == 0) {
         return &[0]T{};
     }
 
-    const bytesType = @TypeOf(bytesSlice);
-    const alignment = comptime meta.alignment(bytesType);
+    const Bytes = @TypeOf(bytes);
+    const alignment = comptime meta.alignment(Bytes);
 
-    const castTarget = if (comptime trait.isConstPtr(bytesType)) [*]align(alignment) const T else [*]align(alignment) T;
+    const cast_target = if (comptime trait.isConstPtr(Bytes)) [*]align(alignment) const T else [*]align(alignment) T;
 
-    return @ptrCast(castTarget, bytesSlice.ptr)[0..@divExact(bytes.len, @sizeOf(T))];
+    return @ptrCast(cast_target, bytes)[0..@divExact(bytes.len, @sizeOf(T))];
 }
 
 test "bytesAsSlice" {
-    const bytes = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF };
-    const slice = bytesAsSlice(u16, bytes[0..]);
-    testing.expect(slice.len == 2);
-    testing.expect(bigToNative(u16, slice[0]) == 0xDEAD);
-    testing.expect(bigToNative(u16, slice[1]) == 0xBEEF);
+    {
+        const bytes = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF };
+        const slice = bytesAsSlice(u16, bytes[0..]);
+        testing.expect(slice.len == 2);
+        testing.expect(bigToNative(u16, slice[0]) == 0xDEAD);
+        testing.expect(bigToNative(u16, slice[1]) == 0xBEEF);
+    }
+    {
+        const bytes = [_]u8{ 0xDE, 0xAD, 0xBE, 0xEF };
+        var runtime_zero: usize = 0;
+        const slice = bytesAsSlice(u16, bytes[runtime_zero..]);
+        testing.expect(slice.len == 2);
+        testing.expect(bigToNative(u16, slice[0]) == 0xDEAD);
+        testing.expect(bigToNative(u16, slice[1]) == 0xBEEF);
+    }
 }
 
 test "bytesAsSlice keeps pointer alignment" {
-    var bytes = [_]u8{ 0x01, 0x02, 0x03, 0x04 };
-    const numbers = bytesAsSlice(u32, bytes[0..]);
-    comptime testing.expect(@TypeOf(numbers) == []align(@alignOf(@TypeOf(bytes))) u32);
+    {
+        var bytes = [_]u8{ 0x01, 0x02, 0x03, 0x04 };
+        const numbers = bytesAsSlice(u32, bytes[0..]);
+        comptime testing.expect(@TypeOf(numbers) == []align(@alignOf(@TypeOf(bytes))) u32);
+    }
+    {
+        var bytes = [_]u8{ 0x01, 0x02, 0x03, 0x04 };
+        var runtime_zero: usize = 0;
+        const numbers = bytesAsSlice(u32, bytes[runtime_zero..]);
+        comptime testing.expect(@TypeOf(numbers) == []align(@alignOf(@TypeOf(bytes))) u32);
+    }
 }
 
 test "bytesAsSlice on a packed struct" {
src/analyze.cpp
@@ -4486,7 +4486,14 @@ static uint32_t get_async_frame_align_bytes(CodeGen *g) {
 }
 
 uint32_t get_ptr_align(CodeGen *g, ZigType *type) {
-    ZigType *ptr_type = get_src_ptr_type(type);
+    ZigType *ptr_type;
+    if (type->id == ZigTypeIdStruct) {
+        assert(type->data.structure.special == StructSpecialSlice);
+        TypeStructField *ptr_field = type->data.structure.fields[slice_ptr_index];
+        ptr_type = resolve_struct_field_type(g, ptr_field);
+    } else {
+        ptr_type = get_src_ptr_type(type);
+    }
     if (ptr_type->id == ZigTypeIdPointer) {
         return (ptr_type->data.pointer.explicit_alignment == 0) ?
             get_abi_alignment(g, ptr_type->data.pointer.child_type) : ptr_type->data.pointer.explicit_alignment;
@@ -4503,8 +4510,15 @@ uint32_t get_ptr_align(CodeGen *g, ZigType *type) {
     }
 }
 
-bool get_ptr_const(ZigType *type) {
-    ZigType *ptr_type = get_src_ptr_type(type);
+bool get_ptr_const(CodeGen *g, ZigType *type) {
+    ZigType *ptr_type;
+    if (type->id == ZigTypeIdStruct) {
+        assert(type->data.structure.special == StructSpecialSlice);
+        TypeStructField *ptr_field = type->data.structure.fields[slice_ptr_index];
+        ptr_type = resolve_struct_field_type(g, ptr_field);
+    } else {
+        ptr_type = get_src_ptr_type(type);
+    }
     if (ptr_type->id == ZigTypeIdPointer) {
         return ptr_type->data.pointer.is_const;
     } else if (ptr_type->id == ZigTypeIdFn) {
src/analyze.hpp
@@ -76,7 +76,7 @@ void resolve_top_level_decl(CodeGen *g, Tld *tld, AstNode *source_node, bool all
 
 ZigType *get_src_ptr_type(ZigType *type);
 uint32_t get_ptr_align(CodeGen *g, ZigType *type);
-bool get_ptr_const(ZigType *type);
+bool get_ptr_const(CodeGen *g, ZigType *type);
 ZigType *validate_var_type(CodeGen *g, AstNode *source_node, ZigType *type_entry);
 ZigType *container_ref_type(ZigType *type_entry);
 bool type_is_complete(ZigType *type_entry);
src/ir.cpp
@@ -25479,11 +25479,22 @@ static IrInstGen *ir_analyze_instruction_err_set_cast(IrAnalyze *ira, IrInstSrcE
 static Error resolve_ptr_align(IrAnalyze *ira, ZigType *ty, uint32_t *result_align) {
     Error err;
 
-    ZigType *ptr_type = get_src_ptr_type(ty);
+    ZigType *ptr_type;
+    if (is_slice(ty)) {
+        TypeStructField *ptr_field = ty->data.structure.fields[slice_ptr_index];
+        ptr_type = resolve_struct_field_type(ira->codegen, ptr_field);
+    } else {
+        ptr_type = get_src_ptr_type(ty);
+    }
     assert(ptr_type != nullptr);
     if (ptr_type->id == ZigTypeIdPointer) {
         if ((err = type_resolve(ira->codegen, ptr_type->data.pointer.child_type, ResolveStatusAlignmentKnown)))
             return err;
+    } else if (is_slice(ptr_type)) {
+        TypeStructField *ptr_field = ptr_type->data.structure.fields[slice_ptr_index];
+        ZigType *slice_ptr_type = resolve_struct_field_type(ira->codegen, ptr_field);
+        if ((err = type_resolve(ira->codegen, slice_ptr_type->data.pointer.child_type, ResolveStatusAlignmentKnown)))
+            return err;
     }
 
     *result_align = get_ptr_align(ira->codegen, ty);
@@ -27615,10 +27626,18 @@ static IrInstGen *ir_analyze_ptr_cast(IrAnalyze *ira, IrInst* source_instr, IrIn
     // We have a check for zero bits later so we use get_src_ptr_type to
     // validate src_type and dest_type.
 
-    ZigType *src_ptr_type = get_src_ptr_type(src_type);
-    if (src_ptr_type == nullptr) {
-        ir_add_error(ira, ptr_src, buf_sprintf("expected pointer, found '%s'", buf_ptr(&src_type->name)));
-        return ira->codegen->invalid_inst_gen;
+    ZigType *if_slice_ptr_type;
+    if (is_slice(src_type)) {
+        TypeStructField *ptr_field = src_type->data.structure.fields[slice_ptr_index];
+        if_slice_ptr_type = resolve_struct_field_type(ira->codegen, ptr_field);
+    } else {
+        if_slice_ptr_type = src_type;
+
+        ZigType *src_ptr_type = get_src_ptr_type(src_type);
+        if (src_ptr_type == nullptr) {
+            ir_add_error(ira, ptr_src, buf_sprintf("expected pointer, found '%s'", buf_ptr(&src_type->name)));
+            return ira->codegen->invalid_inst_gen;
+        }
     }
 
     ZigType *dest_ptr_type = get_src_ptr_type(dest_type);
@@ -27628,7 +27647,7 @@ static IrInstGen *ir_analyze_ptr_cast(IrAnalyze *ira, IrInst* source_instr, IrIn
         return ira->codegen->invalid_inst_gen;
     }
 
-    if (get_ptr_const(src_type) && !get_ptr_const(dest_type)) {
+    if (get_ptr_const(ira->codegen, src_type) && !get_ptr_const(ira->codegen, dest_type)) {
         ir_add_error(ira, source_instr, buf_sprintf("cast discards const qualifier"));
         return ira->codegen->invalid_inst_gen;
     }
@@ -27646,7 +27665,10 @@ static IrInstGen *ir_analyze_ptr_cast(IrAnalyze *ira, IrInst* source_instr, IrIn
     if ((err = type_resolve(ira->codegen, src_type, ResolveStatusZeroBitsKnown)))
         return ira->codegen->invalid_inst_gen;
 
-    if (type_has_bits(ira->codegen, dest_type) && !type_has_bits(ira->codegen, src_type) && safety_check_on) {
+    if (safety_check_on &&
+        type_has_bits(ira->codegen, dest_type) &&
+        !type_has_bits(ira->codegen, if_slice_ptr_type))
+    {
         ErrorMsg *msg = ir_add_error(ira, source_instr,
             buf_sprintf("'%s' and '%s' do not have the same in-memory representation",
                 buf_ptr(&src_type->name), buf_ptr(&dest_type->name)));
@@ -27657,6 +27679,14 @@ static IrInstGen *ir_analyze_ptr_cast(IrAnalyze *ira, IrInst* source_instr, IrIn
         return ira->codegen->invalid_inst_gen;
     }
 
+    // For slices, follow the `ptr` field.
+    if (is_slice(src_type)) {
+        TypeStructField *ptr_field = src_type->data.structure.fields[slice_ptr_index];
+        IrInstGen *ptr_ref = ir_get_ref(ira, source_instr, ptr, true, false);
+        IrInstGen *ptr_ptr = ir_analyze_struct_field_ptr(ira, source_instr, ptr_field, ptr_ref, src_type, false);
+        ptr = ir_get_deref(ira, source_instr, ptr_ptr, nullptr);
+    }
+
     if (instr_is_comptime(ptr)) {
         bool dest_allows_addr_zero = ptr_allows_addr_zero(dest_type);
         UndefAllowed is_undef_allowed = dest_allows_addr_zero ? UndefOk : UndefBad;