Commit a3e288ab5b

Andrew Kelley <superjoe30@gmail.com>
2016-01-26 05:56:29
implement compile time string concatenation
See #76
1 parent 1d68150
doc/langref.md
@@ -111,7 +111,7 @@ BitShiftOperator : "<<" | ">>"
 
 AdditionExpression : MultiplyExpression AdditionOperator AdditionExpression | MultiplyExpression
 
-AdditionOperator : "+" | "-"
+AdditionOperator : "+" | "-" | "++"
 
 MultiplyExpression : CurlySuffixExpression MultiplyOperator MultiplyExpression | CurlySuffixExpression
 
@@ -157,7 +157,7 @@ x() x[] x.y
 !x -x ~x *x &x ?x %x %%x
 x{}
 * / %
-+ -
++ - ++
 << >>
 &
 ^
src/all_types.hpp
@@ -296,6 +296,7 @@ enum BinOpType {
     BinOpTypeDiv,
     BinOpTypeMod,
     BinOpTypeUnwrapMaybe,
+    BinOpTypeStrCat,
 };
 
 struct AstNodeBinOpExpr {
src/analyze.cpp
@@ -1437,6 +1437,8 @@ static TypeTableEntry *create_and_analyze_cast_node(CodeGen *g, ImportTableEntry
         BlockContext *context, TypeTableEntry *cast_to_type, AstNode *node)
 {
     AstNode *new_parent_node = create_ast_node(g, import, NodeTypeFnCallExpr);
+    new_parent_node->line = node->line;
+    new_parent_node->column = node->column;
     *node->parent_field = new_parent_node;
     new_parent_node->parent_field = node->parent_field;
 
@@ -2146,6 +2148,7 @@ static bool is_op_allowed(TypeTableEntry *type, BinOpType op) {
         case BinOpTypeDiv:
         case BinOpTypeMod:
         case BinOpTypeUnwrapMaybe:
+        case BinOpTypeStrCat:
             zig_unreachable();
     }
     zig_unreachable();
@@ -2454,6 +2457,68 @@ static TypeTableEntry *analyze_bin_op_expr(CodeGen *g, ImportTableEntry *import,
                     return g->builtin_types.entry_invalid;
                 }
             }
+        case BinOpTypeStrCat:
+            {
+                AstNode **op1 = node->data.bin_op_expr.op1->parent_field;
+                AstNode **op2 = node->data.bin_op_expr.op2->parent_field;
+
+                TypeTableEntry *str_type = get_unknown_size_array_type(g, g->builtin_types.entry_u8, true);
+
+                TypeTableEntry *op1_type = analyze_expression(g, import, context, str_type, *op1);
+                TypeTableEntry *op2_type = analyze_expression(g, import, context, str_type, *op2);
+
+                if (op1_type->id == TypeTableEntryIdInvalid ||
+                    op2_type->id == TypeTableEntryIdInvalid)
+                {
+                    return g->builtin_types.entry_invalid;
+                }
+
+                ConstExprValue *op1_val = &get_resolved_expr(*op1)->const_val;
+                ConstExprValue *op2_val = &get_resolved_expr(*op2)->const_val;
+
+                AstNode *bad_node;
+                if (!op1_val->ok) {
+                    bad_node = *op1;
+                } else if (!op2_val->ok) {
+                    bad_node = *op2;
+                } else {
+                    bad_node = nullptr;
+                }
+                if (bad_node) {
+                    add_node_error(g, bad_node, buf_sprintf("string concatenation requires constant expression"));
+                    return g->builtin_types.entry_invalid;
+                }
+                ConstExprValue *const_val = &get_resolved_expr(node)->const_val;
+                const_val->ok = true;
+
+                ConstExprValue *all_fields = allocate<ConstExprValue>(2);
+                ConstExprValue *ptr_field = &all_fields[0];
+                ConstExprValue *len_field = &all_fields[1];
+
+                const_val->data.x_struct.fields = allocate<ConstExprValue*>(2);
+                const_val->data.x_struct.fields[0] = ptr_field;
+                const_val->data.x_struct.fields[1] = len_field;
+
+                len_field->ok = true;
+                uint64_t op1_len = op1_val->data.x_struct.fields[1]->data.x_bignum.data.x_uint;
+                uint64_t op2_len = op2_val->data.x_struct.fields[1]->data.x_bignum.data.x_uint;
+                uint64_t len = op1_len + op2_len;
+                bignum_init_unsigned(&len_field->data.x_bignum, len);
+
+                ptr_field->ok = true;
+                ptr_field->data.x_ptr.ptr = allocate<ConstExprValue*>(len);
+                ptr_field->data.x_ptr.len = len;
+
+                uint64_t i = 0;
+                for (uint64_t op1_i = 0; op1_i < op1_len; op1_i += 1, i += 1) {
+                    ptr_field->data.x_ptr.ptr[i] = op1_val->data.x_struct.fields[0]->data.x_ptr.ptr[op1_i];
+                }
+                for (uint64_t op2_i = 0; op2_i < op2_len; op2_i += 1, i += 1) {
+                    ptr_field->data.x_ptr.ptr[i] = op2_val->data.x_struct.fields[0]->data.x_ptr.ptr[op2_i];
+                }
+
+                return str_type;
+            }
         case BinOpTypeInvalid:
             zig_unreachable();
     }
src/codegen.cpp
@@ -951,6 +951,7 @@ static LLVMValueRef gen_arithmetic_bin_op(CodeGen *g, AstNode *source_node,
         case BinOpTypeAssignBoolAnd:
         case BinOpTypeAssignBoolOr:
         case BinOpTypeUnwrapMaybe:
+        case BinOpTypeStrCat:
             zig_unreachable();
     }
     zig_unreachable();
@@ -1228,6 +1229,7 @@ static LLVMValueRef gen_unwrap_maybe_expr(CodeGen *g, AstNode *node) {
 static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) {
     switch (node->data.bin_op_expr.bin_op) {
         case BinOpTypeInvalid:
+        case BinOpTypeStrCat:
             zig_unreachable();
         case BinOpTypeAssign:
         case BinOpTypeAssignTimes:
src/parser.cpp
@@ -49,6 +49,7 @@ static const char *bin_op_str(BinOpType bin_op) {
         case BinOpTypeAssignBoolAnd:       return "&&=";
         case BinOpTypeAssignBoolOr:        return "||=";
         case BinOpTypeUnwrapMaybe:         return "??";
+        case BinOpTypeStrCat:              return "++";
     }
     zig_unreachable();
 }
@@ -1769,12 +1770,13 @@ static BinOpType tok_to_add_op(Token *token) {
     switch (token->id) {
         case TokenIdPlus: return BinOpTypeAdd;
         case TokenIdDash: return BinOpTypeSub;
+        case TokenIdPlusPlus: return BinOpTypeStrCat;
         default: return BinOpTypeInvalid;
     }
 }
 
 /*
-AdditionOperator : token(Plus) | token(Minus)
+AdditionOperator : "+" | "-" | "++"
 */
 static BinOpType ast_parse_add_op(ParseContext *pc, int *token_index, bool mandatory) {
     Token *token = &pc->tokens->at(*token_index);
src/tokenizer.cpp
@@ -612,6 +612,11 @@ void tokenize(Buf *buf, Tokenization *out) {
                         end_token(&t);
                         t.state = TokenizeStateStart;
                         break;
+                    case '+':
+                        t.cur_tok->id = TokenIdPlusPlus;
+                        end_token(&t);
+                        t.state = TokenizeStateStart;
+                        break;
                     default:
                         t.pos -= 1;
                         end_token(&t);
@@ -1067,6 +1072,7 @@ const char * token_name(TokenId id) {
         case TokenIdSemicolon: return ";";
         case TokenIdNumberLiteral: return "NumberLiteral";
         case TokenIdPlus: return "+";
+        case TokenIdPlusPlus: return "++";
         case TokenIdColon: return ":";
         case TokenIdArrow: return "->";
         case TokenIdFatArrow: return "=>";
src/tokenizer.hpp
@@ -52,6 +52,7 @@ enum TokenId {
     TokenIdSemicolon,
     TokenIdNumberLiteral,
     TokenIdPlus,
+    TokenIdPlusPlus,
     TokenIdColon,
     TokenIdArrow,
     TokenIdFatArrow,
test/run_tests.cpp
@@ -1279,6 +1279,13 @@ pub fn main(args: [][]u8) -> %void {
     %%stdout.printf("OK\n");
 }
     )SOURCE", "OK\n");
+
+    add_simple_case("string concatenation", R"SOURCE(
+import "std.zig";
+pub fn main(args: [][]u8) -> %void {
+    %%stdout.printf("OK" ++ " IT " ++ "WORKED\n");
+}
+    )SOURCE", "OK IT WORKED\n");
 }
 
 
@@ -1645,6 +1652,12 @@ extern {
 const x = foo();
     )SOURCE", 1, ".tmp_source.zig:5:11: error: global variable initializer requires constant expression");
 
+    add_compile_fail_case("non compile time string concatenation", R"SOURCE(
+fn f(s: []u8) -> []u8 {
+    s ++ "foo"
+}
+    )SOURCE", 1, ".tmp_source.zig:3:5: error: string concatenation requires constant expression");
+
 }
 
 static void print_compiler_invocation(TestCase *test_case) {
README.md
@@ -61,6 +61,11 @@ compromises backward compatibility.
 
 ## Building
 
+### Dependencies
+
+ * LLVM 3.7
+ * libclang 3.7
+
 ### Debug / Development Build
 
 If you have gcc or clang installed, you can find out what `ZIG_LIBC_DIR` should