Commit b7dd88ad68

Andrew Kelley <superjoe30@gmail.com>
2016-01-09 07:41:40
suport checked arithmetic operations via intrinsics
closes #32
1 parent 14b9cbd
doc/langref.md
@@ -160,7 +160,7 @@ SliceExpression : token(LBracket) Expression token(Ellipsis) option(Expression)
 
 PrefixOp : token(Not) | token(Dash) | token(Tilde) | token(Star) | (token(Ampersand) option(token(Const)))
 
-PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
+PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType | (token(AtSign) token(Symbol) FnCallExpression)
 
 StructValueExpression : token(Type) token(LBrace) list(StructValueExpressionField, token(Comma)) token(RBrace)
 
src/analyze.cpp
@@ -2064,6 +2064,41 @@ static TypeTableEntry *analyze_compiler_fn_type(CodeGen *g, ImportTableEntry *im
     }
 }
 
+static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+        TypeTableEntry *expected_type, AstNode *node)
+{
+    AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
+    Buf *name = &fn_ref_expr->data.symbol;
+
+    auto entry = g->builtin_fn_table.maybe_get(name);
+
+    if (entry) {
+        BuiltinFnEntry *builtin_fn = entry->value;
+        int actual_param_count = node->data.fn_call_expr.params.length;
+
+        assert(node->codegen_node);
+        node->codegen_node->data.fn_call_node.builtin_fn = builtin_fn;
+
+        if (builtin_fn->param_count != actual_param_count) {
+            add_node_error(g, node,
+                    buf_sprintf("expected %d arguments, got %d",
+                        builtin_fn->param_count, actual_param_count));
+        }
+
+        for (int i = 0; i < actual_param_count; i += 1) {
+            AstNode *child = node->data.fn_call_expr.params.at(i);
+            TypeTableEntry *expected_param_type = builtin_fn->param_types[i];
+            analyze_expression(g, import, context, expected_param_type, child);
+        }
+
+        return builtin_fn->return_type;
+    } else {
+        add_node_error(g, node,
+                buf_sprintf("invalid builtin function: '%s'", buf_ptr(name)));
+        return g->builtin_types.entry_invalid;
+    }
+}
+
 static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -2091,6 +2126,9 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
             return g->builtin_types.entry_invalid;
         }
     } else if (fn_ref_expr->type == NodeTypeSymbol) {
+        if (node->data.fn_call_expr.is_builtin) {
+            return analyze_builtin_fn_call_expr(g, import, context, expected_type, node);
+        }
         name = &fn_ref_expr->data.symbol;
     } else {
         add_node_error(g, node,
@@ -2126,12 +2164,12 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
         if (fn_proto->is_var_args) {
             if (actual_param_count < expected_param_count) {
                 add_node_error(g, node,
-                        buf_sprintf("wrong number of arguments. Expected at least %d, got %d.",
+                        buf_sprintf("expected at least %d arguments, got %d",
                             expected_param_count, actual_param_count));
             }
         } else if (expected_param_count != actual_param_count) {
             add_node_error(g, node,
-                    buf_sprintf("wrong number of arguments. Expected %d, got %d.",
+                    buf_sprintf("expected %d arguments, got %d",
                         expected_param_count, actual_param_count));
         }
 
src/analyze.hpp
@@ -148,6 +148,20 @@ struct FnTableEntry {
     HashMap<Buf *, LabelTableEntry *, buf_hash, buf_eql_buf> label_table;
 };
 
+enum BuiltinFnId {
+    BuiltinFnIdInvalid,
+    BuiltinFnIdArithmeticWithOverflow,
+};
+
+struct BuiltinFnEntry {
+    BuiltinFnId id;
+    Buf name;
+    int param_count;
+    TypeTableEntry *return_type;
+    TypeTableEntry **param_types;
+    LLVMValueRef fn_val;
+};
+
 struct CodeGen {
     LLVMModuleRef module;
     ZigList<ErrorMsg*> errors;
@@ -161,6 +175,7 @@ struct CodeGen {
     HashMap<Buf *, LLVMValueRef, buf_hash, buf_eql_buf> str_table;
     HashMap<Buf *, bool, buf_hash, buf_eql_buf> link_table;
     HashMap<Buf *, ImportTableEntry *, buf_hash, buf_eql_buf> import_table;
+    HashMap<Buf *, BuiltinFnEntry *, buf_hash, buf_eql_buf> builtin_fn_table;
 
     struct {
         TypeTableEntry *entry_bool;
@@ -342,6 +357,10 @@ struct WhileNode {
     bool contains_break;
 };
 
+struct FnCallNode {
+    BuiltinFnEntry *builtin_fn;
+};
+
 struct CodeGenNode {
     union {
         TypeNode type_node; // for NodeTypeType
@@ -363,17 +382,11 @@ struct CodeGenNode {
         ParamDeclNode param_decl_node; // for NodeTypeParamDecl
         ImportNode import_node; // for NodeTypeUse
         WhileNode while_node; // for NodeTypeWhileExpr
+        FnCallNode fn_call_node; // for NodeTypeFnCallExpr
     } data;
     ExprNode expr_node; // for all the expression nodes
 };
 
-static inline Buf *hack_get_fn_call_name(CodeGen *g, AstNode *node) {
-    // Assume that the expression evaluates to a simple name and return the buf
-    // TODO after type checking works we should be able to remove this hack
-    assert(node->type == NodeTypeSymbol);
-    return &node->data.symbol;
-}
-
 void semantic_analyze(CodeGen *g);
 void add_node_error(CodeGen *g, AstNode *node, Buf *msg);
 void alloc_codegen_node(AstNode *node);
src/codegen.cpp
@@ -22,6 +22,7 @@ CodeGen *codegen_create(Buf *root_source_dir) {
     g->str_table.init(32);
     g->link_table.init(32);
     g->import_table.init(32);
+    g->builtin_fn_table.init(32);
     g->build_type = CodeGenBuildTypeDebug;
     g->root_source_dir = root_source_dir;
 
@@ -139,6 +140,41 @@ static TypeTableEntry *get_expr_type(AstNode *node) {
     return cast_type ? cast_type : node->codegen_node->expr_node.type_entry;
 }
 
+static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeFnCallExpr);
+    AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
+    assert(fn_ref_expr->type == NodeTypeSymbol);
+    BuiltinFnEntry *builtin_fn = node->codegen_node->data.fn_call_node.builtin_fn;
+
+    switch (builtin_fn->id) {
+        case BuiltinFnIdInvalid:
+            zig_unreachable();
+        case BuiltinFnIdArithmeticWithOverflow:
+            {
+                int fn_call_param_count = node->data.fn_call_expr.params.length;
+                assert(fn_call_param_count == 3);
+
+                LLVMValueRef op1 = gen_expr(g, node->data.fn_call_expr.params.at(0));
+                LLVMValueRef op2 = gen_expr(g, node->data.fn_call_expr.params.at(1));
+                LLVMValueRef ptr_result = gen_expr(g, node->data.fn_call_expr.params.at(2));
+
+                LLVMValueRef params[] = {
+                    op1,
+                    op2,
+                };
+
+                add_debug_source_node(g, node);
+                LLVMValueRef result_struct = LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 2, "");
+                LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
+                LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
+                LLVMBuildStore(g->builder, result, ptr_result);
+
+                return overflow_bit;
+            }
+    }
+    zig_unreachable();
+}
+
 static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeFnCallExpr);
 
@@ -159,7 +195,15 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
             zig_unreachable();
         }
     } else if (fn_ref_expr->type == NodeTypeSymbol) {
-        Buf *name = hack_get_fn_call_name(g, fn_ref_expr);
+        if (node->data.fn_call_expr.is_builtin) {
+            return gen_builtin_fn_call_expr(g, node);
+        }
+
+        // Assume that the expression evaluates to a simple name and return the buf
+        // TODO after we support function pointers we can make this generic
+        assert(fn_ref_expr->type == NodeTypeSymbol);
+        Buf *name = &fn_ref_expr->data.symbol;
+
         struct_type = nullptr;
         first_param_expr = nullptr;
         fn_table_entry = g->cur_fn->import_entry->fn_table.get(name);
@@ -2167,6 +2211,64 @@ static void define_builtin_types(CodeGen *g) {
     }
 }
 
+static void define_builtin_fns_int(CodeGen *g, TypeTableEntry *type_entry) {
+    assert(type_entry->id == TypeTableEntryIdInt);
+    struct OverflowFn {
+        const char *bare_name;
+        const char *signed_name;
+        const char *unsigned_name;
+    };
+    OverflowFn overflow_fns[] = {
+        {"add", "sadd", "uadd"},
+        {"sub", "ssub", "usub"},
+        {"mul", "smul", "umul"},
+    };
+    for (int i = 0; i < sizeof(overflow_fns)/sizeof(overflow_fns[0]); i += 1) {
+        OverflowFn *overflow_fn = &overflow_fns[i];
+        BuiltinFnEntry *builtin_fn = allocate<BuiltinFnEntry>(1);
+        buf_resize(&builtin_fn->name, 0);
+        buf_appendf(&builtin_fn->name, "%s_with_overflow_%s", overflow_fn->bare_name, buf_ptr(&type_entry->name));
+        builtin_fn->id = BuiltinFnIdArithmeticWithOverflow;
+        builtin_fn->return_type = g->builtin_types.entry_bool;
+        builtin_fn->param_count = 3;
+        builtin_fn->param_types = allocate<TypeTableEntry *>(builtin_fn->param_count);
+        builtin_fn->param_types[0] = type_entry;
+        builtin_fn->param_types[1] = type_entry;
+        builtin_fn->param_types[2] = get_pointer_to_type(g, type_entry, false, false);
+
+
+        const char *signed_str = type_entry->data.integral.is_signed ?
+            overflow_fn->signed_name : overflow_fn->unsigned_name;
+        Buf *llvm_name = buf_sprintf("llvm.%s.with.overflow.i%" PRIu64, signed_str, type_entry->size_in_bits);
+
+        LLVMTypeRef return_elem_types[] = {
+            type_entry->type_ref,
+            LLVMInt1Type(),
+        };
+        LLVMTypeRef param_types[] = {
+            type_entry->type_ref,
+            type_entry->type_ref,
+        };
+        LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false);
+        LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false);
+        builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(llvm_name), fn_type);
+        assert(LLVMGetIntrinsicID(builtin_fn->fn_val));
+
+        g->builtin_fn_table.put(&builtin_fn->name, builtin_fn);
+    }
+}
+
+static void define_builtin_fns(CodeGen *g) {
+    define_builtin_fns_int(g, g->builtin_types.entry_u8);
+    define_builtin_fns_int(g, g->builtin_types.entry_u16);
+    define_builtin_fns_int(g, g->builtin_types.entry_u32);
+    define_builtin_fns_int(g, g->builtin_types.entry_u64);
+    define_builtin_fns_int(g, g->builtin_types.entry_i8);
+    define_builtin_fns_int(g, g->builtin_types.entry_i16);
+    define_builtin_fns_int(g, g->builtin_types.entry_i32);
+    define_builtin_fns_int(g, g->builtin_types.entry_i64);
+}
+
 
 
 static void init(CodeGen *g, Buf *source_path) {
@@ -2228,9 +2330,10 @@ static void init(CodeGen *g, Buf *source_path) {
             "", 0, !g->strip_debug_symbols);
 
     // This is for debug stuff that doesn't have a real file.
-    g->dummy_di_file = nullptr; //LLVMZigCreateFile(g->dbuilder, "", "");
+    g->dummy_di_file = nullptr;
 
     define_builtin_types(g);
+    define_builtin_fns(g);
 
 }
 
src/parser.cpp
@@ -1313,7 +1313,7 @@ static AstNode *ast_parse_struct_val_expr(ParseContext *pc, int *token_index) {
 }
 
 /*
-PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
+PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType | (token(AtSign) token(Symbol) FnCallExpression)
 KeywordLiteral : token(Unreachable) | token(Void) | token(True) | token(False) | token(Null)
 */
 static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) {
@@ -1356,6 +1356,18 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
         AstNode *node = ast_create_node(pc, NodeTypeNullLiteral, token);
         *token_index += 1;
         return node;
+    } else if (token->id == TokenIdAtSign) {
+        *token_index += 1;
+        Token *name_tok = ast_eat_token(pc, token_index, TokenIdSymbol);
+        AstNode *name_node = ast_create_node(pc, NodeTypeSymbol, name_tok);
+        ast_buf_from_token(pc, name_tok, &name_node->data.symbol);
+
+        AstNode *node = ast_create_node(pc, NodeTypeFnCallExpr, token);
+        node->data.fn_call_expr.fn_ref_expr = name_node;
+        ast_eat_token(pc, token_index, TokenIdLParen);
+        ast_parse_fn_call_param_list(pc, token_index, &node->data.fn_call_expr.params);
+        node->data.fn_call_expr.is_builtin = true;
+        return node;
     } else if (token->id == TokenIdSymbol) {
         Token *next_token = &pc->tokens->at(*token_index + 1);
 
src/parser.hpp
@@ -176,6 +176,7 @@ struct AstNodeBinOpExpr {
 struct AstNodeFnCallExpr {
     AstNode *fn_ref_expr;
     ZigList<AstNode *> params;
+    bool is_builtin;
 };
 
 struct AstNodeArrayAccessExpr {
src/tokenizer.cpp
@@ -376,6 +376,10 @@ void tokenize(Buf *buf, Tokenization *out) {
                         begin_token(&t, TokenIdTilde);
                         end_token(&t);
                         break;
+                    case '@':
+                        begin_token(&t, TokenIdAtSign);
+                        end_token(&t);
+                        break;
                     case '-':
                         begin_token(&t, TokenIdDash);
                         t.state = TokenizeStateSawDash;
@@ -1074,6 +1078,7 @@ static const char * token_name(Token *token) {
         case TokenIdMaybe: return "Maybe";
         case TokenIdDoubleQuestion: return "DoubleQuestion";
         case TokenIdMaybeAssign: return "MaybeAssign";
+        case TokenIdAtSign: return "AtSign";
     }
     return "(invalid token)";
 }
src/tokenizer.hpp
@@ -89,6 +89,7 @@ enum TokenId {
     TokenIdMaybe,
     TokenIdDoubleQuestion,
     TokenIdMaybeAssign,
+    TokenIdAtSign,
 };
 
 struct Token {
std/std.zig
@@ -63,10 +63,6 @@ pub fn parse_u64(buf: []u8, radix: u8, result: &u64) -> bool {
             return true;
         }
 
-        x *= radix;
-        x += digit;
-
-        /* TODO intrinsics mul and add with overflow
         // x *= radix
         if (@mul_with_overflow_u64(x, radix, &x)) {
             return true;
@@ -76,7 +72,6 @@ pub fn parse_u64(buf: []u8, radix: u8, result: &u64) -> bool {
         if (@add_with_overflow_u64(x, digit, &x)) {
             return true;
         }
-        */
 
         i += 1;
     }
test/run_tests.cpp
@@ -953,6 +953,24 @@ fn f(c: u8) -> u8 {
     } else {
         2
     }
+}
+    )SOURCE", "OK\n");
+
+    add_simple_case("overflow intrinsics", R"SOURCE(
+use "std.zig";
+pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
+    var result: u8;
+    if (!@add_with_overflow_u8(250, 100, &result)) {
+        print_str("BAD\n");
+    }
+    if (@add_with_overflow_u8(100, 150, &result)) {
+        print_str("BAD\n");
+    }
+    if (result != 250) {
+        print_str("BAD\n");
+    }
+    print_str("OK\n");
+    return 0;
 }
     )SOURCE", "OK\n");
 }
@@ -995,7 +1013,7 @@ fn a() {
     b(1);
 }
 fn b(a: i32, b: i32, c: i32) { }
-    )SOURCE", 1, ".tmp_source.zig:3:6: error: wrong number of arguments. Expected 3, got 1.");
+    )SOURCE", 1, ".tmp_source.zig:3:6: error: expected 3 arguments, got 1");
 
     add_compile_fail_case("invalid type", R"SOURCE(
 fn a() -> bogus {}