Commit 18ed87c695

Andrew Kelley <superjoe30@gmail.com>
2016-05-08 09:59:21
ability to cast u8 slice to bigger slice
1 parent aed96e3
src/analyze.cpp
@@ -4240,17 +4240,16 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
         return resolve_cast(g, context, node, expr_node, wanted_type, CastOpToUnknownSizeArray, true);
     }
 
-    // explicit cast from []T to []u8
-    if (is_slice(wanted_type) &&
-        is_u8(wanted_type->data.structure.fields[0].type_entry->data.pointer.child_type) &&
-        is_slice(actual_type) &&
+    // explicit cast from []T to []u8 or []u8 to []T
+    if (is_slice(wanted_type) && is_slice(actual_type) &&
+        (is_u8(wanted_type->data.structure.fields[0].type_entry->data.pointer.child_type) ||
+        is_u8(actual_type->data.structure.fields[0].type_entry->data.pointer.child_type)) &&
         (wanted_type->data.structure.fields[0].type_entry->data.pointer.is_const ||
          !actual_type->data.structure.fields[0].type_entry->data.pointer.is_const))
     {
         return resolve_cast(g, context, node, expr_node, wanted_type, CastOpResizeSlice, true);
     }
 
-
     // explicit cast from pointer to another pointer
     if ((actual_type->id == TypeTableEntryIdPointer || actual_type->id == TypeTableEntryIdFn) &&
         (wanted_type->id == TypeTableEntryIdPointer || wanted_type->id == TypeTableEntryIdFn))
src/codegen.cpp
@@ -888,6 +888,8 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
 
                 TypeTableEntry *actual_pointer_type = actual_type->data.structure.fields[0].type_entry;
                 TypeTableEntry *actual_child_type = actual_pointer_type->data.pointer.child_type;
+                TypeTableEntry *wanted_pointer_type = wanted_type->data.structure.fields[0].type_entry;
+                TypeTableEntry *wanted_child_type = wanted_pointer_type->data.pointer.child_type;
 
                 set_debug_source_node(g, node);
 
@@ -896,15 +898,6 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
                 int wanted_ptr_index = wanted_type->data.structure.fields[0].gen_index;
                 int wanted_len_index = wanted_type->data.structure.fields[1].gen_index;
 
-                LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, expr_val, actual_len_index, "");
-                LLVMValueRef src_len = LLVMBuildLoad(g->builder, src_len_ptr, "");
-                LLVMValueRef src_size = LLVMConstInt(g->builtin_types.entry_isize->type_ref,
-                        type_size(g, actual_child_type), false);
-                LLVMValueRef new_len = LLVMBuildMul(g->builder, src_len, src_size, "");
-                LLVMValueRef dest_len_ptr = LLVMBuildStructGEP(g->builder, cast_expr->tmp_ptr,
-                        wanted_len_index, "");
-                LLVMBuildStore(g->builder, new_len, dest_len_ptr);
-
                 LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, expr_val, actual_ptr_index, "");
                 LLVMValueRef src_ptr = LLVMBuildLoad(g->builder, src_ptr_ptr, "");
                 LLVMValueRef src_ptr_casted = LLVMBuildBitCast(g->builder, src_ptr,
@@ -913,6 +906,40 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
                         wanted_ptr_index, "");
                 LLVMBuildStore(g->builder, src_ptr_casted, dest_ptr_ptr);
 
+                LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, expr_val, actual_len_index, "");
+                LLVMValueRef src_len = LLVMBuildLoad(g->builder, src_len_ptr, "");
+                uint64_t src_size = type_size(g, actual_child_type);
+                uint64_t dest_size = type_size(g, wanted_child_type);
+
+                LLVMValueRef new_len;
+                if (dest_size == 1) {
+                    LLVMValueRef src_size_val = LLVMConstInt(g->builtin_types.entry_isize->type_ref, src_size, false);
+                    new_len = LLVMBuildMul(g->builder, src_len, src_size_val, "");
+                } else if (src_size == 1) {
+                    LLVMValueRef dest_size_val = LLVMConstInt(g->builtin_types.entry_isize->type_ref, dest_size, false);
+                    if (want_debug_safety(g, node)) {
+                        LLVMValueRef remainder_val = LLVMBuildURem(g->builder, src_len, dest_size_val, "");
+                        LLVMValueRef zero = LLVMConstNull(g->builtin_types.entry_isize->type_ref);
+                        LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
+                        LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "SliceWidenOk");
+                        LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "SliceWidenFail");
+                        LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
+
+                        LLVMPositionBuilderAtEnd(g->builder, fail_block);
+                        gen_debug_safety_crash(g);
+
+                        LLVMPositionBuilderAtEnd(g->builder, ok_block);
+                    }
+                    new_len = ZigLLVMBuildExactUDiv(g->builder, src_len, dest_size_val, "");
+                } else {
+                    zig_unreachable();
+                }
+
+                LLVMValueRef dest_len_ptr = LLVMBuildStructGEP(g->builder, cast_expr->tmp_ptr,
+                        wanted_len_index, "");
+                LLVMBuildStore(g->builder, new_len, dest_len_ptr);
+
+
                 return cast_expr->tmp_ptr;
             }
         case CastOpIntToFloat:
test/run_tests.cpp
@@ -1466,6 +1466,16 @@ fn div_exact(a: i32, b: i32) -> i32 {
 }
     )SOURCE");
 
+    add_debug_safety_case("cast []u8 to bigger slice of wrong size", R"SOURCE(
+pub fn main(args: [][]u8) -> %void {
+    widen_slice([]u8{1, 2, 3, 4, 5});
+}
+#static_eval_enable(false)
+fn widen_slice(slice: []u8) -> []i32 {
+    ([]i32)(slice)
+}
+    )SOURCE");
+
 }
 
 //////////////////////////////////////////////////////////////////////////////
test/self_hosted.zig
@@ -1593,6 +1593,13 @@ fn cast_slice_to_u8_slice() {
     bytes[6] = 0;
     bytes[7] = 0;
     assert(big_thing_slice[1] == 0);
+    const big_thing_again = ([]i32)(bytes);
+    assert(big_thing_again[2] == 3);
+    big_thing_again[2] = -1;
+    assert(bytes[8] == @max_value(u8));
+    assert(bytes[9] == @max_value(u8));
+    assert(bytes[10] == @max_value(u8));
+    assert(bytes[11] == @max_value(u8));
 }
 
 #attribute("test")