Commit 9aea99a999

Andrew Kelley <superjoe30@gmail.com>
2016-01-07 13:29:11
implement array slicing syntax
closes #52
1 parent ea69d6e
doc/langref.md
@@ -148,7 +148,7 @@ CastExpression : CastExpression token(as) Type | PrefixOpExpression
 
 PrefixOpExpression : PrefixOp PrefixOpExpression | SuffixOpExpression
 
-SuffixOpExpression : PrimaryExpression option(FnCallExpression | ArrayAccessExpression | FieldAccessExpression)
+SuffixOpExpression : PrimaryExpression option(FnCallExpression | ArrayAccessExpression | FieldAccessExpression | SliceExpression)
 
 FieldAccessExpression : token(Dot) token(Symbol)
 
@@ -156,6 +156,8 @@ FnCallExpression : token(LParen) list(Expression, token(Comma)) token(RParen)
 
 ArrayAccessExpression : token(LBracket) Expression token(RBracket)
 
+SliceExpression : token(LBracket) Expression token(Ellipsis) option(Expression) token(RBracket) option(token(Const))
+
 PrefixOp : token(Not) | token(Dash) | token(Tilde) | token(Star) | (token(Ampersand) option(token(Const)))
 
 PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
src/analyze.cpp
@@ -23,6 +23,8 @@ static AstNode *first_executing_node(AstNode *node) {
             return first_executing_node(node->data.bin_op_expr.op1);
         case NodeTypeArrayAccessExpr:
             return first_executing_node(node->data.array_access_expr.array_ref_expr);
+        case NodeTypeSliceExpr:
+            return first_executing_node(node->data.slice_expr.array_ref_expr);
         case NodeTypeFieldAccessExpr:
             return first_executing_node(node->data.field_access_expr.struct_expr);
         case NodeTypeCastExpr:
@@ -875,6 +877,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import,
         case NodeTypeBinOpExpr:
         case NodeTypeFnCallExpr:
         case NodeTypeArrayAccessExpr:
+        case NodeTypeSliceExpr:
         case NodeTypeNumberLiteral:
         case NodeTypeStringLiteral:
         case NodeTypeCharLiteral:
@@ -950,6 +953,7 @@ static void preview_types(CodeGen *g, ImportTableEntry *import, AstNode *node) {
         case NodeTypeBinOpExpr:
         case NodeTypeFnCallExpr:
         case NodeTypeArrayAccessExpr:
+        case NodeTypeSliceExpr:
         case NodeTypeNumberLiteral:
         case NodeTypeStringLiteral:
         case NodeTypeCharLiteral:
@@ -1349,6 +1353,50 @@ static TypeTableEntry *analyze_field_access_expr(CodeGen *g, ImportTableEntry *i
     return return_type;
 }
 
+static TypeTableEntry *analyze_slice_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+        AstNode *node)
+{
+    TypeTableEntry *array_type = analyze_expression(g, import, context, nullptr,
+            node->data.slice_expr.array_ref_expr);
+
+    TypeTableEntry *return_type;
+
+    if (array_type->id == TypeTableEntryIdInvalid) {
+        return_type = g->builtin_types.entry_invalid;
+    } else if (array_type->id == TypeTableEntryIdArray) {
+        return_type = get_unknown_size_array_type(g, import, array_type->data.array.child_type,
+                node->data.slice_expr.is_const);
+    } else if (array_type->id == TypeTableEntryIdPointer) {
+        return_type = get_unknown_size_array_type(g, import, array_type->data.pointer.child_type,
+                node->data.slice_expr.is_const);
+    } else if (array_type->id == TypeTableEntryIdStruct &&
+               array_type->data.structure.is_unknown_size_array)
+    {
+        return_type = get_unknown_size_array_type(g, import,
+                array_type->data.structure.fields[0].type_entry,
+                node->data.slice_expr.is_const);
+    } else {
+        add_node_error(g, node,
+            buf_sprintf("slice of non-array type '%s'", buf_ptr(&array_type->name)));
+        return_type = g->builtin_types.entry_invalid;
+    }
+
+    if (return_type->id != TypeTableEntryIdInvalid) {
+        assert(node->codegen_node);
+        node->codegen_node->data.struct_val_expr_node.type_entry = return_type;
+        node->codegen_node->data.struct_val_expr_node.source_node = node;
+        context->struct_val_expr_alloca_list.append(&node->codegen_node->data.struct_val_expr_node);
+    }
+
+    analyze_expression(g, import, context, g->builtin_types.entry_usize, node->data.slice_expr.start);
+
+    if (node->data.slice_expr.end) {
+        analyze_expression(g, import, context, g->builtin_types.entry_usize, node->data.slice_expr.end);
+    }
+
+    return return_type;
+}
+
 static TypeTableEntry *analyze_array_access_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         AstNode *node)
 {
@@ -1363,7 +1411,8 @@ static TypeTableEntry *analyze_array_access_expr(CodeGen *g, ImportTableEntry *i
         return_type = array_type->data.pointer.child_type;
     } else {
         if (array_type->id != TypeTableEntryIdInvalid) {
-            add_node_error(g, node, buf_sprintf("array access of non-array"));
+            add_node_error(g, node,
+                    buf_sprintf("array access of non-array type '%s'", buf_ptr(&array_type->name)));
         }
         return_type = g->builtin_types.entry_invalid;
     }
@@ -2197,6 +2246,9 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
             // for reading array access; assignment handled elsewhere
             return_type = analyze_array_access_expr(g, import, context, node);
             break;
+        case NodeTypeSliceExpr:
+            return_type = analyze_slice_expr(g, import, context, node);
+            break;
         case NodeTypeFieldAccessExpr:
             return_type = analyze_field_access_expr(g, import, context, node);
             break;
@@ -2541,6 +2593,7 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import,
         case NodeTypeBinOpExpr:
         case NodeTypeFnCallExpr:
         case NodeTypeArrayAccessExpr:
+        case NodeTypeSliceExpr:
         case NodeTypeNumberLiteral:
         case NodeTypeStringLiteral:
         case NodeTypeCharLiteral:
src/analyze.hpp
@@ -355,9 +355,11 @@ struct CodeGenNode {
         StructDeclNode struct_decl_node; // for NodeTypeStructDecl
         FieldAccessNode field_access_node; // for NodeTypeFieldAccessExpr
         CastNode cast_node; // for NodeTypeCastExpr
+        // note: I've been using this field on some non-number literal nodes too.
         NumberLiteralNode num_lit_node; // for NodeTypeNumberLiteral
         VarDeclNode var_decl_node; // for NodeTypeVariableDeclaration
         StructValFieldNode struct_val_field_node; // for NodeTypeStructValueField
+        // note: I've been using this field on some non-struct val expressions too.
         StructValExprNode struct_val_expr_node; // for NodeTypeStructValueExpr
         IfVarNode if_var_node; // for NodeTypeStructValueExpr
         ParamDeclNode param_decl_node; // for NodeTypeParamDecl
src/codegen.cpp
@@ -215,26 +215,34 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
     }
 }
 
-static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) {
-    assert(node->type == NodeTypeArrayAccessExpr);
-
-    AstNode *array_expr_node = node->data.array_access_expr.array_ref_expr;
-    TypeTableEntry *type_entry = get_expr_type(array_expr_node);
+static LLVMValueRef gen_array_base_ptr(CodeGen *g, AstNode *node) {
+    TypeTableEntry *type_entry = get_expr_type(node);
 
     LLVMValueRef array_ptr;
-    if (array_expr_node->type == NodeTypeFieldAccessExpr) {
-        array_ptr = gen_field_access_expr(g, array_expr_node, true);
+    if (node->type == NodeTypeFieldAccessExpr) {
+        array_ptr = gen_field_access_expr(g, node, true);
         if (type_entry->id == TypeTableEntryIdPointer) {
             // we have a double pointer so we must dereference it once
             add_debug_source_node(g, node);
             array_ptr = LLVMBuildLoad(g->builder, array_ptr, "");
         }
     } else {
-        array_ptr = gen_expr(g, array_expr_node);
+        array_ptr = gen_expr(g, node);
     }
 
     assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind);
 
+    return array_ptr;
+}
+
+static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeArrayAccessExpr);
+
+    AstNode *array_expr_node = node->data.array_access_expr.array_ref_expr;
+    TypeTableEntry *type_entry = get_expr_type(array_expr_node);
+
+    LLVMValueRef array_ptr = gen_array_base_ptr(g, array_expr_node);
+
     LLVMValueRef subscript_value = gen_expr(g, node->data.array_access_expr.subscript);
     assert(subscript_value);
 
@@ -299,6 +307,48 @@ static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node, TypeTableEntry **ou
     return LLVMBuildStructGEP(g->builder, struct_ptr, codegen_field_access->field_index, "");
 }
 
+static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeSliceExpr);
+
+    AstNode *array_ref_node = node->data.slice_expr.array_ref_expr;
+    TypeTableEntry *array_type = get_expr_type(array_ref_node);
+
+    LLVMValueRef tmp_struct_ptr = node->codegen_node->data.struct_val_expr_node.ptr;
+
+    if (array_type->id == TypeTableEntryIdArray) {
+        LLVMValueRef array_ptr = gen_array_base_ptr(g, array_ref_node);
+        LLVMValueRef start_val = gen_expr(g, node->data.slice_expr.start);
+        LLVMValueRef end_val;
+        if (node->data.slice_expr.end) {
+            end_val = gen_expr(g, node->data.slice_expr.end);
+        } else {
+            end_val = LLVMConstInt(g->builtin_types.entry_usize->type_ref, array_type->data.array.len, false);
+        }
+
+        add_debug_source_node(g, node);
+        LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, "");
+        LLVMValueRef indices[] = {
+            LLVMConstNull(g->builtin_types.entry_usize->type_ref),
+            start_val,
+        };
+        LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, indices, 2, "");
+        LLVMBuildStore(g->builder, slice_start_ptr, ptr_field_ptr);
+
+        LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 1, "");
+        LLVMValueRef len_value = LLVMBuildSub(g->builder, end_val, start_val, "");
+        LLVMBuildStore(g->builder, len_value, len_field_ptr);
+
+        return tmp_struct_ptr;
+    } else if (array_type->id == TypeTableEntryIdPointer) {
+        zig_panic("TODO gen_slice_expr pointer");
+    } else if (array_type->id == TypeTableEntryIdStruct) {
+        assert(array_type->data.structure.is_unknown_size_array);
+        zig_panic("TODO gen_slice_expr unknown size array");
+    } else {
+        zig_unreachable();
+    }
+}
+
 static LLVMValueRef gen_array_access_expr(CodeGen *g, AstNode *node, bool is_lvalue) {
     assert(node->type == NodeTypeArrayAccessExpr);
 
@@ -1443,6 +1493,8 @@ static LLVMValueRef gen_expr_no_cast(CodeGen *g, AstNode *node) {
             return gen_fn_call_expr(g, node);
         case NodeTypeArrayAccessExpr:
             return gen_array_access_expr(g, node, false);
+        case NodeTypeSliceExpr:
+            return gen_slice_expr(g, node);
         case NodeTypeFieldAccessExpr:
             return gen_field_access_expr(g, node, false);
         case NodeTypeUnreachable:
src/parser.cpp
@@ -90,6 +90,8 @@ const char *node_type_str(NodeType node_type) {
             return "FnCallExpr";
         case NodeTypeArrayAccessExpr:
             return "ArrayAccessExpr";
+        case NodeTypeSliceExpr:
+            return "SliceExpr";
         case NodeTypeExternBlock:
             return "ExternBlock";
         case NodeTypeDirective:
@@ -298,6 +300,14 @@ void ast_print(AstNode *node, int indent) {
             ast_print(node->data.array_access_expr.array_ref_expr, indent + 2);
             ast_print(node->data.array_access_expr.subscript, indent + 2);
             break;
+        case NodeTypeSliceExpr:
+            fprintf(stderr, "%s\n", node_type_str(node->type));
+            ast_print(node->data.slice_expr.array_ref_expr, indent + 2);
+            ast_print(node->data.slice_expr.start, indent + 2);
+            if (node->data.slice_expr.end) {
+                ast_print(node->data.slice_expr.end, indent + 2);
+            }
+            break;
         case NodeTypeDirective:
             fprintf(stderr, "%s\n", node_type_str(node->type));
             break;
@@ -1381,9 +1391,10 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
 }
 
 /*
-SuffixOpExpression : PrimaryExpression option(FnCallExpression | ArrayAccessExpression | FieldAccessExpression)
+SuffixOpExpression : PrimaryExpression option(FnCallExpression | ArrayAccessExpression | FieldAccessExpression | SliceExpression)
 FnCallExpression : token(LParen) list(Expression, token(Comma)) token(RParen)
 ArrayAccessExpression : token(LBracket) Expression token(RBracket)
+SliceExpression : token(LBracket) Expression token(Ellipsis) option(Expression) token(RBracket) option(token(Const))
 FieldAccessExpression : token(Dot) token(Symbol)
 */
 static AstNode *ast_parse_suffix_op_expr(ParseContext *pc, int *token_index, bool mandatory) {
@@ -1405,15 +1416,38 @@ static AstNode *ast_parse_suffix_op_expr(ParseContext *pc, int *token_index, boo
         } else if (token->id == TokenIdLBracket) {
             *token_index += 1;
 
-            AstNode *node = ast_create_node(pc, NodeTypeArrayAccessExpr, token);
-            node->data.array_access_expr.array_ref_expr = primary_expr;
-            node->data.array_access_expr.subscript = ast_parse_expression(pc, token_index, true);
+            AstNode *expr_node = ast_parse_expression(pc, token_index, true);
 
-            Token *r_bracket = &pc->tokens->at(*token_index);
-            *token_index += 1;
-            ast_expect_token(pc, r_bracket, TokenIdRBracket);
+            Token *ellipsis_or_r_bracket = &pc->tokens->at(*token_index);
 
-            primary_expr = node;
+            if (ellipsis_or_r_bracket->id == TokenIdEllipsis) {
+                *token_index += 1;
+
+                AstNode *node = ast_create_node(pc, NodeTypeSliceExpr, token);
+                node->data.slice_expr.array_ref_expr = primary_expr;
+                node->data.slice_expr.start = expr_node;
+                node->data.slice_expr.end = ast_parse_expression(pc, token_index, false);
+
+                ast_eat_token(pc, token_index, TokenIdRBracket);
+
+                Token *const_tok = &pc->tokens->at(*token_index);
+                if (const_tok->id == TokenIdKeywordConst) {
+                    *token_index += 1;
+                    node->data.slice_expr.is_const = true;
+                }
+
+                primary_expr = node;
+            } else if (ellipsis_or_r_bracket->id == TokenIdRBracket) {
+                *token_index += 1;
+
+                AstNode *node = ast_create_node(pc, NodeTypeArrayAccessExpr, token);
+                node->data.array_access_expr.array_ref_expr = primary_expr;
+                node->data.array_access_expr.subscript = expr_node;
+
+                primary_expr = node;
+            } else {
+                ast_invalid_token_error(pc, token);
+            }
         } else if (token->id == TokenIdDot) {
             *token_index += 1;
 
src/parser.hpp
@@ -41,6 +41,7 @@ enum NodeType {
     NodeTypePrefixOpExpr,
     NodeTypeFnCallExpr,
     NodeTypeArrayAccessExpr,
+    NodeTypeSliceExpr,
     NodeTypeFieldAccessExpr,
     NodeTypeUse,
     NodeTypeVoid,
@@ -181,6 +182,13 @@ struct AstNodeArrayAccessExpr {
     AstNode *subscript;
 };
 
+struct AstNodeSliceExpr {
+    AstNode *array_ref_expr;
+    AstNode *start;
+    AstNode *end;
+    bool is_const;
+};
+
 struct AstNodeFieldAccessExpr {
     AstNode *struct_expr;
     Buf field_name;
@@ -378,6 +386,7 @@ struct AstNode {
         AstNodePrefixOpExpr prefix_op_expr;
         AstNodeFnCallExpr fn_call_expr;
         AstNodeArrayAccessExpr array_access_expr;
+        AstNodeSliceExpr slice_expr;
         AstNodeUse use;
         AstNodeIfBoolExpr if_bool_expr;
         AstNodeIfVarExpr if_var_expr;
test/run_tests.cpp
@@ -907,6 +907,35 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
         "min i16: -32768\n"
         "min i32: -2147483648\n"
         "min i64: -9223372036854775808\n");
+
+
+    add_simple_case("slicing", R"SOURCE(
+use "std.zig";
+pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
+    var array : [20]i32;
+
+    array[5] = 1234;
+
+    var slice = array[5...10];
+
+    if (slice.len != 5) {
+        print_str("BAD\n");
+    }
+
+    if (slice.ptr[0] != 1234) {
+        print_str("BAD\n");
+    }
+
+    var slice_rest = array[10...];
+    if (slice_rest.len != 10) {
+        print_str("BAD\n");
+    }
+
+    print_str("OK\n");
+    return 0;
+}
+    )SOURCE", "OK\n");
+
 }
 
 ////////////////////////////////////////////////////////////////////////////////////