Commit 8dc188ebe0

Vexu <git@vexu.eu>
2020-03-10 21:33:32
support atomic operations with bools
1 parent 675f01f
Changed files (3)
src
test
stage1
behavior
src/codegen.cpp
@@ -5224,6 +5224,15 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I
     LLVMValueRef cmp_val = ir_llvm_value(g, instruction->cmp_value);
     LLVMValueRef new_val = ir_llvm_value(g, instruction->new_value);
 
+    ZigType *operand_type = instruction->new_value->value->type;
+    if (operand_type->id == ZigTypeIdBool) {
+        // treat bool as u8
+        ptr_val = LLVMBuildBitCast(g->builder, ptr_val,
+            LLVMPointerType(g->builtin_types.entry_u8->llvm_type, 0), "");
+        cmp_val = LLVMConstZExt(cmp_val, g->builtin_types.entry_u8->llvm_type);
+        new_val = LLVMConstZExt(new_val, g->builtin_types.entry_u8->llvm_type);
+    }
+
     LLVMAtomicOrdering success_order = to_LLVMAtomicOrdering(instruction->success_order);
     LLVMAtomicOrdering failure_order = to_LLVMAtomicOrdering(instruction->failure_order);
 
@@ -5236,6 +5245,9 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I
 
     if (!handle_is_ptr(g, optional_type)) {
         LLVMValueRef payload_val = LLVMBuildExtractValue(g->builder, result_val, 0, "");
+        if (operand_type->id == ZigTypeIdBool) {
+            payload_val = LLVMBuildTrunc(g->builder, payload_val, g->builtin_types.entry_bool->llvm_type, "");
+        }
         LLVMValueRef success_bit = LLVMBuildExtractValue(g->builder, result_val, 1, "");
         return LLVMBuildSelect(g->builder, success_bit, LLVMConstNull(get_llvm_type(g, child_type)), payload_val, "");
     }
@@ -5250,6 +5262,9 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I
     ir_assert(type_has_bits(g, child_type), &instruction->base);
 
     LLVMValueRef payload_val = LLVMBuildExtractValue(g->builder, result_val, 0, "");
+    if (operand_type->id == ZigTypeIdBool) {
+        payload_val = LLVMBuildTrunc(g->builder, payload_val, g->builtin_types.entry_bool->llvm_type, "");
+    }
     LLVMValueRef val_ptr = LLVMBuildStructGEP(g->builder, result_loc, maybe_child_index, "");
     gen_assign_raw(g, val_ptr, get_pointer_to_type(g, child_type, false), payload_val);
 
@@ -5827,6 +5842,16 @@ static LLVMValueRef ir_render_atomic_rmw(CodeGen *g, IrExecutableGen *executable
     LLVMValueRef ptr = ir_llvm_value(g, instruction->ptr);
     LLVMValueRef operand = ir_llvm_value(g, instruction->operand);
 
+    if (operand_type->id == ZigTypeIdBool) {
+        // treat bool as u8
+        LLVMValueRef casted_ptr = LLVMBuildBitCast(g->builder, ptr,
+            LLVMPointerType(g->builtin_types.entry_u8->llvm_type, 0), "");
+        LLVMValueRef casted_operand = LLVMBuildPtrToInt(g->builder, operand, g->builtin_types.entry_u8->llvm_type, "");
+        LLVMValueRef uncasted_result = ZigLLVMBuildAtomicRMW(g->builder, op, casted_ptr, casted_operand, ordering,
+                g->is_single_threaded);
+        return LLVMBuildTrunc(g->builder, uncasted_result, g->builtin_types.entry_bool->llvm_type, "");
+    }
+
     if (get_codegen_ptr_type_bail(g, operand_type) == nullptr) {
         return ZigLLVMBuildAtomicRMW(g->builder, op, ptr, operand, ordering, g->is_single_threaded);
     }
@@ -5845,6 +5870,16 @@ static LLVMValueRef ir_render_atomic_load(CodeGen *g, IrExecutableGen *executabl
 {
     LLVMAtomicOrdering ordering = to_LLVMAtomicOrdering(instruction->ordering);
     LLVMValueRef ptr = ir_llvm_value(g, instruction->ptr);
+
+    ZigType *operand_type = instruction->ptr->value->type->data.pointer.child_type;
+    if (operand_type->id == ZigTypeIdBool) {
+        // treat bool as u8
+        ptr = LLVMBuildBitCast(g->builder, ptr,
+                LLVMPointerType(g->builtin_types.entry_u8->llvm_type, 0), "");
+        LLVMValueRef load_inst = gen_load(g, ptr, instruction->ptr->value->type, "");
+        LLVMSetOrdering(load_inst, ordering);
+        return LLVMBuildTrunc(g->builder, load_inst, g->builtin_types.entry_bool->llvm_type, "");
+    }
     LLVMValueRef load_inst = gen_load(g, ptr, instruction->ptr->value->type, "");
     LLVMSetOrdering(load_inst, ordering);
     return load_inst;
@@ -5856,6 +5891,14 @@ static LLVMValueRef ir_render_atomic_store(CodeGen *g, IrExecutableGen *executab
     LLVMAtomicOrdering ordering = to_LLVMAtomicOrdering(instruction->ordering);
     LLVMValueRef ptr = ir_llvm_value(g, instruction->ptr);
     LLVMValueRef value = ir_llvm_value(g, instruction->value);
+
+    ZigType *operand_type = instruction->value->value->type;
+    if (operand_type->id == ZigTypeIdBool) {
+        // treat bool as u8
+        ptr = LLVMBuildBitCast(g->builder, ptr,
+                LLVMPointerType(g->builtin_types.entry_u8->llvm_type, 0), "");
+        value = LLVMConstZExt(value, g->builtin_types.entry_u8->llvm_type);
+    }
     LLVMValueRef store_inst = gen_store(g, value, ptr, instruction->ptr->value->type);
     LLVMSetOrdering(store_inst, ordering);
     return nullptr;
src/ir.cpp
@@ -28357,6 +28357,8 @@ static ZigType *ir_resolve_atomic_operand_type(IrAnalyze *ira, IrInstGen *op) {
                     max_atomic_bits, (uint32_t) operand_type->data.floating.bit_count));
             return ira->codegen->builtin_types.entry_invalid;
         }
+    } else if (operand_type->id == ZigTypeIdBool) {
+        // will be treated as u8
     } else {
         Error err;
         ZigType *operand_ptr_type;
@@ -28397,6 +28399,10 @@ static IrInstGen *ir_analyze_instruction_atomic_rmw(IrAnalyze *ira, IrInstSrcAto
         ir_add_error(ira, &instruction->op->base,
             buf_sprintf("@atomicRmw on enum only works with .Xchg"));
         return ira->codegen->invalid_inst_gen;
+    } else if (operand_type->id == ZigTypeIdBool && op != AtomicRmwOp_xchg) {
+        ir_add_error(ira, &instruction->op->base,
+            buf_sprintf("@atomicRmw on bool only works with .Xchg"));
+        return ira->codegen->invalid_inst_gen;
     } else if (operand_type->id == ZigTypeIdFloat && op > AtomicRmwOp_sub) {
         ir_add_error(ira, &instruction->op->base,
             buf_sprintf("@atomicRmw with float only works with .Xchg, .Add and .Sub"));
test/stage1/behavior/atomics.zig
@@ -161,3 +161,13 @@ fn testAtomicRmwFloat() void {
     _ = @atomicRmw(f32, &x, .Sub, 2, .SeqCst);
     expect(x == 4);
 }
+
+test "atomics with bool" {
+    var x = false;
+    @atomicStore(bool, &x, true, .SeqCst);
+    expect(x == true);
+    expect(@atomicLoad(bool, &x, .SeqCst) == true);
+    expect(@atomicRmw(bool, &x, .Xchg, false, .SeqCst) == true);
+    expect(@cmpxchgStrong(bool, &x, false, true, .SeqCst, .SeqCst) == null);
+    expect(@cmpxchgStrong(bool, &x, false, true, .SeqCst, .SeqCst).? == true);
+}