From 51f6ac69f81cde91c2ed5392541e10dc26794a39 Mon Sep 17 00:00:00 2001 From: Anton Lazarev <antonok35@gmail.com> Date: Mon, 9 Jan 2023 19:35:52 -0800 Subject: [PATCH] support async trait methods --- Cargo.toml | 1 + src/expansion.rs | 23 +++++++++++++++++++---- tests/async-trait.rs | 39 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) create mode 100644 tests/async-trait.rs diff --git a/Cargo.toml b/Cargo.toml index a6f7862..975b1d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,4 @@ enum_derive = "= 0.1.7" custom_derive = "= 0.1.7" serde = { version = "= 1.0.136", features = ["derive"] } serde_json = "= 1.0.78" +smol = "1.3.0" diff --git a/src/expansion.rs b/src/expansion.rs index 4a2cf8d..0f16bbd 100644 --- a/src/expansion.rs +++ b/src/expansion.rs @@ -221,7 +221,8 @@ fn create_trait_fn_call( trait_method: &syn::TraitItemMethod, trait_generics: &syn::TypeGenerics, trait_name: &syn::Ident, -) -> syn::ExprCall { + is_async: bool, +) -> syn::Expr { let trait_args = trait_method.to_owned().sig.inputs; let (method_type, mut args) = extract_fn_args(trait_args); @@ -229,7 +230,7 @@ fn create_trait_fn_call( let explicit_self_arg = syn::Ident::new(FIELDNAME, trait_method.span()); args.insert(0, plain_identifier_expr(explicit_self_arg)); - syn::ExprCall { + let call = syn::Expr::from(syn::ExprCall { attrs: vec![], func: { if let MethodType::Static = method_type { @@ -269,6 +270,17 @@ fn create_trait_fn_call( }, paren_token: Default::default(), args, + }); + + if is_async { + syn::Expr::from(syn::ExprAwait { + attrs: Default::default(), + base: Box::new(call), + dot_token: Default::default(), + await_token: Default::default(), + }) + } else { + call } } @@ -280,8 +292,9 @@ fn create_match_expr( trait_name: &syn::Ident, enum_name: &syn::Ident, enumvariants: &[&EnumDispatchVariant], + is_async: bool, ) -> syn::Expr { - let trait_fn_call = create_trait_fn_call(trait_method, trait_generics, trait_name); + let trait_fn_call = create_trait_fn_call(trait_method, trait_generics, trait_name, is_async); // Creates a Vec containing a match arm for every enum variant let match_arms = enumvariants @@ -302,7 +315,7 @@ fn create_match_expr( }, guard: None, fat_arrow_token: Default::default(), - body: Box::new(syn::Expr::from(trait_fn_call.to_owned())), + body: Box::new(trait_fn_call.to_owned()), comma: Some(Default::default()), } }) @@ -332,6 +345,7 @@ fn create_trait_match( match trait_item { syn::TraitItem::Method(mut trait_method) => { identify_signature_arguments(&mut trait_method.sig); + let is_async = trait_method.sig.asyncness.is_some(); let match_expr = create_match_expr( &trait_method, @@ -339,6 +353,7 @@ fn create_trait_match( trait_name, enum_name, enumvariants, + is_async, ); let mut impl_attrs = trait_method.attrs.clone(); diff --git a/tests/async-trait.rs b/tests/async-trait.rs new file mode 100644 index 0000000..ccbe81f --- /dev/null +++ b/tests/async-trait.rs @@ -0,0 +1,39 @@ +#![feature(async_fn_in_trait)] + +use enum_dispatch::enum_dispatch; + +struct A; +struct B; + +impl XTrait for A { + async fn run(&mut self) -> Result<u32, ()> { + Ok(10) + } +} +impl XTrait for B { + async fn run(&mut self) -> Result<u32, ()> { + Ok(20) + } +} + +#[enum_dispatch] +enum X { + A, + B, +} + +#[enum_dispatch(X)] +trait XTrait { + async fn run(&mut self) -> Result<u32, ()>; +} + +fn main() -> smol::io::Result<()> { + let mut a: X = A.into(); + let mut b: X = B.into(); + + smol::block_on(async { + assert_eq!(10, a.run().await.unwrap()); + assert_eq!(20, b.run().await.unwrap()); + Ok(()) + }) +} -- GitLab