Commit b7dd88ad68
Changed files (10)
doc/langref.md
@@ -160,7 +160,7 @@ SliceExpression : token(LBracket) Expression token(Ellipsis) option(Expression)
PrefixOp : token(Not) | token(Dash) | token(Tilde) | token(Star) | (token(Ampersand) option(token(Const)))
-PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
+PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType | (token(AtSign) token(Symbol) FnCallExpression)
StructValueExpression : token(Type) token(LBrace) list(StructValueExpressionField, token(Comma)) token(RBrace)
src/analyze.cpp
@@ -2064,6 +2064,41 @@ static TypeTableEntry *analyze_compiler_fn_type(CodeGen *g, ImportTableEntry *im
}
}
+static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
+ TypeTableEntry *expected_type, AstNode *node)
+{
+ AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
+ Buf *name = &fn_ref_expr->data.symbol;
+
+ auto entry = g->builtin_fn_table.maybe_get(name);
+
+ if (entry) {
+ BuiltinFnEntry *builtin_fn = entry->value;
+ int actual_param_count = node->data.fn_call_expr.params.length;
+
+ assert(node->codegen_node);
+ node->codegen_node->data.fn_call_node.builtin_fn = builtin_fn;
+
+ if (builtin_fn->param_count != actual_param_count) {
+ add_node_error(g, node,
+ buf_sprintf("expected %d arguments, got %d",
+ builtin_fn->param_count, actual_param_count));
+ }
+
+ for (int i = 0; i < actual_param_count; i += 1) {
+ AstNode *child = node->data.fn_call_expr.params.at(i);
+ TypeTableEntry *expected_param_type = builtin_fn->param_types[i];
+ analyze_expression(g, import, context, expected_param_type, child);
+ }
+
+ return builtin_fn->return_type;
+ } else {
+ add_node_error(g, node,
+ buf_sprintf("invalid builtin function: '%s'", buf_ptr(name)));
+ return g->builtin_types.entry_invalid;
+ }
+}
+
static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
TypeTableEntry *expected_type, AstNode *node)
{
@@ -2091,6 +2126,9 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
return g->builtin_types.entry_invalid;
}
} else if (fn_ref_expr->type == NodeTypeSymbol) {
+ if (node->data.fn_call_expr.is_builtin) {
+ return analyze_builtin_fn_call_expr(g, import, context, expected_type, node);
+ }
name = &fn_ref_expr->data.symbol;
} else {
add_node_error(g, node,
@@ -2126,12 +2164,12 @@ static TypeTableEntry *analyze_fn_call_expr(CodeGen *g, ImportTableEntry *import
if (fn_proto->is_var_args) {
if (actual_param_count < expected_param_count) {
add_node_error(g, node,
- buf_sprintf("wrong number of arguments. Expected at least %d, got %d.",
+ buf_sprintf("expected at least %d arguments, got %d",
expected_param_count, actual_param_count));
}
} else if (expected_param_count != actual_param_count) {
add_node_error(g, node,
- buf_sprintf("wrong number of arguments. Expected %d, got %d.",
+ buf_sprintf("expected %d arguments, got %d",
expected_param_count, actual_param_count));
}
src/analyze.hpp
@@ -148,6 +148,20 @@ struct FnTableEntry {
HashMap<Buf *, LabelTableEntry *, buf_hash, buf_eql_buf> label_table;
};
+enum BuiltinFnId {
+ BuiltinFnIdInvalid,
+ BuiltinFnIdArithmeticWithOverflow,
+};
+
+struct BuiltinFnEntry {
+ BuiltinFnId id;
+ Buf name;
+ int param_count;
+ TypeTableEntry *return_type;
+ TypeTableEntry **param_types;
+ LLVMValueRef fn_val;
+};
+
struct CodeGen {
LLVMModuleRef module;
ZigList<ErrorMsg*> errors;
@@ -161,6 +175,7 @@ struct CodeGen {
HashMap<Buf *, LLVMValueRef, buf_hash, buf_eql_buf> str_table;
HashMap<Buf *, bool, buf_hash, buf_eql_buf> link_table;
HashMap<Buf *, ImportTableEntry *, buf_hash, buf_eql_buf> import_table;
+ HashMap<Buf *, BuiltinFnEntry *, buf_hash, buf_eql_buf> builtin_fn_table;
struct {
TypeTableEntry *entry_bool;
@@ -342,6 +357,10 @@ struct WhileNode {
bool contains_break;
};
+struct FnCallNode {
+ BuiltinFnEntry *builtin_fn;
+};
+
struct CodeGenNode {
union {
TypeNode type_node; // for NodeTypeType
@@ -363,17 +382,11 @@ struct CodeGenNode {
ParamDeclNode param_decl_node; // for NodeTypeParamDecl
ImportNode import_node; // for NodeTypeUse
WhileNode while_node; // for NodeTypeWhileExpr
+ FnCallNode fn_call_node; // for NodeTypeFnCallExpr
} data;
ExprNode expr_node; // for all the expression nodes
};
-static inline Buf *hack_get_fn_call_name(CodeGen *g, AstNode *node) {
- // Assume that the expression evaluates to a simple name and return the buf
- // TODO after type checking works we should be able to remove this hack
- assert(node->type == NodeTypeSymbol);
- return &node->data.symbol;
-}
-
void semantic_analyze(CodeGen *g);
void add_node_error(CodeGen *g, AstNode *node, Buf *msg);
void alloc_codegen_node(AstNode *node);
src/codegen.cpp
@@ -22,6 +22,7 @@ CodeGen *codegen_create(Buf *root_source_dir) {
g->str_table.init(32);
g->link_table.init(32);
g->import_table.init(32);
+ g->builtin_fn_table.init(32);
g->build_type = CodeGenBuildTypeDebug;
g->root_source_dir = root_source_dir;
@@ -139,6 +140,41 @@ static TypeTableEntry *get_expr_type(AstNode *node) {
return cast_type ? cast_type : node->codegen_node->expr_node.type_entry;
}
+static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
+ assert(node->type == NodeTypeFnCallExpr);
+ AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
+ assert(fn_ref_expr->type == NodeTypeSymbol);
+ BuiltinFnEntry *builtin_fn = node->codegen_node->data.fn_call_node.builtin_fn;
+
+ switch (builtin_fn->id) {
+ case BuiltinFnIdInvalid:
+ zig_unreachable();
+ case BuiltinFnIdArithmeticWithOverflow:
+ {
+ int fn_call_param_count = node->data.fn_call_expr.params.length;
+ assert(fn_call_param_count == 3);
+
+ LLVMValueRef op1 = gen_expr(g, node->data.fn_call_expr.params.at(0));
+ LLVMValueRef op2 = gen_expr(g, node->data.fn_call_expr.params.at(1));
+ LLVMValueRef ptr_result = gen_expr(g, node->data.fn_call_expr.params.at(2));
+
+ LLVMValueRef params[] = {
+ op1,
+ op2,
+ };
+
+ add_debug_source_node(g, node);
+ LLVMValueRef result_struct = LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 2, "");
+ LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
+ LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
+ LLVMBuildStore(g->builder, result, ptr_result);
+
+ return overflow_bit;
+ }
+ }
+ zig_unreachable();
+}
+
static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeFnCallExpr);
@@ -159,7 +195,15 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
zig_unreachable();
}
} else if (fn_ref_expr->type == NodeTypeSymbol) {
- Buf *name = hack_get_fn_call_name(g, fn_ref_expr);
+ if (node->data.fn_call_expr.is_builtin) {
+ return gen_builtin_fn_call_expr(g, node);
+ }
+
+ // Assume that the expression evaluates to a simple name and return the buf
+ // TODO after we support function pointers we can make this generic
+ assert(fn_ref_expr->type == NodeTypeSymbol);
+ Buf *name = &fn_ref_expr->data.symbol;
+
struct_type = nullptr;
first_param_expr = nullptr;
fn_table_entry = g->cur_fn->import_entry->fn_table.get(name);
@@ -2167,6 +2211,64 @@ static void define_builtin_types(CodeGen *g) {
}
}
+static void define_builtin_fns_int(CodeGen *g, TypeTableEntry *type_entry) {
+ assert(type_entry->id == TypeTableEntryIdInt);
+ struct OverflowFn {
+ const char *bare_name;
+ const char *signed_name;
+ const char *unsigned_name;
+ };
+ OverflowFn overflow_fns[] = {
+ {"add", "sadd", "uadd"},
+ {"sub", "ssub", "usub"},
+ {"mul", "smul", "umul"},
+ };
+ for (int i = 0; i < sizeof(overflow_fns)/sizeof(overflow_fns[0]); i += 1) {
+ OverflowFn *overflow_fn = &overflow_fns[i];
+ BuiltinFnEntry *builtin_fn = allocate<BuiltinFnEntry>(1);
+ buf_resize(&builtin_fn->name, 0);
+ buf_appendf(&builtin_fn->name, "%s_with_overflow_%s", overflow_fn->bare_name, buf_ptr(&type_entry->name));
+ builtin_fn->id = BuiltinFnIdArithmeticWithOverflow;
+ builtin_fn->return_type = g->builtin_types.entry_bool;
+ builtin_fn->param_count = 3;
+ builtin_fn->param_types = allocate<TypeTableEntry *>(builtin_fn->param_count);
+ builtin_fn->param_types[0] = type_entry;
+ builtin_fn->param_types[1] = type_entry;
+ builtin_fn->param_types[2] = get_pointer_to_type(g, type_entry, false, false);
+
+
+ const char *signed_str = type_entry->data.integral.is_signed ?
+ overflow_fn->signed_name : overflow_fn->unsigned_name;
+ Buf *llvm_name = buf_sprintf("llvm.%s.with.overflow.i%" PRIu64, signed_str, type_entry->size_in_bits);
+
+ LLVMTypeRef return_elem_types[] = {
+ type_entry->type_ref,
+ LLVMInt1Type(),
+ };
+ LLVMTypeRef param_types[] = {
+ type_entry->type_ref,
+ type_entry->type_ref,
+ };
+ LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false);
+ LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false);
+ builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(llvm_name), fn_type);
+ assert(LLVMGetIntrinsicID(builtin_fn->fn_val));
+
+ g->builtin_fn_table.put(&builtin_fn->name, builtin_fn);
+ }
+}
+
+static void define_builtin_fns(CodeGen *g) {
+ define_builtin_fns_int(g, g->builtin_types.entry_u8);
+ define_builtin_fns_int(g, g->builtin_types.entry_u16);
+ define_builtin_fns_int(g, g->builtin_types.entry_u32);
+ define_builtin_fns_int(g, g->builtin_types.entry_u64);
+ define_builtin_fns_int(g, g->builtin_types.entry_i8);
+ define_builtin_fns_int(g, g->builtin_types.entry_i16);
+ define_builtin_fns_int(g, g->builtin_types.entry_i32);
+ define_builtin_fns_int(g, g->builtin_types.entry_i64);
+}
+
static void init(CodeGen *g, Buf *source_path) {
@@ -2228,9 +2330,10 @@ static void init(CodeGen *g, Buf *source_path) {
"", 0, !g->strip_debug_symbols);
// This is for debug stuff that doesn't have a real file.
- g->dummy_di_file = nullptr; //LLVMZigCreateFile(g->dbuilder, "", "");
+ g->dummy_di_file = nullptr;
define_builtin_types(g);
+ define_builtin_fns(g);
}
src/parser.cpp
@@ -1313,7 +1313,7 @@ static AstNode *ast_parse_struct_val_expr(ParseContext *pc, int *token_index) {
}
/*
-PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType
+PrimaryExpression : token(Number) | token(String) | token(CharLiteral) | KeywordLiteral | GroupedExpression | Goto | token(Break) | token(Continue) | BlockExpression | token(Symbol) | StructValueExpression | CompilerFnType | (token(AtSign) token(Symbol) FnCallExpression)
KeywordLiteral : token(Unreachable) | token(Void) | token(True) | token(False) | token(Null)
*/
static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) {
@@ -1356,6 +1356,18 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool
AstNode *node = ast_create_node(pc, NodeTypeNullLiteral, token);
*token_index += 1;
return node;
+ } else if (token->id == TokenIdAtSign) {
+ *token_index += 1;
+ Token *name_tok = ast_eat_token(pc, token_index, TokenIdSymbol);
+ AstNode *name_node = ast_create_node(pc, NodeTypeSymbol, name_tok);
+ ast_buf_from_token(pc, name_tok, &name_node->data.symbol);
+
+ AstNode *node = ast_create_node(pc, NodeTypeFnCallExpr, token);
+ node->data.fn_call_expr.fn_ref_expr = name_node;
+ ast_eat_token(pc, token_index, TokenIdLParen);
+ ast_parse_fn_call_param_list(pc, token_index, &node->data.fn_call_expr.params);
+ node->data.fn_call_expr.is_builtin = true;
+ return node;
} else if (token->id == TokenIdSymbol) {
Token *next_token = &pc->tokens->at(*token_index + 1);
src/parser.hpp
@@ -176,6 +176,7 @@ struct AstNodeBinOpExpr {
struct AstNodeFnCallExpr {
AstNode *fn_ref_expr;
ZigList<AstNode *> params;
+ bool is_builtin;
};
struct AstNodeArrayAccessExpr {
src/tokenizer.cpp
@@ -376,6 +376,10 @@ void tokenize(Buf *buf, Tokenization *out) {
begin_token(&t, TokenIdTilde);
end_token(&t);
break;
+ case '@':
+ begin_token(&t, TokenIdAtSign);
+ end_token(&t);
+ break;
case '-':
begin_token(&t, TokenIdDash);
t.state = TokenizeStateSawDash;
@@ -1074,6 +1078,7 @@ static const char * token_name(Token *token) {
case TokenIdMaybe: return "Maybe";
case TokenIdDoubleQuestion: return "DoubleQuestion";
case TokenIdMaybeAssign: return "MaybeAssign";
+ case TokenIdAtSign: return "AtSign";
}
return "(invalid token)";
}
src/tokenizer.hpp
@@ -89,6 +89,7 @@ enum TokenId {
TokenIdMaybe,
TokenIdDoubleQuestion,
TokenIdMaybeAssign,
+ TokenIdAtSign,
};
struct Token {
std/std.zig
@@ -63,10 +63,6 @@ pub fn parse_u64(buf: []u8, radix: u8, result: &u64) -> bool {
return true;
}
- x *= radix;
- x += digit;
-
- /* TODO intrinsics mul and add with overflow
// x *= radix
if (@mul_with_overflow_u64(x, radix, &x)) {
return true;
@@ -76,7 +72,6 @@ pub fn parse_u64(buf: []u8, radix: u8, result: &u64) -> bool {
if (@add_with_overflow_u64(x, digit, &x)) {
return true;
}
- */
i += 1;
}
test/run_tests.cpp
@@ -953,6 +953,24 @@ fn f(c: u8) -> u8 {
} else {
2
}
+}
+ )SOURCE", "OK\n");
+
+ add_simple_case("overflow intrinsics", R"SOURCE(
+use "std.zig";
+pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
+ var result: u8;
+ if (!@add_with_overflow_u8(250, 100, &result)) {
+ print_str("BAD\n");
+ }
+ if (@add_with_overflow_u8(100, 150, &result)) {
+ print_str("BAD\n");
+ }
+ if (result != 250) {
+ print_str("BAD\n");
+ }
+ print_str("OK\n");
+ return 0;
}
)SOURCE", "OK\n");
}
@@ -995,7 +1013,7 @@ fn a() {
b(1);
}
fn b(a: i32, b: i32, c: i32) { }
- )SOURCE", 1, ".tmp_source.zig:3:6: error: wrong number of arguments. Expected 3, got 1.");
+ )SOURCE", 1, ".tmp_source.zig:3:6: error: expected 3 arguments, got 1");
add_compile_fail_case("invalid type", R"SOURCE(
fn a() -> bogus {}