Commit c95e497857

Andrew Kelley <superjoe30@gmail.com>
2016-05-05 03:19:49
add cmpxchg builtin function
1 parent f2bc5cc
src/all_types.hpp
@@ -1115,6 +1115,7 @@ enum BuiltinFnId {
     BuiltinFnIdErrName,
     BuiltinFnIdBreakpoint,
     BuiltinFnIdEmbedFile,
+    BuiltinFnIdCmpExchange,
 };
 
 struct BuiltinFnEntry {
@@ -1183,6 +1184,7 @@ struct CodeGen {
         TypeTableEntry *entry_os_enum;
         TypeTableEntry *entry_arch_enum;
         TypeTableEntry *entry_environ_enum;
+        TypeTableEntry *entry_mem_order_enum;
     } builtin_types;
 
     ZigTarget zig_target;
@@ -1322,6 +1324,14 @@ struct BlockContext {
     bool safety_off;
 };
 
+enum AtomicOrder {
+    AtomicOrderUnordered,
+    AtomicOrderMonotonic,
+    AtomicOrderAcquire,
+    AtomicOrderRelease,
+    AtomicOrderAcqRel,
+    AtomicOrderSeqCst,
+};
 
 
 #endif
src/analyze.cpp
@@ -4414,6 +4414,77 @@ static TypeTableEntry *analyze_embed_file(CodeGen *g, ImportTableEntry *import,
     return resolve_expr_const_val_as_string_lit(g, node, &file_contents);
 }
 
+static TypeTableEntry *analyze_cmpxchg(CodeGen *g, ImportTableEntry *import,
+        BlockContext *context, AstNode *node)
+{
+    AstNode **ptr_arg = &node->data.fn_call_expr.params.at(0);
+    AstNode **cmp_arg = &node->data.fn_call_expr.params.at(1);
+    AstNode **new_arg = &node->data.fn_call_expr.params.at(2);
+    AstNode **success_order_arg = &node->data.fn_call_expr.params.at(3);
+    AstNode **failure_order_arg = &node->data.fn_call_expr.params.at(4);
+
+    TypeTableEntry *ptr_type = analyze_expression(g, import, context, nullptr, *ptr_arg);
+    if (ptr_type->id == TypeTableEntryIdInvalid) {
+        return g->builtin_types.entry_invalid;
+    } else if (ptr_type->id != TypeTableEntryIdPointer) {
+        add_node_error(g, *ptr_arg,
+            buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&ptr_type->name)));
+        return g->builtin_types.entry_invalid;
+    }
+
+    TypeTableEntry *child_type = ptr_type->data.pointer.child_type;
+    TypeTableEntry *cmp_type = analyze_expression(g, import, context, child_type, *cmp_arg);
+    TypeTableEntry *new_type = analyze_expression(g, import, context, child_type, *new_arg);
+
+    TypeTableEntry *success_order_type = analyze_expression(g, import, context,
+            g->builtin_types.entry_mem_order_enum, *success_order_arg);
+    TypeTableEntry *failure_order_type = analyze_expression(g, import, context,
+            g->builtin_types.entry_mem_order_enum, *failure_order_arg);
+
+    if (cmp_type->id == TypeTableEntryIdInvalid ||
+        new_type->id == TypeTableEntryIdInvalid ||
+        success_order_type->id == TypeTableEntryIdInvalid ||
+        failure_order_type->id == TypeTableEntryIdInvalid)
+    {
+        return g->builtin_types.entry_invalid;
+    }
+
+    ConstExprValue *success_order_val = &get_resolved_expr(*success_order_arg)->const_val;
+    ConstExprValue *failure_order_val = &get_resolved_expr(*failure_order_arg)->const_val;
+    if (!success_order_val->ok) {
+        add_node_error(g, *success_order_arg, buf_sprintf("unable to evaluate constant expression"));
+        return g->builtin_types.entry_invalid;
+    } else if (!failure_order_val->ok) {
+        add_node_error(g, *failure_order_arg, buf_sprintf("unable to evaluate constant expression"));
+        return g->builtin_types.entry_invalid;
+    }
+
+    if (success_order_val->data.x_enum.tag < AtomicOrderMonotonic) {
+        add_node_error(g, *success_order_arg,
+                buf_sprintf("success atomic ordering must be Monotonic or stricter"));
+        return g->builtin_types.entry_invalid;
+    }
+    if (failure_order_val->data.x_enum.tag < AtomicOrderMonotonic) {
+        add_node_error(g, *failure_order_arg,
+                buf_sprintf("failure atomic ordering must be Monotonic or stricter"));
+        return g->builtin_types.entry_invalid;
+    }
+    if (failure_order_val->data.x_enum.tag > success_order_val->data.x_enum.tag) {
+        add_node_error(g, *failure_order_arg,
+                buf_sprintf("failure atomic ordering must be no stricter than success"));
+        return g->builtin_types.entry_invalid;
+    }
+    if (failure_order_val->data.x_enum.tag == AtomicOrderRelease ||
+        failure_order_val->data.x_enum.tag == AtomicOrderAcqRel)
+    {
+        add_node_error(g, *failure_order_arg,
+                buf_sprintf("failure atomic ordering must not be Release or AcqRel"));
+        return g->builtin_types.entry_invalid;
+    }
+
+    return g->builtin_types.entry_bool;
+}
+
 static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context,
         TypeTableEntry *expected_type, AstNode *node)
 {
@@ -4750,6 +4821,8 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry
             return g->builtin_types.entry_void;
         case BuiltinFnIdEmbedFile:
             return analyze_embed_file(g, import, context, node);
+        case BuiltinFnIdCmpExchange:
+            return analyze_cmpxchg(g, import, context, node);
     }
     zig_unreachable();
 }
src/codegen.cpp
@@ -401,6 +401,46 @@ static LLVMValueRef gen_err_name(CodeGen *g, AstNode *node) {
     return LLVMBuildInBoundsGEP(g->builder, g->err_name_table, indices, 2, "");
 }
 
+static LLVMAtomicOrdering to_LLVMAtomicOrdering(AtomicOrder atomic_order) {
+    switch (atomic_order) {
+        case AtomicOrderUnordered: return LLVMAtomicOrderingUnordered;
+        case AtomicOrderMonotonic: return LLVMAtomicOrderingMonotonic;
+        case AtomicOrderAcquire: return LLVMAtomicOrderingAcquire;
+        case AtomicOrderRelease: return LLVMAtomicOrderingRelease;
+        case AtomicOrderAcqRel: return LLVMAtomicOrderingAcquireRelease;
+        case AtomicOrderSeqCst: return LLVMAtomicOrderingSequentiallyConsistent;
+    }
+    zig_unreachable();
+}
+
+static LLVMValueRef gen_cmp_exchange(CodeGen *g, AstNode *node) {
+    assert(node->type == NodeTypeFnCallExpr);
+
+    AstNode *ptr_arg = node->data.fn_call_expr.params.at(0);
+    AstNode *cmp_arg = node->data.fn_call_expr.params.at(1);
+    AstNode *new_arg = node->data.fn_call_expr.params.at(2);
+    AstNode *success_order_arg = node->data.fn_call_expr.params.at(3);
+    AstNode *failure_order_arg = node->data.fn_call_expr.params.at(4);
+
+    LLVMValueRef ptr_val = gen_expr(g, ptr_arg);
+    LLVMValueRef cmp_val = gen_expr(g, cmp_arg);
+    LLVMValueRef new_val = gen_expr(g, new_arg);
+
+    ConstExprValue *success_order_val = &get_resolved_expr(success_order_arg)->const_val;
+    ConstExprValue *failure_order_val = &get_resolved_expr(failure_order_arg)->const_val;
+
+    assert(success_order_val->ok);
+    assert(failure_order_val->ok);
+
+    LLVMAtomicOrdering success_order = to_LLVMAtomicOrdering((AtomicOrder)success_order_val->data.x_enum.tag);
+    LLVMAtomicOrdering failure_order = to_LLVMAtomicOrdering((AtomicOrder)failure_order_val->data.x_enum.tag);
+
+    LLVMValueRef result_val = ZigLLVMBuildCmpXchg(g->builder, ptr_val, cmp_val, new_val,
+            success_order, failure_order, "");
+
+    return LLVMBuildExtractValue(g->builder, result_val, 1, "");
+}
+
 static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
     assert(node->type == NodeTypeFnCallExpr);
     AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
@@ -546,6 +586,8 @@ static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
         case BuiltinFnIdBreakpoint:
             set_debug_source_node(g, node);
             return LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, "");
+        case BuiltinFnIdCmpExchange:
+            return gen_cmp_exchange(g, node);
     }
     zig_unreachable();
 }
@@ -4052,6 +4094,7 @@ static void define_builtin_types(CodeGen *g) {
             ZigLLVM_EnvironmentType environ_type = get_target_environ(i);
             type_enum_field->name = buf_create_from_str(ZigLLVMGetEnvironmentTypeName(environ_type));
             type_enum_field->value = i;
+            type_enum_field->type_entry = g->builtin_types.entry_void;
 
             if (environ_type == g->zig_target.env_type) {
                 g->target_environ_index = i;
@@ -4064,6 +4107,41 @@ static void define_builtin_types(CodeGen *g) {
 
         g->builtin_types.entry_environ_enum = entry;
     }
+
+    {
+        TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdEnum);
+        entry->deep_const = true;
+        buf_init_from_str(&entry->name, "AtomicOrder");
+        uint32_t field_count = 6;
+        entry->data.enumeration.field_count = field_count;
+        entry->data.enumeration.fields = allocate<TypeEnumField>(field_count);
+        entry->data.enumeration.fields[0].name = buf_create_from_str("Unordered");
+        entry->data.enumeration.fields[0].value = AtomicOrderUnordered;
+        entry->data.enumeration.fields[0].type_entry = g->builtin_types.entry_void;
+        entry->data.enumeration.fields[1].name = buf_create_from_str("Monotonic");
+        entry->data.enumeration.fields[1].value = AtomicOrderMonotonic;
+        entry->data.enumeration.fields[1].type_entry = g->builtin_types.entry_void;
+        entry->data.enumeration.fields[2].name = buf_create_from_str("Acquire");
+        entry->data.enumeration.fields[2].value = AtomicOrderAcquire;
+        entry->data.enumeration.fields[2].type_entry = g->builtin_types.entry_void;
+        entry->data.enumeration.fields[3].name = buf_create_from_str("Release");
+        entry->data.enumeration.fields[3].value = AtomicOrderRelease;
+        entry->data.enumeration.fields[3].type_entry = g->builtin_types.entry_void;
+        entry->data.enumeration.fields[4].name = buf_create_from_str("AcqRel");
+        entry->data.enumeration.fields[4].value = AtomicOrderAcqRel;
+        entry->data.enumeration.fields[4].type_entry = g->builtin_types.entry_void;
+        entry->data.enumeration.fields[5].name = buf_create_from_str("SeqCst");
+        entry->data.enumeration.fields[5].value = AtomicOrderSeqCst;
+        entry->data.enumeration.fields[5].type_entry = g->builtin_types.entry_void;
+
+        entry->data.enumeration.complete = true;
+
+        TypeTableEntry *tag_type_entry = get_smallest_unsigned_int_type(g, field_count);
+        entry->data.enumeration.tag_type = tag_type_entry;
+
+        g->builtin_types.entry_mem_order_enum = entry;
+        g->primitive_type_table.put(&entry->name, entry);
+    }
 }
 
 
@@ -4162,6 +4240,8 @@ static void define_builtin_fns(CodeGen *g) {
     create_builtin_fn_with_arg_count(g, BuiltinFnIdCImport, "c_import", 1);
     create_builtin_fn_with_arg_count(g, BuiltinFnIdErrName, "err_name", 1);
     create_builtin_fn_with_arg_count(g, BuiltinFnIdEmbedFile, "embed_file", 1);
+    create_builtin_fn_with_arg_count(g, BuiltinFnIdCmpExchange, "cmpxchg", 5);
+    //create_builtin_fn_with_arg_count(g, BuiltinFnIdAtomicRmw, "atomicrmw", 1);
 }
 
 static void init(CodeGen *g, Buf *source_path) {
src/eval.cpp
@@ -704,6 +704,7 @@ static bool eval_fn_call_builtin(EvalFn *ef, AstNode *node, ConstExprValue *out_
         case BuiltinFnIdCImport:
         case BuiltinFnIdErrName:
         case BuiltinFnIdEmbedFile:
+        case BuiltinFnIdCmpExchange:
             zig_panic("TODO");
         case BuiltinFnIdBreakpoint:
         case BuiltinFnIdInvalid:
src/zig_llvm.cpp
@@ -645,6 +645,31 @@ unsigned ZigLLVMGetPrefTypeAlignment(LLVMTargetDataRef TD, LLVMTypeRef Ty) {
     return unwrap(TD)->getPrefTypeAlignment(unwrap(Ty));
 }
 
+
+static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering) {
+    switch (Ordering) {
+        case LLVMAtomicOrderingNotAtomic: return NotAtomic;
+        case LLVMAtomicOrderingUnordered: return Unordered;
+        case LLVMAtomicOrderingMonotonic: return Monotonic;
+        case LLVMAtomicOrderingAcquire: return Acquire;
+        case LLVMAtomicOrderingRelease: return Release;
+        case LLVMAtomicOrderingAcquireRelease: return AcquireRelease;
+        case LLVMAtomicOrderingSequentiallyConsistent: return SequentiallyConsistent;
+    }
+    abort();
+}
+
+LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp,
+        LLVMValueRef new_val, LLVMAtomicOrdering success_ordering,
+        LLVMAtomicOrdering failure_ordering,
+        const char *name)
+{
+    return wrap(unwrap(builder)->CreateAtomicCmpXchg(unwrap(ptr), unwrap(cmp), unwrap(new_val),
+                mapFromLLVMOrdering(success_ordering), mapFromLLVMOrdering(failure_ordering),
+                CrossThread));
+}
+
+
 //------------------------------------
 
 #include "buffer.hpp"
src/zig_llvm.hpp
@@ -39,6 +39,11 @@ void LLVMZigOptimizeModule(LLVMTargetMachineRef targ_machine_ref, LLVMModuleRef
 LLVMValueRef LLVMZigBuildCall(LLVMBuilderRef B, LLVMValueRef Fn, LLVMValueRef *Args,
         unsigned NumArgs, unsigned CC, const char *Name);
 
+LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp,
+        LLVMValueRef new_val, LLVMAtomicOrdering success_ordering,
+        LLVMAtomicOrdering failure_ordering,
+        const char *name);
+
 // 0 is return value, 1 is first arg
 void LLVMZigAddNonNullAttr(LLVMValueRef fn, unsigned i);
 
test/run_tests.cpp
@@ -1295,8 +1295,20 @@ fn foo() {
 #static_eval_enable(false)
 fn bar() -> i32 { 2 }
     )SOURCE", 1, ".tmp_source.zig:3:15: error: unable to infer expression type");
+
+    add_compile_fail_case("atomic orderings of cmpxchg", R"SOURCE(
+fn f() {
+    var x: i32 = 1234;
+    while (!@cmpxchg(&x, 1234, 5678, AtomicOrder.Monotonic, AtomicOrder.SeqCst)) {}
+    while (!@cmpxchg(&x, 1234, 5678, AtomicOrder.Unordered, AtomicOrder.Unordered)) {}
+}
+    )SOURCE", 2,
+            ".tmp_source.zig:4:72: error: failure atomic ordering must be no stricter than success",
+            ".tmp_source.zig:5:49: error: success atomic ordering must be Monotonic or stricter");
 }
 
+//////////////////////////////////////////////////////////////////////////////
+
 static void add_debug_safety_test_cases(void) {
     add_debug_safety_case("out of bounds slice access", R"SOURCE(
 pub fn main(args: [][]u8) -> %void {
test/self_hosted.zig
@@ -1442,3 +1442,10 @@ fn assign_to_if_var_ptr() {
 
     assert(??maybe_bool == false);
 }
+
+#attribute("test")
+fn cmpxchg() {
+    var x: i32 = 1234;
+    while (!@cmpxchg(&x, 1234, 5678, AtomicOrder.SeqCst, AtomicOrder.SeqCst)) {}
+    assert(x == 5678);
+}