falco_plugin/async_event/
background_task.rs1use std::sync::{Arc, Condvar, Mutex};
2use std::thread::JoinHandle;
3use std::time::Duration;
4
5#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
7pub enum RequestedState {
8 Running,
9 #[default]
10 Stopped,
11}
12
13#[derive(Default, Debug)]
20pub struct BackgroundTask {
21 lock: Mutex<RequestedState>,
22 cond: Condvar,
23}
24
25impl BackgroundTask {
26 pub fn request_start(&self) -> Result<(), anyhow::Error> {
28 *self
29 .lock
30 .lock()
31 .map_err(|e| anyhow::anyhow!(e.to_string()))? = RequestedState::Running;
32
33 Ok(())
34 }
35
36 pub fn request_stop_and_notify(&self) -> Result<(), anyhow::Error> {
38 *self
39 .lock
40 .lock()
41 .map_err(|e| anyhow::anyhow!(e.to_string()))? = RequestedState::Stopped;
42 self.cond.notify_one();
43
44 Ok(())
45 }
46
47 pub fn should_keep_running(&self, timeout: Duration) -> Result<bool, anyhow::Error> {
57 let (_guard, wait_res) = self
58 .cond
59 .wait_timeout_while(
60 self.lock
61 .lock()
62 .map_err(|e| anyhow::anyhow!(e.to_string()))?,
63 timeout,
64 |&mut state| state == RequestedState::Running,
65 )
66 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
67
68 Ok(wait_res.timed_out())
69 }
70
71 pub fn spawn<F>(
85 self: &Arc<Self>,
86 interval: Duration,
87 mut func: F,
88 ) -> Result<JoinHandle<Result<(), anyhow::Error>>, anyhow::Error>
89 where
90 F: FnMut() -> Result<(), anyhow::Error> + 'static + Send,
91 {
92 self.request_start()?;
93 let clone = Arc::clone(self);
94
95 Ok(std::thread::spawn(move || {
96 while clone.should_keep_running(interval)? {
97 func()?
98 }
99
100 Ok(())
101 }))
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use crate::async_event::background_task::BackgroundTask;
108 use std::sync::atomic::{AtomicUsize, Ordering};
109 use std::sync::Arc;
110 use std::time::{Duration, Instant};
111
112 #[test]
113 fn test_stop_request() {
114 let req = Arc::new(BackgroundTask::default());
115 let counter = Arc::new(AtomicUsize::default());
116
117 let req_clone = Arc::clone(&req);
118 let counter_clone = Arc::clone(&counter);
119
120 req.request_start().unwrap();
121 let handle = std::thread::spawn(move || {
122 while req_clone
123 .should_keep_running(Duration::from_millis(100))
124 .unwrap()
125 {
126 counter_clone.fetch_add(1, Ordering::Relaxed);
127 }
128 });
129
130 let start_time = Instant::now();
131 std::thread::sleep(Duration::from_millis(450));
132 req.request_stop_and_notify().unwrap();
133 handle.join().unwrap();
134
135 let elapsed = start_time.elapsed();
136 assert_eq!(counter.load(Ordering::Relaxed), 4);
137
138 let millis = elapsed.as_millis();
139 assert!(millis >= 450);
140 assert!(millis < 500);
141 }
142
143 #[test]
144 fn test_spawn() {
145 let req = Arc::new(BackgroundTask::default());
146 let counter = Arc::new(AtomicUsize::default());
147 let counter_clone = Arc::clone(&counter);
148
149 let handle = req
150 .spawn(Duration::from_millis(100), move || {
151 counter_clone.fetch_add(1, Ordering::Relaxed);
152 Ok(())
153 })
154 .unwrap();
155
156 let start_time = Instant::now();
157 std::thread::sleep(Duration::from_millis(450));
158 req.request_stop_and_notify().unwrap();
159 handle.join().unwrap().unwrap();
160
161 let elapsed = start_time.elapsed();
162 assert_eq!(counter.load(Ordering::Relaxed), 4);
163
164 let millis = elapsed.as_millis();
165 assert!(millis >= 450);
166 assert!(millis < 500);
167 }
168}