falco_plugin/plugin/async_event/
background_task.rs1use std::sync::{Arc, Condvar, Mutex};
2use std::thread::JoinHandle;
3use std::time::Duration;
4
5#[derive(Debug, Copy, Clone, Eq, PartialEq)]
7pub enum RequestedState {
8 Running,
9 Stopped,
10}
11
12impl Default for RequestedState {
13 fn default() -> Self {
14 Self::Stopped
15 }
16}
17
18#[derive(Default, Debug)]
25pub struct BackgroundTask {
26 lock: Mutex<RequestedState>,
27 cond: Condvar,
28}
29
30impl BackgroundTask {
31 pub fn request_start(&self) -> Result<(), anyhow::Error> {
33 *self
34 .lock
35 .lock()
36 .map_err(|e| anyhow::anyhow!(e.to_string()))? = RequestedState::Running;
37
38 Ok(())
39 }
40
41 pub fn request_stop_and_notify(&self) -> Result<(), anyhow::Error> {
43 *self
44 .lock
45 .lock()
46 .map_err(|e| anyhow::anyhow!(e.to_string()))? = RequestedState::Stopped;
47 self.cond.notify_one();
48
49 Ok(())
50 }
51
52 pub fn should_keep_running(&self, timeout: Duration) -> Result<bool, anyhow::Error> {
62 let (_guard, wait_res) = self
63 .cond
64 .wait_timeout_while(
65 self.lock
66 .lock()
67 .map_err(|e| anyhow::anyhow!(e.to_string()))?,
68 timeout,
69 |&mut state| state == RequestedState::Running,
70 )
71 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
72
73 Ok(wait_res.timed_out())
74 }
75
76 pub fn spawn<F>(
90 self: &Arc<Self>,
91 interval: Duration,
92 mut func: F,
93 ) -> Result<JoinHandle<Result<(), anyhow::Error>>, anyhow::Error>
94 where
95 F: FnMut() -> Result<(), anyhow::Error> + 'static + Send,
96 {
97 self.request_start()?;
98 let clone = Arc::clone(self);
99
100 Ok(std::thread::spawn(move || {
101 while clone.should_keep_running(interval)? {
102 func()?
103 }
104
105 Ok(())
106 }))
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use crate::plugin::async_event::background_task::BackgroundTask;
113 use std::sync::atomic::{AtomicUsize, Ordering};
114 use std::sync::Arc;
115 use std::time::{Duration, Instant};
116
117 #[test]
118 fn test_stop_request() {
119 let req = Arc::new(BackgroundTask::default());
120 let counter = Arc::new(AtomicUsize::default());
121
122 let req_clone = Arc::clone(&req);
123 let counter_clone = Arc::clone(&counter);
124
125 req.request_start().unwrap();
126 let handle = std::thread::spawn(move || {
127 while req_clone
128 .should_keep_running(Duration::from_millis(100))
129 .unwrap()
130 {
131 counter_clone.fetch_add(1, Ordering::Relaxed);
132 }
133 });
134
135 let start_time = Instant::now();
136 std::thread::sleep(Duration::from_millis(450));
137 req.request_stop_and_notify().unwrap();
138 handle.join().unwrap();
139
140 let elapsed = start_time.elapsed();
141 assert_eq!(counter.load(Ordering::Relaxed), 4);
142
143 let millis = elapsed.as_millis();
144 assert!(millis >= 450);
145 assert!(millis < 500);
146 }
147
148 #[test]
149 fn test_spawn() {
150 let req = Arc::new(BackgroundTask::default());
151 let counter = Arc::new(AtomicUsize::default());
152 let counter_clone = Arc::clone(&counter);
153
154 let handle = req
155 .spawn(Duration::from_millis(100), move || {
156 counter_clone.fetch_add(1, Ordering::Relaxed);
157 Ok(())
158 })
159 .unwrap();
160
161 let start_time = Instant::now();
162 std::thread::sleep(Duration::from_millis(450));
163 req.request_stop_and_notify().unwrap();
164 handle.join().unwrap().unwrap();
165
166 let elapsed = start_time.elapsed();
167 assert_eq!(counter.load(Ordering::Relaxed), 4);
168
169 let millis = elapsed.as_millis();
170 assert!(millis >= 450);
171 assert!(millis < 500);
172 }
173}