master
1//! A semaphore is an unsigned integer that blocks the kernel thread if
2//! the number would become negative.
3//! This API supports static initialization and does not require deinitialization.
4//!
5//! Example:
6//! ```
7//! var s = Semaphore{};
8//!
9//! fn consumer() void {
10//! s.wait();
11//! }
12//!
13//! fn producer() void {
14//! s.post();
15//! }
16//!
17//! const thread = try std.Thread.spawn(.{}, producer, .{});
18//! consumer();
19//! thread.join();
20//! ```
21
22mutex: Mutex = .{},
23cond: Condition = .{},
24/// It is OK to initialize this field to any value.
25permits: usize = 0,
26
27const Semaphore = @This();
28const std = @import("../std.zig");
29const Mutex = std.Thread.Mutex;
30const Condition = std.Thread.Condition;
31const builtin = @import("builtin");
32const testing = std.testing;
33
34pub fn wait(sem: *Semaphore) void {
35 sem.mutex.lock();
36 defer sem.mutex.unlock();
37
38 while (sem.permits == 0)
39 sem.cond.wait(&sem.mutex);
40
41 sem.permits -= 1;
42 if (sem.permits > 0)
43 sem.cond.signal();
44}
45
46pub fn timedWait(sem: *Semaphore, timeout_ns: u64) error{Timeout}!void {
47 var timeout_timer = std.time.Timer.start() catch unreachable;
48
49 sem.mutex.lock();
50 defer sem.mutex.unlock();
51
52 while (sem.permits == 0) {
53 const elapsed = timeout_timer.read();
54 if (elapsed > timeout_ns)
55 return error.Timeout;
56
57 const local_timeout_ns = timeout_ns - elapsed;
58 try sem.cond.timedWait(&sem.mutex, local_timeout_ns);
59 }
60
61 sem.permits -= 1;
62 if (sem.permits > 0)
63 sem.cond.signal();
64}
65
66pub fn post(sem: *Semaphore) void {
67 sem.mutex.lock();
68 defer sem.mutex.unlock();
69
70 sem.permits += 1;
71 sem.cond.signal();
72}
73
74test Semaphore {
75 if (builtin.single_threaded) {
76 return error.SkipZigTest;
77 }
78
79 const TestContext = struct {
80 sem: *Semaphore,
81 n: *i32,
82 fn worker(ctx: *@This()) void {
83 ctx.sem.wait();
84 ctx.n.* += 1;
85 ctx.sem.post();
86 }
87 };
88 const num_threads = 3;
89 var sem = Semaphore{ .permits = 1 };
90 var threads: [num_threads]std.Thread = undefined;
91 var n: i32 = 0;
92 var ctx = TestContext{ .sem = &sem, .n = &n };
93
94 for (&threads) |*t| t.* = try std.Thread.spawn(.{}, TestContext.worker, .{&ctx});
95 for (threads) |t| t.join();
96 sem.wait();
97 try testing.expect(n == num_threads);
98}
99
100test timedWait {
101 var sem = Semaphore{};
102 try testing.expectEqual(0, sem.permits);
103
104 try testing.expectError(error.Timeout, sem.timedWait(1));
105
106 sem.post();
107 try testing.expectEqual(1, sem.permits);
108
109 try sem.timedWait(1);
110 try testing.expectEqual(0, sem.permits);
111}