falco_plugin/async_event/
background_task.rs

1use std::sync::{Arc, Condvar, Mutex};
2use std::thread::JoinHandle;
3use std::time::Duration;
4
5/// A trivial enum to indicate the requested state of the async background task
6#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
7pub enum RequestedState {
8    Running,
9    #[default]
10    Stopped,
11}
12
13/// # A helper to periodically run a background task until shutdown is requested
14///
15/// Can be used to spawn a separate thread or as a building block for some other
16/// (synchronous/blocking) abstraction.
17///
18/// The implementation is little more than a [`Condvar`] and some helper methods.
19#[derive(Default, Debug)]
20pub struct BackgroundTask {
21    lock: Mutex<RequestedState>,
22    cond: Condvar,
23}
24
25impl BackgroundTask {
26    /// Mark the task as ready to run
27    pub fn request_start(&self) -> Result<(), anyhow::Error> {
28        *self
29            .lock
30            .lock()
31            .map_err(|e| anyhow::anyhow!(e.to_string()))? = RequestedState::Running;
32
33        Ok(())
34    }
35
36    /// Request the task to stop
37    pub fn request_stop_and_notify(&self) -> Result<(), anyhow::Error> {
38        *self
39            .lock
40            .lock()
41            .map_err(|e| anyhow::anyhow!(e.to_string()))? = RequestedState::Stopped;
42        self.cond.notify_one();
43
44        Ok(())
45    }
46
47    /// Wait for a stop request for up to `timeout`
48    ///
49    /// Usable in a loop like:
50    ///
51    /// ```ignore
52    /// while task.should_keep_running(timeout)? {
53    ///     do_things_on_every_timeout()?;
54    /// }
55    /// ```
56    pub fn should_keep_running(&self, timeout: Duration) -> Result<bool, anyhow::Error> {
57        let (_guard, wait_res) = self
58            .cond
59            .wait_timeout_while(
60                self.lock
61                    .lock()
62                    .map_err(|e| anyhow::anyhow!(e.to_string()))?,
63                timeout,
64                |&mut state| state == RequestedState::Running,
65            )
66            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
67
68        Ok(wait_res.timed_out())
69    }
70
71    /// Spawn a background thread that calls `func` every `interval` until shutdown
72    ///
73    /// Ideally, the called closure should not block for any noticeable time, as shutdown
74    /// requests are not processed while it's running.
75    ///
76    /// This method does not attempt to compensate for the closure running time and does not
77    /// try to guarantee that it's executed exactly every `interval`. If you need precise
78    /// intervals between each execution, you should start the thread yourself and calculate
79    /// the timeout passed to [`BackgroundTask::should_keep_running`] every time. You will also
80    /// need to handle the case when the closure takes longer than the interval:
81    /// - just skip the next execution?
82    /// - try to catch up by running the closure back-to-back without a delay?
83    /// - return an error (and stop the background thread)?
84    pub fn spawn<F>(
85        self: &Arc<Self>,
86        interval: Duration,
87        mut func: F,
88    ) -> Result<JoinHandle<Result<(), anyhow::Error>>, anyhow::Error>
89    where
90        F: FnMut() -> Result<(), anyhow::Error> + 'static + Send,
91    {
92        self.request_start()?;
93        let clone = Arc::clone(self);
94
95        Ok(std::thread::spawn(move || {
96            while clone.should_keep_running(interval)? {
97                func()?
98            }
99
100            Ok(())
101        }))
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use crate::async_event::background_task::BackgroundTask;
108    use std::sync::atomic::{AtomicUsize, Ordering};
109    use std::sync::Arc;
110    use std::time::{Duration, Instant};
111
112    #[test]
113    fn test_stop_request() {
114        let req = Arc::new(BackgroundTask::default());
115        let counter = Arc::new(AtomicUsize::default());
116
117        let req_clone = Arc::clone(&req);
118        let counter_clone = Arc::clone(&counter);
119
120        req.request_start().unwrap();
121        let handle = std::thread::spawn(move || {
122            while req_clone
123                .should_keep_running(Duration::from_millis(100))
124                .unwrap()
125            {
126                counter_clone.fetch_add(1, Ordering::Relaxed);
127            }
128        });
129
130        let start_time = Instant::now();
131        std::thread::sleep(Duration::from_millis(450));
132        req.request_stop_and_notify().unwrap();
133        handle.join().unwrap();
134
135        let elapsed = start_time.elapsed();
136        assert_eq!(counter.load(Ordering::Relaxed), 4);
137
138        let millis = elapsed.as_millis();
139        assert!(millis >= 450);
140        assert!(millis < 500);
141    }
142
143    #[test]
144    fn test_spawn() {
145        let req = Arc::new(BackgroundTask::default());
146        let counter = Arc::new(AtomicUsize::default());
147        let counter_clone = Arc::clone(&counter);
148
149        let handle = req
150            .spawn(Duration::from_millis(100), move || {
151                counter_clone.fetch_add(1, Ordering::Relaxed);
152                Ok(())
153            })
154            .unwrap();
155
156        let start_time = Instant::now();
157        std::thread::sleep(Duration::from_millis(450));
158        req.request_stop_and_notify().unwrap();
159        handle.join().unwrap().unwrap();
160
161        let elapsed = start_time.elapsed();
162        assert_eq!(counter.load(Ordering::Relaxed), 4);
163
164        let millis = elapsed.as_millis();
165        assert!(millis >= 450);
166        assert!(millis < 500);
167    }
168}