falco_plugin/plugin/async_event/
background_task.rs

1use std::sync::{Arc, Condvar, Mutex};
2use std::thread::JoinHandle;
3use std::time::Duration;
4
5/// A trivial enum to indicate the requested state of the async background task
6#[derive(Debug, Copy, Clone, Eq, PartialEq)]
7pub enum RequestedState {
8    Running,
9    Stopped,
10}
11
12impl Default for RequestedState {
13    fn default() -> Self {
14        Self::Stopped
15    }
16}
17
18/// # A helper to periodically run a background task until shutdown is requested
19///
20/// Can be used to spawn a separate thread or as a building block for some other
21/// (synchronous/blocking) abstraction.
22///
23/// The implementation is little more than a [`Condvar`] and some helper methods.
24#[derive(Default, Debug)]
25pub struct BackgroundTask {
26    lock: Mutex<RequestedState>,
27    cond: Condvar,
28}
29
30impl BackgroundTask {
31    /// Mark the task as ready to run
32    pub fn request_start(&self) -> Result<(), anyhow::Error> {
33        *self
34            .lock
35            .lock()
36            .map_err(|e| anyhow::anyhow!(e.to_string()))? = RequestedState::Running;
37
38        Ok(())
39    }
40
41    /// Request the task to stop
42    pub fn request_stop_and_notify(&self) -> Result<(), anyhow::Error> {
43        *self
44            .lock
45            .lock()
46            .map_err(|e| anyhow::anyhow!(e.to_string()))? = RequestedState::Stopped;
47        self.cond.notify_one();
48
49        Ok(())
50    }
51
52    /// Wait for a stop request for up to `timeout`
53    ///
54    /// Usable in a loop like:
55    ///
56    /// ```ignore
57    /// while task.should_keep_running(timeout)? {
58    ///     do_things_on_every_timeout()?;
59    /// }
60    /// ```
61    pub fn should_keep_running(&self, timeout: Duration) -> Result<bool, anyhow::Error> {
62        let (_guard, wait_res) = self
63            .cond
64            .wait_timeout_while(
65                self.lock
66                    .lock()
67                    .map_err(|e| anyhow::anyhow!(e.to_string()))?,
68                timeout,
69                |&mut state| state == RequestedState::Running,
70            )
71            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
72
73        Ok(wait_res.timed_out())
74    }
75
76    /// Spawn a background thread that calls `func` every `interval` until shutdown
77    ///
78    /// Ideally, the called closure should not block for any noticeable time, as shutdown
79    /// requests are not processed while it's running.
80    ///
81    /// This method does not attempt to compensate for the closure running time and does not
82    /// try to guarantee that it's executed exactly every `interval`. If you need precise
83    /// intervals between each execution, you should start the thread yourself and calculate
84    /// the timeout passed to [`BackgroundTask::should_keep_running`] every time. You will also
85    /// need to handle the case when the closure takes longer than the interval:
86    /// - just skip the next execution?
87    /// - try to catch up by running the closure back-to-back without a delay?
88    /// - return an error (and stop the background thread)?
89    pub fn spawn<F>(
90        self: &Arc<Self>,
91        interval: Duration,
92        mut func: F,
93    ) -> Result<JoinHandle<Result<(), anyhow::Error>>, anyhow::Error>
94    where
95        F: FnMut() -> Result<(), anyhow::Error> + 'static + Send,
96    {
97        self.request_start()?;
98        let clone = Arc::clone(self);
99
100        Ok(std::thread::spawn(move || {
101            while clone.should_keep_running(interval)? {
102                func()?
103            }
104
105            Ok(())
106        }))
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use crate::plugin::async_event::background_task::BackgroundTask;
113    use std::sync::atomic::{AtomicUsize, Ordering};
114    use std::sync::Arc;
115    use std::time::{Duration, Instant};
116
117    #[test]
118    fn test_stop_request() {
119        let req = Arc::new(BackgroundTask::default());
120        let counter = Arc::new(AtomicUsize::default());
121
122        let req_clone = Arc::clone(&req);
123        let counter_clone = Arc::clone(&counter);
124
125        req.request_start().unwrap();
126        let handle = std::thread::spawn(move || {
127            while req_clone
128                .should_keep_running(Duration::from_millis(100))
129                .unwrap()
130            {
131                counter_clone.fetch_add(1, Ordering::Relaxed);
132            }
133        });
134
135        let start_time = Instant::now();
136        std::thread::sleep(Duration::from_millis(450));
137        req.request_stop_and_notify().unwrap();
138        handle.join().unwrap();
139
140        let elapsed = start_time.elapsed();
141        assert_eq!(counter.load(Ordering::Relaxed), 4);
142
143        let millis = elapsed.as_millis();
144        assert!(millis >= 450);
145        assert!(millis < 500);
146    }
147
148    #[test]
149    fn test_spawn() {
150        let req = Arc::new(BackgroundTask::default());
151        let counter = Arc::new(AtomicUsize::default());
152        let counter_clone = Arc::clone(&counter);
153
154        let handle = req
155            .spawn(Duration::from_millis(100), move || {
156                counter_clone.fetch_add(1, Ordering::Relaxed);
157                Ok(())
158            })
159            .unwrap();
160
161        let start_time = Instant::now();
162        std::thread::sleep(Duration::from_millis(450));
163        req.request_stop_and_notify().unwrap();
164        handle.join().unwrap().unwrap();
165
166        let elapsed = start_time.elapsed();
167        assert_eq!(counter.load(Ordering::Relaxed), 4);
168
169        let millis = elapsed.as_millis();
170        assert!(millis >= 450);
171        assert!(millis < 500);
172    }
173}