From 51f6ac69f81cde91c2ed5392541e10dc26794a39 Mon Sep 17 00:00:00 2001
From: Anton Lazarev <antonok35@gmail.com>
Date: Mon, 9 Jan 2023 19:35:52 -0800
Subject: [PATCH] support async trait methods

---
 Cargo.toml           |  1 +
 src/expansion.rs     | 23 +++++++++++++++++++----
 tests/async-trait.rs | 39 +++++++++++++++++++++++++++++++++++++++
 3 files changed, 59 insertions(+), 4 deletions(-)
 create mode 100644 tests/async-trait.rs

diff --git a/Cargo.toml b/Cargo.toml
index a6f7862..975b1d0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,3 +25,4 @@ enum_derive = "= 0.1.7"
 custom_derive = "= 0.1.7"
 serde = { version = "= 1.0.136", features = ["derive"] }
 serde_json = "= 1.0.78"
+smol = "1.3.0"
diff --git a/src/expansion.rs b/src/expansion.rs
index 4a2cf8d..0f16bbd 100644
--- a/src/expansion.rs
+++ b/src/expansion.rs
@@ -221,7 +221,8 @@ fn create_trait_fn_call(
     trait_method: &syn::TraitItemMethod,
     trait_generics: &syn::TypeGenerics,
     trait_name: &syn::Ident,
-) -> syn::ExprCall {
+    is_async: bool,
+) -> syn::Expr {
     let trait_args = trait_method.to_owned().sig.inputs;
     let (method_type, mut args) = extract_fn_args(trait_args);
 
@@ -229,7 +230,7 @@ fn create_trait_fn_call(
     let explicit_self_arg = syn::Ident::new(FIELDNAME, trait_method.span());
     args.insert(0, plain_identifier_expr(explicit_self_arg));
 
-    syn::ExprCall {
+    let call = syn::Expr::from(syn::ExprCall {
         attrs: vec![],
         func: {
             if let MethodType::Static = method_type {
@@ -269,6 +270,17 @@ fn create_trait_fn_call(
         },
         paren_token: Default::default(),
         args,
+    });
+
+    if is_async {
+        syn::Expr::from(syn::ExprAwait {
+            attrs: Default::default(),
+            base: Box::new(call),
+            dot_token: Default::default(),
+            await_token: Default::default(),
+        })
+    } else {
+        call
     }
 }
 
@@ -280,8 +292,9 @@ fn create_match_expr(
     trait_name: &syn::Ident,
     enum_name: &syn::Ident,
     enumvariants: &[&EnumDispatchVariant],
+    is_async: bool,
 ) -> syn::Expr {
-    let trait_fn_call = create_trait_fn_call(trait_method, trait_generics, trait_name);
+    let trait_fn_call = create_trait_fn_call(trait_method, trait_generics, trait_name, is_async);
 
     // Creates a Vec containing a match arm for every enum variant
     let match_arms = enumvariants
@@ -302,7 +315,7 @@ fn create_match_expr(
                 },
                 guard: None,
                 fat_arrow_token: Default::default(),
-                body: Box::new(syn::Expr::from(trait_fn_call.to_owned())),
+                body: Box::new(trait_fn_call.to_owned()),
                 comma: Some(Default::default()),
             }
         })
@@ -332,6 +345,7 @@ fn create_trait_match(
     match trait_item {
         syn::TraitItem::Method(mut trait_method) => {
             identify_signature_arguments(&mut trait_method.sig);
+            let is_async = trait_method.sig.asyncness.is_some();
 
             let match_expr = create_match_expr(
                 &trait_method,
@@ -339,6 +353,7 @@ fn create_trait_match(
                 trait_name,
                 enum_name,
                 enumvariants,
+                is_async,
             );
 
             let mut impl_attrs = trait_method.attrs.clone();
diff --git a/tests/async-trait.rs b/tests/async-trait.rs
new file mode 100644
index 0000000..ccbe81f
--- /dev/null
+++ b/tests/async-trait.rs
@@ -0,0 +1,39 @@
+#![feature(async_fn_in_trait)]
+
+use enum_dispatch::enum_dispatch;
+
+struct A;
+struct B;
+
+impl XTrait for A {
+    async fn run(&mut self) -> Result<u32, ()> {
+        Ok(10)
+    }
+}
+impl XTrait for B {
+    async fn run(&mut self) -> Result<u32, ()> {
+        Ok(20)
+    }
+}
+
+#[enum_dispatch]
+enum X {
+    A,
+    B,
+}
+
+#[enum_dispatch(X)]
+trait XTrait {
+    async fn run(&mut self) -> Result<u32, ()>;
+}
+
+fn main() -> smol::io::Result<()> {
+    let mut a: X = A.into();
+    let mut b: X = B.into();
+
+    smol::block_on(async {
+        assert_eq!(10, a.run().await.unwrap());
+        assert_eq!(20, b.run().await.unwrap());
+        Ok(())
+    })
+}
-- 
GitLab