Commit a292eb8d64

Andrew Kelley <superjoe30@gmail.com>
2015-12-15 08:46:56
support inline assembly expressions with return type
1 parent 66ca916
doc/langref.md
@@ -84,7 +84,7 @@ AsmOutput : token(Colon) list(AsmOutputItem, token(Comma)) option(AsmInput)
 
 AsmInput : token(Colon) list(AsmInputItem, token(Comma)) option(AsmClobbers)
 
-AsmOutputItem : token(LBracket) token(Symbol) token(RBracket) token(String) token(LParen) token(Symbol) token(RParen)
+AsmOutputItem : token(LBracket) token(Symbol) token(RBracket) token(String) token(LParen) (token(Symbol) | token(Return) Type) token(RParen)
 
 AsmInputItem : token(LBracket) token(Symbol) token(RBracket) token(String) token(LParen) Expression token(RParen)
 
src/analyze.cpp
@@ -1233,16 +1233,27 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import,
             }
         case NodeTypeAsmExpr:
             {
+                node->data.asm_expr.return_count = 0;
+                return_type = g->builtin_types.entry_void;
                 for (int i = 0; i < node->data.asm_expr.output_list.length; i += 1) {
                     AsmOutput *asm_output = node->data.asm_expr.output_list.at(i);
-                    analyze_variable_name(g, import, context, node, &asm_output->variable_name);
+                    if (asm_output->return_type) {
+                        node->data.asm_expr.return_count += 1;
+                        return_type = resolve_type(g, asm_output->return_type);
+                        if (node->data.asm_expr.return_count > 1) {
+                            add_node_error(g, node,
+                                buf_sprintf("inline assembly allows up to one output value"));
+                            break;
+                        }
+                    } else {
+                        analyze_variable_name(g, import, context, node, &asm_output->variable_name);
+                    }
                 }
                 for (int i = 0; i < node->data.asm_expr.input_list.length; i += 1) {
                     AsmInput *asm_input = node->data.asm_expr.input_list.at(i);
                     analyze_expression(g, import, context, nullptr, asm_input->expr);
                 }
 
-                return_type = g->builtin_types.entry_void;
                 break;
             }
         case NodeTypeBinOpExpr:
src/codegen.cpp
@@ -790,33 +790,45 @@ static LLVMValueRef gen_asm_expr(CodeGen *g, AstNode *node) {
 
     Buf constraint_buf = BUF_INIT;
     buf_resize(&constraint_buf, 0);
+
+    assert(asm_expr->return_count == 0 || asm_expr->return_count == 1);
+
     int total_constraint_count = asm_expr->output_list.length +
                                  asm_expr->input_list.length +
                                  asm_expr->clobber_list.length;
     int input_and_output_count = asm_expr->output_list.length +
-                                 asm_expr->input_list.length;
+                                 asm_expr->input_list.length -
+                                 asm_expr->return_count;
     int total_index = 0;
+    int param_index = 0;
     LLVMTypeRef *param_types = allocate<LLVMTypeRef>(input_and_output_count);
     LLVMValueRef *param_values = allocate<LLVMValueRef>(input_and_output_count);
     for (int i = 0; i < asm_expr->output_list.length; i += 1, total_index += 1) {
         AsmOutput *asm_output = asm_expr->output_list.at(i);
+        bool is_return = false;
         if (buf_eql_str(&asm_output->constraint, "=m")) {
             buf_append_str(&constraint_buf, "=*m");
+        } else if (buf_eql_str(&asm_output->constraint, "=r")) {
+            buf_append_str(&constraint_buf, "=r");
+            is_return = true;
         } else {
-            zig_panic("TODO unable to handle anything other than '=m' for outputs");
+            zig_panic("TODO unable to handle anything other than '=m' and '=r' for outputs");
         }
         if (total_index + 1 < total_constraint_count) {
             buf_append_char(&constraint_buf, ',');
         }
 
-        VariableTableEntry *variable = find_variable(
-                node->codegen_node->expr_node.block_context,
-                &asm_output->variable_name);
-        assert(variable);
-        param_types[total_index] = LLVMTypeOf(variable->value_ref);
-        param_values[total_index] = variable->value_ref;
+        if (!is_return) {
+            VariableTableEntry *variable = find_variable(
+                    node->codegen_node->expr_node.block_context,
+                    &asm_output->variable_name);
+            assert(variable);
+            param_types[param_index] = LLVMTypeOf(variable->value_ref);
+            param_values[param_index] = variable->value_ref;
+            param_index += 1;
+        }
     }
-    for (int i = 0; i < asm_expr->input_list.length; i += 1, total_index += 1) {
+    for (int i = 0; i < asm_expr->input_list.length; i += 1, total_index += 1, param_index += 1) {
         AsmInput *asm_input = asm_expr->input_list.at(i);
         buf_append_buf(&constraint_buf, &asm_input->constraint);
         if (total_index + 1 < total_constraint_count) {
@@ -824,8 +836,8 @@ static LLVMValueRef gen_asm_expr(CodeGen *g, AstNode *node) {
         }
 
         TypeTableEntry *expr_type = get_expr_type(asm_input->expr);
-        param_types[total_index] = expr_type->type_ref;
-        param_values[total_index] = gen_expr(g, asm_input->expr);
+        param_types[param_index] = expr_type->type_ref;
+        param_values[param_index] = gen_expr(g, asm_input->expr);
     }
     for (int i = 0; i < asm_expr->clobber_list.length; i += 1, total_index += 1) {
         Buf *clobber_buf = asm_expr->clobber_list.at(i);
@@ -835,7 +847,13 @@ static LLVMValueRef gen_asm_expr(CodeGen *g, AstNode *node) {
         }
     }
 
-    LLVMTypeRef function_type = LLVMFunctionType(LLVMVoidType(), param_types, input_and_output_count, false);
+    LLVMTypeRef ret_type;
+    if (asm_expr->return_count == 0) {
+        ret_type = LLVMVoidType();
+    } else {
+        ret_type = get_expr_type(node)->type_ref;
+    }
+    LLVMTypeRef function_type = LLVMFunctionType(ret_type, param_types, input_and_output_count, false);
 
     bool is_volatile = asm_expr->is_volatile || (asm_expr->output_list.length == 0);
     LLVMValueRef asm_fn = LLVMConstInlineAsm(function_type, buf_ptr(&llvm_template),
src/parser.cpp
@@ -1731,7 +1731,7 @@ static void ast_parse_asm_input_item(ParseContext *pc, int *token_index, AstNode
 }
 
 /*
-AsmOutputItem : token(LBracket) token(Symbol) token(RBracket) token(String) token(LParen) token(Symbol) token(RParen)
+AsmOutputItem : token(LBracket) token(Symbol) token(RBracket) token(String) token(LParen) (token(Symbol) | token(Return) Type) token(RParen)
 */
 static void ast_parse_asm_output_item(ParseContext *pc, int *token_index, AstNode *node) {
     ast_eat_token(pc, token_index, TokenIdLBracket);
@@ -1740,14 +1740,24 @@ static void ast_parse_asm_output_item(ParseContext *pc, int *token_index, AstNod
 
     Token *constraint = ast_eat_token(pc, token_index, TokenIdStringLiteral);
 
+    AsmOutput *asm_output = allocate<AsmOutput>(1);
+
     ast_eat_token(pc, token_index, TokenIdLParen);
-    Token *out_symbol = ast_eat_token(pc, token_index, TokenIdSymbol);
+
+    Token *token = &pc->tokens->at(*token_index);
+    *token_index += 1;
+    if (token->id == TokenIdSymbol) {
+        ast_buf_from_token(pc, token, &asm_output->variable_name);
+    } else if (token->id == TokenIdKeywordReturn) {
+        asm_output->return_type = ast_parse_type(pc, token_index);
+    } else {
+        ast_invalid_token_error(pc, token);
+    }
+
     ast_eat_token(pc, token_index, TokenIdRParen);
 
-    AsmOutput *asm_output = allocate<AsmOutput>(1);
     ast_buf_from_token(pc, alias, &asm_output->asm_symbolic_name);
     parse_string_literal(pc, constraint, &asm_output->constraint, nullptr, nullptr);
-    ast_buf_from_token(pc, out_symbol, &asm_output->variable_name);
     node->data.asm_expr.output_list.append(asm_output);
 }
 
src/parser.hpp
@@ -228,6 +228,7 @@ struct AsmOutput {
     Buf asm_symbolic_name;
     Buf constraint;
     Buf variable_name;
+    AstNode *return_type; // null unless "=r" and return
 };
 
 struct AsmInput {
@@ -249,6 +250,7 @@ struct AstNodeAsmExpr {
     ZigList<AsmOutput*> output_list;
     ZigList<AsmInput*> input_list;
     ZigList<Buf*> clobber_list;
+    int return_count; // populated by analyze
 };
 
 struct AstNodeStructDecl {
std/std.zig
@@ -3,20 +3,17 @@ const SYS_exit : isize = 60;
 const stdout_fileno : isize = 1;
 
 fn syscall1(number: isize, arg1: isize) -> isize {
-    var result : isize;
     asm volatile ("
         mov %[number], %%rax
         mov %[arg1], %%rdi
         syscall
         mov %%rax, %[ret]"
-        : [ret] "=m" (result)
+        : [ret] "=r" (return isize)
         : [number] "r" (number), [arg1] "r" (arg1)
-        : "rcx", "r11", "rax", "rdi");
-    return result;
+        : "rcx", "r11", "rax", "rdi")
 }
 
 fn syscall3(number: isize, arg1: isize, arg2: isize, arg3: isize) -> isize {
-    var result : isize;
     asm volatile ("
         mov %[number], %%rax
         mov %[arg1], %%rdi
@@ -24,10 +21,9 @@ fn syscall3(number: isize, arg1: isize, arg2: isize, arg3: isize) -> isize {
         mov %[arg3], %%rdx
         syscall
         mov %%rax, %[ret]"
-        : [ret] "=m" (result)
+        : [ret] "=r" (return isize)
         : [number] "r" (number), [arg1] "r" (arg1), [arg2] "r" (arg2), [arg3] "r" (arg3)
-        : "rcx", "r11", "rax", "rdi", "rsi", "rdx");
-    return result;
+        : "rcx", "r11", "rax", "rdi", "rsi", "rdx")
 }
 
 pub fn write(fd: isize, buf: &const u8, count: usize) -> isize {