Commit 187d00ca83

Andrew Kelley <superjoe30@gmail.com>
2016-01-03 03:47:36
ability to access pointers with array indexing syntax
closes #40
1 parent 968b85a
Changed files (5)
example/arrays/arrays.zig
@@ -2,37 +2,26 @@ export executable "arrays";
 
 use "std.zig";
 
-export fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
-    var array : [i32; 5];
+pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
+    var array : [u32; 5];
 
-    var i : i32 = 0;
-loop_start:
-    if i == 5 {
-        goto loop_end;
+    var i : u32 = 0;
+    while (i < 5) {
+        array[i] = i + 1;
+        i = array[i];
     }
-    array[i] = i + 1;
-    i = array[i];
-    goto loop_start;
-
-loop_end:
 
     i = 0;
-    var accumulator : i32 = 0;
-loop_2_start:
-    if i == 5 {
-        goto loop_2_end;
-    }
-
-    accumulator += array[i];
+    var accumulator : u32 = 0;
+    while (i < 5) {
+        accumulator += array[i];
 
-    i = i + 1;
-    goto loop_2_start;
-loop_2_end:
-
-    if accumulator == 15 {
-        print_str("OK\n" as string);
+        i += 1;
     }
 
+    if (accumulator == 15) {
+        print_str("OK\n");
+    }
 
     return 0;
 }
src/analyze.cpp
@@ -1057,6 +1057,8 @@ static TypeTableEntry *analyze_array_access_expr(CodeGen *g, ImportTableEntry *i
 
     if (array_type->id == TypeTableEntryIdArray) {
         return_type = array_type->data.array.child_type;
+    } else if (array_type->id == TypeTableEntryIdPointer) {
+        return_type = array_type->data.pointer.child_type;
     } else {
         if (array_type->id != TypeTableEntryIdInvalid) {
             add_node_error(g, node, buf_sprintf("array access of non-array"));
@@ -1064,14 +1066,7 @@ static TypeTableEntry *analyze_array_access_expr(CodeGen *g, ImportTableEntry *i
         return_type = g->builtin_types.entry_invalid;
     }
 
-    TypeTableEntry *subscript_type = analyze_expression(g, import, context, nullptr,
-            node->data.array_access_expr.subscript);
-    if (subscript_type->id != TypeTableEntryIdInt &&
-        subscript_type->id != TypeTableEntryIdInvalid)
-    {
-        add_node_error(g, node,
-            buf_sprintf("array subscripts must be integers"));
-    }
+    analyze_expression(g, import, context, g->builtin_types.entry_usize, node->data.array_access_expr.subscript);
 
     return return_type;
 }
@@ -1150,7 +1145,7 @@ static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, B
     cast_node->after_type = wanted_type;
 
     // special casing this for now, TODO think about casting and do a general solution
-    if (wanted_type == g->builtin_types.entry_isize &&
+    if ((wanted_type == g->builtin_types.entry_isize || wanted_type == g->builtin_types.entry_usize) &&
         actual_type->id == TypeTableEntryIdPointer)
     {
         cast_node->op = CastOpPtrToInt;
src/codegen.cpp
@@ -199,19 +199,48 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
 static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeArrayAccessExpr);
 
-    // TODO gen_lvalue
-    LLVMValueRef array_ref_value = gen_expr(g, node->data.array_access_expr.array_ref_expr);
+    TypeTableEntry *type_entry = get_expr_type(node->data.array_access_expr.array_ref_expr);
+    AstNode *array_expr_node = node->data.array_access_expr.array_ref_expr;
+
+    LLVMValueRef array_ptr = gen_expr(g, array_expr_node);
+    /*
+    if (array_expr_node->type == NodeTypeSymbol) {
+        VariableTableEntry *var = find_variable(array_expr_node->codegen_node->expr_node.block_context,
+                &array_expr_node->data.symbol);
+        assert(var);
+
+        array_ptr = var->value_ref;
+    } else if (array_expr_node->type == NodeTypeFieldAccessExpr) {
+        zig_panic("TODO gen array ptr field access expr");
+    } else if (array_expr_node->type == NodeTypeArrayAccessExpr) {
+        zig_panic("TODO gen array ptr array access expr");
+    } else {
+        array_ptr = gen_expr(g, array_expr_node);
+    }
+    */
+
     LLVMValueRef subscript_value = gen_expr(g, node->data.array_access_expr.subscript);
 
-    assert(array_ref_value);
+    assert(array_ptr);
     assert(subscript_value);
 
-    LLVMValueRef indices[] = {
-        LLVMConstInt(LLVMInt32Type(), 0, false),
-        subscript_value
-    };
-    add_debug_source_node(g, node);
-    return LLVMBuildInBoundsGEP(g->builder, array_ref_value, indices, 2, "");
+    if (type_entry->id == TypeTableEntryIdArray) {
+        LLVMValueRef indices[] = {
+            LLVMConstNull(g->builtin_types.entry_usize->type_ref),
+            subscript_value
+        };
+        add_debug_source_node(g, node);
+        return LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, "");
+    } else if (type_entry->id == TypeTableEntryIdPointer) {
+        assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind);
+        LLVMValueRef indices[] = {
+            subscript_value
+        };
+        add_debug_source_node(g, node);
+        return LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 1, "");
+    } else {
+        zig_unreachable();
+    }
 }
 
 static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node, TypeTableEntry **out_type_entry) {
@@ -279,6 +308,14 @@ static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node, bool is_lva
         if (buf_eql_str(name, "len")) {
             return LLVMConstInt(g->builtin_types.entry_usize->type_ref,
                     struct_type->data.array.len, false);
+        } else if (buf_eql_str(name, "ptr")) {
+            LLVMValueRef array_val = gen_expr(g, node->data.field_access_expr.struct_expr);
+            LLVMValueRef indices[] = {
+                LLVMConstNull(g->builtin_types.entry_usize->type_ref),
+                LLVMConstNull(g->builtin_types.entry_usize->type_ref),
+            };
+            add_debug_source_node(g, node);
+            return LLVMBuildInBoundsGEP(g->builder, array_val, indices, 2, "");
         } else {
             zig_panic("gen_field_access_expr bad array field");
         }
@@ -314,9 +351,15 @@ static LLVMValueRef gen_lvalue(CodeGen *g, AstNode *expr_node, AstNode *node,
         target_ref = var->value_ref;
     } else if (node->type == NodeTypeArrayAccessExpr) {
         TypeTableEntry *array_type = get_expr_type(node->data.array_access_expr.array_ref_expr);
-        assert(array_type->id == TypeTableEntryIdArray);
-        *out_type_entry = array_type->data.array.child_type;
-        target_ref = gen_array_ptr(g, node);
+        if (array_type->id == TypeTableEntryIdArray) {
+            *out_type_entry = array_type->data.array.child_type;
+            target_ref = gen_array_ptr(g, node);
+        } else if (array_type->id == TypeTableEntryIdPointer) {
+            *out_type_entry = array_type->data.pointer.child_type;
+            target_ref = gen_array_ptr(g, node);
+        } else {
+            zig_unreachable();
+        }
     } else if (node->type == NodeTypeFieldAccessExpr) {
         target_ref = gen_field_ptr(g, node, out_type_entry);
     } else {
@@ -389,28 +432,26 @@ static LLVMValueRef gen_bare_cast(CodeGen *g, AstNode *node, LLVMValueRef expr_v
                 return cast_node->ptr;
             }
         case CastOpPtrToInt:
+            add_debug_source_node(g, node);
             return LLVMBuildPtrToInt(g->builder, expr_val, wanted_type->type_ref, "");
         case CastOpPointerReinterpret:
+            add_debug_source_node(g, node);
             return LLVMBuildBitCast(g->builder, expr_val, wanted_type->type_ref, "");
         case CastOpIntWidenOrShorten:
             if (actual_type->size_in_bits == wanted_type->size_in_bits) {
                 return expr_val;
             } else if (actual_type->size_in_bits < wanted_type->size_in_bits) {
-                if (actual_type->data.integral.is_signed && wanted_type->data.integral.is_signed) {
+                if (actual_type->data.integral.is_signed) {
+                    add_debug_source_node(g, node);
                     return LLVMBuildSExt(g->builder, expr_val, wanted_type->type_ref, "");
-                } else if (!actual_type->data.integral.is_signed && !wanted_type->data.integral.is_signed) {
-                    return LLVMBuildZExt(g->builder, expr_val, wanted_type->type_ref, "");
                 } else {
-                    zig_panic("TODO gen_cast_expr mixing of signness");
+                    add_debug_source_node(g, node);
+                    return LLVMBuildZExt(g->builder, expr_val, wanted_type->type_ref, "");
                 }
             } else {
                 assert(actual_type->size_in_bits > wanted_type->size_in_bits);
-
-                if (actual_type->data.integral.is_signed && wanted_type->data.integral.is_signed) {
-                    return LLVMBuildTrunc(g->builder, expr_val, wanted_type->type_ref, "");
-                } else {
-                    zig_panic("TODO gen_cast_expr shorten unsigned");
-                }
+                add_debug_source_node(g, node);
+                return LLVMBuildTrunc(g->builder, expr_val, wanted_type->type_ref, "");
             }
         case CastOpArrayToString:
             {
@@ -1232,8 +1273,8 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
                 Buf *str = &node->data.string_literal.buf;
                 LLVMValueRef str_val = find_or_create_string(g, str, node->data.string_literal.c);
                 LLVMValueRef indices[] = {
-                    LLVMConstInt(LLVMInt32Type(), 0, false),
-                    LLVMConstInt(LLVMInt32Type(), 0, false)
+                    LLVMConstNull(g->builtin_types.entry_usize->type_ref),
+                    LLVMConstNull(g->builtin_types.entry_usize->type_ref),
                 };
                 LLVMValueRef ptr_val = LLVMBuildInBoundsGEP(g->builder, str_val, indices, 2, "");
                 return ptr_val;
std/std.zig
@@ -1,40 +1,37 @@
-const SYS_write : isize = 1;
-const SYS_exit : isize = 60;
-const SYS_getrandom : isize = 278;
+const SYS_write : usize = 1;
+const SYS_exit : usize = 60;
+const SYS_getrandom : usize = 278;
 
 const stdout_fileno : isize = 1;
 const stderr_fileno : isize = 2;
 
-fn syscall1(number: isize, arg1: isize) -> isize {
+fn syscall1(number: usize, arg1: usize) -> usize {
     asm volatile ("syscall"
-        : [ret] "={rax}" (-> isize)
+        : [ret] "={rax}" (-> usize)
         : [number] "{rax}" (number), [arg1] "{rdi}" (arg1)
         : "rcx", "r11")
 }
 
-fn syscall3(number: isize, arg1: isize, arg2: isize, arg3: isize) -> isize {
+fn syscall3(number: usize, arg1: usize, arg2: usize, arg3: usize) -> usize {
     asm volatile ("syscall"
-        : [ret] "={rax}" (-> isize)
+        : [ret] "={rax}" (-> usize)
         : [number] "{rax}" (number), [arg1] "{rdi}" (arg1), [arg2] "{rsi}" (arg2), [arg3] "{rdx}" (arg3)
         : "rcx", "r11")
 }
 
-/*
 pub fn getrandom(buf: &u8, count: usize, flags: u32) -> isize {
-    return syscall3(SYS_getrandom, buf as isize, count as isize, flags as isize);
+    return syscall3(SYS_getrandom, buf as usize, count, flags as usize) as isize;
 }
-*/
 
 pub fn write(fd: isize, buf: &const u8, count: usize) -> isize {
-    return syscall3(SYS_write, fd, buf as isize, count as isize);
+    return syscall3(SYS_write, fd as usize, buf as usize, count) as isize;
 }
 
 pub fn exit(status: i32) -> unreachable {
-    syscall1(SYS_exit, status as isize);
+    syscall1(SYS_exit, status as usize);
     unreachable;
 }
 
-/*
 fn digit_to_char(digit: u64) -> u8 { '0' + (digit as u8) }
 
 const max_u64_base10_digits: usize = 20;
@@ -66,17 +63,6 @@ fn buf_print_u64(out_buf: &u8, x: u64) -> usize {
     return len;
 }
 
-// TODO handle buffering and flushing (mutex protected)
-// TODO error handling
-pub fn print_u64(x: u64) -> isize {
-    // TODO use max_u64_base10_digits instead of hardcoding 20
-    var buf: [u8; 20];
-    const len = buf_print_u64(buf.ptr, x);
-    return write(stdout_fileno, buf.ptr, len);
-}
-*/
-
-
 // TODO error handling
 // TODO handle buffering and flushing (mutex protected)
 pub fn print_str(str: string) -> isize { fprint_str(stdout_fileno, str) }
@@ -87,9 +73,16 @@ pub fn fprint_str(fd: isize, str: string) -> isize {
     return write(fd, str.ptr, str.len);
 }
 
-/*
+// TODO handle buffering and flushing (mutex protected)
+// TODO error handling
+pub fn print_u64(x: u64) -> isize {
+    // TODO use max_u64_base10_digits instead of hardcoding 20
+    var buf: [u8; 20];
+    const len = buf_print_u64(buf.ptr, x);
+    return write(stdout_fileno, buf.ptr, len);
+}
+
 // TODO error handling
 pub fn os_get_random_bytes(buf: &u8, count: usize) -> isize {
     return getrandom(buf, count, 0);
 }
-*/
test/run_tests.cpp
@@ -319,31 +319,21 @@ done:
 use "std.zig";
 
 pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
-    var array : [i32; 5];
+    var array : [u32; 5];
 
-    var i : i32 = 0;
-loop_start:
-    if (i == 5) {
-        goto loop_end;
+    var i : u32 = 0;
+    while (i < 5) {
+        array[i] = i + 1;
+        i = array[i];
     }
-    array[i] = i + 1;
-    i = array[i];
-    goto loop_start;
-
-loop_end:
 
     i = 0;
-    var accumulator = 0 as i32;
-loop_2_start:
-    if (i == 5) {
-        goto loop_2_end;
-    }
+    var accumulator = 0 as u32;
+    while (i < 5) {
+        accumulator += array[i];
 
-    accumulator = accumulator + array[i];
-
-    i = i + 1;
-    goto loop_2_start;
-loop_2_end:
+        i += 1;
+    }
 
     if (accumulator == 15) {
         print_str("OK\n");
@@ -871,9 +861,9 @@ fn f() {
                  ".tmp_source.zig:4:12: error: use of undeclared identifier 'i'",
                  ".tmp_source.zig:4:14: error: use of undeclared identifier 'i'",
                  ".tmp_source.zig:5:8: error: array access of non-array",
-                 ".tmp_source.zig:5:8: error: array subscripts must be integers",
+                 ".tmp_source.zig:5:9: error: expected type 'usize', got 'bool'",
                  ".tmp_source.zig:5:19: error: array access of non-array",
-                 ".tmp_source.zig:5:19: error: array subscripts must be integers");
+                 ".tmp_source.zig:5:20: error: expected type 'usize', got 'bool'");
 
     add_compile_fail_case("variadic functions only allowed in extern", R"SOURCE(
 fn f(...) {}