Skip to main content

falco_plugin/source/
wrappers.rs

1use crate::base::wrappers::PluginWrapper;
2use crate::error::ffi_result::FfiResult;
3use crate::error::panic::catch_panic;
4use crate::source::SourcePluginInstanceWrapper;
5use crate::source::{EventBatch, EventInput, SourcePlugin, SourcePluginInstance};
6use crate::strings::cstring_writer::WriteIntoCString;
7use crate::strings::from_ptr::try_str_from_ptr;
8use falco_plugin_api::plugin_api__bindgen_ty_1 as source_plugin_api;
9use falco_plugin_api::{
10    ss_instance_t, ss_plugin_event, ss_plugin_event_input, ss_plugin_rc,
11    ss_plugin_rc_SS_PLUGIN_FAILURE, ss_plugin_rc_SS_PLUGIN_SUCCESS, ss_plugin_t,
12};
13use std::ffi::c_char;
14use std::io::Write;
15use std::marker::PhantomData;
16use std::panic::AssertUnwindSafe;
17
18/// Marker trait to mark a source plugin as exported to the API
19///
20/// # Safety
21///
22/// Only implement this trait if you export the plugin either statically or dynamically
23/// to the plugin API. This is handled by the `source_plugin!` and `static_plugin!` macros, so you
24/// should never need to implement this trait manually.
25#[diagnostic::on_unimplemented(
26    message = "Source plugin is not exported",
27    note = "use either `source_plugin!` or `static_plugin!`"
28)]
29pub unsafe trait SourcePluginExported {}
30
31pub trait SourcePluginFallbackApi {
32    const SOURCE_API: source_plugin_api = source_plugin_api {
33        get_id: None,
34        get_event_source: None,
35        open: None,
36        close: None,
37        list_open_params: None,
38        get_progress: None,
39        event_to_string: None,
40        next_batch: None,
41    };
42
43    const IMPLEMENTS_SOURCE: bool = false;
44}
45impl<T> SourcePluginFallbackApi for T {}
46
47#[allow(missing_debug_implementations)]
48pub struct SourcePluginApi<T>(std::marker::PhantomData<T>);
49
50const fn impl_source_plugin_api<T: SourcePlugin>() -> source_plugin_api {
51    if T::PLUGIN_ID != 0 && T::EVENT_SOURCE.is_empty() {
52        panic!("EVENT_SOURCE cannot be empty if PLUGIN_ID is non-zero")
53    }
54
55    source_plugin_api {
56        get_id: Some(plugin_get_id::<T>),
57        get_event_source: Some(plugin_get_event_source::<T>),
58        open: Some(plugin_open::<T>),
59        close: Some(plugin_close::<T>),
60        list_open_params: Some(plugin_list_open_params::<T>),
61        get_progress: Some(plugin_get_progress::<T>),
62        event_to_string: Some(plugin_event_to_string::<T>),
63        next_batch: Some(plugin_next_batch::<T>),
64    }
65}
66
67impl<T: SourcePlugin> SourcePluginApi<T> {
68    pub const SOURCE_API: source_plugin_api = impl_source_plugin_api::<T>();
69
70    pub const IMPLEMENTS_SOURCE: bool = true;
71}
72
73pub extern "C" fn plugin_get_event_source<T: SourcePlugin>() -> *const c_char {
74    T::EVENT_SOURCE.as_ptr()
75}
76
77pub extern "C" fn plugin_get_id<T: SourcePlugin>() -> u32 {
78    T::PLUGIN_ID
79}
80
81/// # Safety
82///
83/// All pointers must be valid
84pub unsafe extern "C" fn plugin_list_open_params<T: SourcePlugin>(
85    plugin: *mut ss_plugin_t,
86    rc: *mut i32,
87) -> *const c_char {
88    let plugin = plugin as *mut PluginWrapper<T>;
89    let plugin = unsafe {
90        let Some(plugin) = plugin.as_mut() else {
91            return std::ptr::null();
92        };
93        plugin
94    };
95    let Some(actual_plugin) = &mut plugin.plugin else {
96        return std::ptr::null();
97    };
98
99    // .as_ptr() breaks the lifetime requirement, which implies that the caller
100    // cannot hold on to the result before the next API call. As this whole method
101    // is apparently unused, this holds (for now).
102    match catch_panic(AssertUnwindSafe(|| {
103        actual_plugin.plugin.list_open_params().map(|s| s.as_ptr())
104    })) {
105        Ok(s) => {
106            unsafe {
107                *rc = ss_plugin_rc_SS_PLUGIN_SUCCESS;
108            }
109            s
110        }
111        Err(e) => {
112            unsafe {
113                *rc = e.status_code();
114            }
115            e.set_last_error(&mut plugin.error_buf);
116            std::ptr::null()
117        }
118    }
119}
120
121/// # Safety
122///
123/// All pointers must be valid
124pub unsafe extern "C" fn plugin_open<T: SourcePlugin>(
125    plugin: *mut ss_plugin_t,
126    params: *const c_char,
127    rc: *mut ss_plugin_rc,
128) -> *mut ss_instance_t {
129    let plugin = plugin as *mut PluginWrapper<T>;
130    unsafe {
131        let Some(plugin) = plugin.as_mut() else {
132            return std::ptr::null_mut();
133        };
134        let Some(actual_plugin) = &mut plugin.plugin else {
135            return std::ptr::null_mut();
136        };
137
138        let Some(rc) = rc.as_mut() else {
139            return std::ptr::null_mut();
140        };
141
142        let params = if params.is_null() {
143            None
144        } else {
145            match try_str_from_ptr(&params) {
146                Ok(params) => Some(params),
147                Err(e) => {
148                    plugin
149                        .error_buf
150                        .write_into(|w| w.write_all(e.to_string().as_bytes()))
151                        .ok();
152                    *rc = ss_plugin_rc_SS_PLUGIN_FAILURE;
153
154                    return std::ptr::null_mut();
155                }
156            }
157        };
158
159        match catch_panic(AssertUnwindSafe(|| actual_plugin.plugin.open(params))) {
160            Ok(instance) => {
161                *rc = ss_plugin_rc_SS_PLUGIN_SUCCESS;
162                Box::into_raw(Box::new(SourcePluginInstanceWrapper {
163                    instance,
164                    batch: Default::default(),
165                }))
166                .cast()
167            }
168            Err(e) => {
169                e.set_last_error(&mut plugin.error_buf);
170                *rc = e.status_code();
171                std::ptr::null_mut()
172            }
173        }
174    }
175}
176
177/// # Safety
178///
179/// All pointers must be valid
180pub unsafe extern "C" fn plugin_close<T: SourcePlugin>(
181    plugin: *mut ss_plugin_t,
182    instance: *mut ss_instance_t,
183) {
184    let plugin = plugin as *mut PluginWrapper<T>;
185    let plugin = unsafe {
186        let Some(plugin) = plugin.as_mut() else {
187            return;
188        };
189        plugin
190    };
191    let Some(actual_plugin) = &mut plugin.plugin else {
192        return;
193    };
194
195    let instance = instance as *mut SourcePluginInstanceWrapper<T::Instance>;
196    if instance.is_null() {
197        return;
198    }
199    unsafe {
200        let mut inst = Box::from_raw(instance);
201        let res = catch_panic(AssertUnwindSafe(|| {
202            actual_plugin.plugin.close(&mut inst.instance);
203            Ok(())
204        }));
205
206        if let Err(e) = res {
207            log::error!("Error closing source plugin instance: {}", e)
208        }
209    }
210}
211
212/// # Safety
213///
214/// All pointers must be valid
215pub unsafe extern "C" fn plugin_next_batch<T: SourcePlugin>(
216    plugin: *mut ss_plugin_t,
217    instance: *mut ss_instance_t,
218    nevts: *mut u32,
219    evts: *mut *mut *mut ss_plugin_event,
220) -> ss_plugin_rc {
221    let plugin = plugin as *mut PluginWrapper<T>;
222    let instance = instance as *mut SourcePluginInstanceWrapper<T::Instance>;
223    unsafe {
224        let Some(plugin) = plugin.as_mut() else {
225            return ss_plugin_rc_SS_PLUGIN_FAILURE;
226        };
227        let Some(actual_plugin) = &mut plugin.plugin else {
228            return ss_plugin_rc_SS_PLUGIN_FAILURE;
229        };
230
231        let Some(instance) = instance.as_mut() else {
232            return ss_plugin_rc_SS_PLUGIN_FAILURE;
233        };
234
235        instance.batch.reset();
236        let mut batch = EventBatch::new(&instance.batch);
237        let batch_result = catch_panic(AssertUnwindSafe(|| {
238            instance
239                .instance
240                .next_batch(&mut actual_plugin.plugin, &mut batch)
241        }));
242        match batch_result {
243            Ok(()) => {
244                let (events, events_len) = batch.get_events_ptr_len();
245                *nevts = events_len as u32;
246                *evts = events.cast();
247                ss_plugin_rc_SS_PLUGIN_SUCCESS
248            }
249            Err(e) => {
250                *nevts = 0;
251                *evts = std::ptr::null_mut();
252                e.set_last_error(&mut plugin.error_buf);
253                e.status_code()
254            }
255        }
256    }
257}
258
259/// # Safety
260///
261/// All pointers must be valid
262pub unsafe extern "C" fn plugin_get_progress<T: SourcePlugin>(
263    _plugin: *mut ss_plugin_t,
264    instance: *mut ss_instance_t,
265    progress_pct: *mut u32,
266) -> *const c_char {
267    let instance = instance as *mut SourcePluginInstanceWrapper<T::Instance>;
268    let Some(instance) = (unsafe { instance.as_mut() }) else {
269        unsafe {
270            *progress_pct = 0;
271        }
272        return std::ptr::null();
273    };
274
275    // .as_ptr() implies that the caller cannot hold on to the result before the next API call
276    let progress = catch_panic(AssertUnwindSafe(|| {
277        let progress = instance.instance.get_progress();
278        match progress.detail {
279            Some(s) => Ok((progress.value, s.as_ptr())),
280            None => Ok((progress.value, std::ptr::null())),
281        }
282    }));
283
284    match progress {
285        Ok((pct, detail)) => {
286            unsafe {
287                *progress_pct = (pct * 100.0) as u32;
288            }
289
290            detail
291        }
292        Err(e) => {
293            log::error!("Error getting progress for source plugin instance: {}", e);
294
295            unsafe {
296                *progress_pct = 0;
297            }
298
299            std::ptr::null()
300        }
301    }
302}
303
304/// # Safety
305///
306/// All pointers must be valid
307pub unsafe extern "C" fn plugin_event_to_string<T: SourcePlugin>(
308    plugin: *mut ss_plugin_t,
309    event: *const ss_plugin_event_input,
310) -> *const c_char {
311    let plugin = plugin as *mut PluginWrapper<T>;
312    unsafe {
313        let Some(plugin) = plugin.as_mut() else {
314            return std::ptr::null();
315        };
316        let Some(actual_plugin) = &mut plugin.plugin else {
317            return std::ptr::null();
318        };
319
320        let Some(event) = event.as_ref() else {
321            return std::ptr::null();
322        };
323        let event = EventInput(*event, PhantomData);
324
325        match catch_panic(AssertUnwindSafe(|| {
326            actual_plugin.plugin.event_to_string(&event)
327        })) {
328            Ok(s) => {
329                plugin.string_storage = s;
330                plugin.string_storage.as_ptr()
331            }
332            Err(_) => std::ptr::null(),
333        }
334    }
335}
336
337/// # Register a source plugin
338///
339/// This macro must be called at most once in a crate (it generates public functions with fixed
340/// `#[unsafe(no_mangle)]` names) with a type implementing [`SourcePlugin`] as the sole parameter.
341#[macro_export]
342macro_rules! source_plugin {
343    ($ty:ty) => {
344        unsafe impl $crate::source::wrappers::SourcePluginExported for $ty {}
345
346        $crate::wrap_ffi! {
347            #[unsafe(no_mangle)]
348            use $crate::source::wrappers: <$ty>;
349            unsafe fn plugin_next_batch(
350                plugin: *mut falco_plugin::api::ss_plugin_t,
351                instance: *mut falco_plugin::api::ss_instance_t,
352                nevts: *mut u32,
353                evts: *mut *mut *mut falco_plugin::api::ss_plugin_event,
354            ) -> i32;
355            unsafe fn plugin_get_progress(
356                plugin: *mut falco_plugin::api::ss_plugin_t,
357                instance: *mut falco_plugin::api::ss_instance_t,
358                progress_pct: *mut u32,
359            ) -> *const ::std::ffi::c_char;
360            unsafe fn plugin_get_id() -> u32;
361            unsafe fn plugin_get_event_source() -> *const ::std::ffi::c_char;
362            unsafe fn plugin_list_open_params(
363                plugin: *mut falco_plugin::api::ss_plugin_t,
364                rc: *mut i32,
365            ) -> *const ::std::ffi::c_char;
366            unsafe fn plugin_open(
367                plugin: *mut falco_plugin::api::ss_plugin_t,
368                params: *const ::std::ffi::c_char,
369                rc: *mut i32,
370            ) -> *mut falco_plugin::api::ss_instance_t;
371            unsafe fn plugin_close(
372                plugin: *mut falco_plugin::api::ss_plugin_t,
373                instance: *mut falco_plugin::api::ss_instance_t,
374            ) -> ();
375            unsafe fn plugin_event_to_string(
376                plugin: *mut falco_plugin::api::ss_plugin_t,
377                event_input: *const falco_plugin::api::ss_plugin_event_input,
378            ) -> *const std::ffi::c_char;
379        }
380
381        #[allow(dead_code)]
382        fn __typecheck_plugin_source_api() -> falco_plugin::api::plugin_api__bindgen_ty_1 {
383            falco_plugin::api::plugin_api__bindgen_ty_1 {
384                next_batch: Some(plugin_next_batch),
385                get_progress: Some(plugin_get_progress),
386                get_id: Some(plugin_get_id),
387                get_event_source: Some(plugin_get_event_source),
388                list_open_params: Some(plugin_list_open_params),
389                open: Some(plugin_open),
390                close: Some(plugin_close),
391                event_to_string: Some(plugin_event_to_string),
392            }
393        }
394    };
395}