@@ -12677,6 +12677,7 @@ F: Documentation/dev-tools/kunit/
F: include/kunit/
F: lib/kunit/
F: rust/kernel/kunit.rs
+F: rust/macros/kunit.rs
F: scripts/rustdoc_test_*
F: tools/testing/kunit/
@@ -40,6 +40,8 @@ pub fn info(args: fmt::Arguments<'_>) {
}
}
+use macros::kunit_tests;
+
/// Asserts that a boolean expression is `true` at runtime.
///
/// Public but hidden since it should only be used from generated tests.
@@ -283,3 +285,12 @@ macro_rules! kunit_unsafe_test_suite {
};
};
}
+
+#[kunit_tests(rust_kernel_kunit)]
+mod tests {
+ #[test]
+ fn rust_test_kunit_example_test() {
+ #![expect(clippy::eq_op)]
+ assert_eq!(1 + 1, 2);
+ }
+}
new file mode 100644
@@ -0,0 +1,161 @@
+// SPDX-License-Identifier: GPL-2.0
+
+//! Procedural macro to run KUnit tests using a user-space like syntax.
+//!
+//! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
+
+use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
+use std::fmt::Write;
+
+pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
+ let attr = attr.to_string();
+
+ if attr.is_empty() {
+ panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
+ }
+
+ if attr.len() > 255 {
+ panic!(
+ "The test suite name `{}` exceeds the maximum length of 255 bytes",
+ attr
+ )
+ }
+
+ let mut tokens: Vec<_> = ts.into_iter().collect();
+
+ // Scan for the `mod` keyword.
+ tokens
+ .iter()
+ .find_map(|token| match token {
+ TokenTree::Ident(ident) => match ident.to_string().as_str() {
+ "mod" => Some(true),
+ _ => None,
+ },
+ _ => None,
+ })
+ .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
+
+ // Retrieve the main body. The main body should be the last token tree.
+ let body = match tokens.pop() {
+ Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
+ _ => panic!("Cannot locate main body of module"),
+ };
+
+ // Get the functions set as tests. Search for `[test]` -> `fn`.
+ let mut body_it = body.stream().into_iter();
+ let mut tests = Vec::new();
+ while let Some(token) = body_it.next() {
+ match token {
+ TokenTree::Group(ident) if ident.to_string() == "[test]" => match body_it.next() {
+ Some(TokenTree::Ident(ident)) if ident.to_string() == "fn" => {
+ let test_name = match body_it.next() {
+ Some(TokenTree::Ident(ident)) => ident.to_string(),
+ _ => continue,
+ };
+ tests.push(test_name);
+ }
+ _ => continue,
+ },
+ _ => (),
+ }
+ }
+
+ // Add `#[cfg(CONFIG_KUNIT)]` before the module declaration.
+ let config_kunit = "#[cfg(CONFIG_KUNIT)]".to_owned().parse().unwrap();
+ tokens.insert(
+ 0,
+ TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
+ );
+
+ // Generate the test KUnit test suite and a test case for each `#[test]`.
+ // The code generated for the following test module:
+ //
+ // ```
+ // #[kunit_tests(kunit_test_suit_name)]
+ // mod tests {
+ // #[test]
+ // fn foo() {
+ // assert_eq!(1, 1);
+ // }
+ //
+ // #[test]
+ // fn bar() {
+ // assert_eq!(2, 2);
+ // }
+ // }
+ // ```
+ //
+ // Looks like:
+ //
+ // ```
+ // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut kernel::bindings::kunit) { foo(); }
+ // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut kernel::bindings::kunit) { bar(); }
+ //
+ // static mut TEST_CASES: [kernel::bindings::kunit_case; 3] = [
+ // kernel::kunit::kunit_case(kernel::c_str!("foo"), kunit_rust_wrapper_foo),
+ // kernel::kunit::kunit_case(kernel::c_str!("bar"), kunit_rust_wrapper_bar),
+ // kernel::kunit::kunit_case_null(),
+ // ];
+ //
+ // kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
+ // ```
+ let mut kunit_macros = "".to_owned();
+ let mut test_cases = "".to_owned();
+ for test in &tests {
+ let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{}", test);
+ let kunit_wrapper = format!(
+ "unsafe extern \"C\" fn {}(_test: *mut kernel::bindings::kunit) {{ {}(); }}",
+ kunit_wrapper_fn_name, test
+ );
+ writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
+ writeln!(
+ test_cases,
+ " kernel::kunit::kunit_case(kernel::c_str!(\"{}\"), {}),",
+ test, kunit_wrapper_fn_name
+ )
+ .unwrap();
+ }
+
+ writeln!(kunit_macros).unwrap();
+ writeln!(
+ kunit_macros,
+ "static mut TEST_CASES: [kernel::bindings::kunit_case; {}] = [\n{test_cases} kernel::kunit::kunit_case_null(),\n];",
+ tests.len() + 1
+ )
+ .unwrap();
+
+ writeln!(
+ kunit_macros,
+ "kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
+ )
+ .unwrap();
+
+ // Remove the `#[test]` macros.
+ // We do this at a token level, in order to preserve span information.
+ let mut new_body = vec![];
+ let mut body_it = body.stream().into_iter();
+
+ while let Some(token) = body_it.next() {
+ match token {
+ TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
+ Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
+ Some(next) => {
+ new_body.extend([token, next]);
+ }
+ _ => {
+ new_body.push(token);
+ }
+ },
+ _ => {
+ new_body.push(token);
+ }
+ }
+ }
+
+ let mut new_body = TokenStream::from_iter(new_body);
+ new_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
+
+ tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body)));
+
+ tokens.into_iter().collect()
+}
@@ -10,6 +10,7 @@
mod quote;
mod concat_idents;
mod helpers;
+mod kunit;
mod module;
mod paste;
mod pin_data;
@@ -492,3 +493,31 @@ pub fn paste(input: TokenStream) -> TokenStream {
pub fn derive_zeroable(input: TokenStream) -> TokenStream {
zeroable::derive(input)
}
+
+/// Registers a KUnit test suite and its test cases using a user-space like syntax.
+///
+/// This macro should be used on modules. If `CONFIG_KUNIT` (in `.config`) is `n`, the target module
+/// is ignored.
+///
+/// # Examples
+///
+/// ```ignore
+/// # use macros::kunit_tests;
+///
+/// #[kunit_tests(kunit_test_suit_name)]
+/// mod tests {
+/// #[test]
+/// fn foo() {
+/// assert_eq!(1, 1);
+/// }
+///
+/// #[test]
+/// fn bar() {
+/// assert_eq!(2, 2);
+/// }
+/// }
+/// ```
+#[proc_macro_attribute]
+pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
+ kunit::kunit_tests(attr, ts)
+}