From 1a04cdab1de6cb1b876ac7bc5750007a5238ca9c Mon Sep 17 00:00:00 2001
From: Ilya Usov <Ilyausov96@gmail.com>
Date: Sat, 16 Mar 2024 00:34:36 +0100
Subject: [PATCH] Fix RdTask GetAwaiter behaviour (return to original context
 after await)

add test
---
 rd-net/RdFramework/Tasks/RdTaskEx.cs       | 23 ++-----------
 rd-net/Test.RdFramework/AsyncRdTaskTest.cs | 39 +++++++++++++++++++++-
 2 files changed, 40 insertions(+), 22 deletions(-)

diff --git a/rd-net/RdFramework/Tasks/RdTaskEx.cs b/rd-net/RdFramework/Tasks/RdTaskEx.cs
index 4f3a455c0..bb041e81a 100644
--- a/rd-net/RdFramework/Tasks/RdTaskEx.cs
+++ b/rd-net/RdFramework/Tasks/RdTaskEx.cs
@@ -122,27 +122,8 @@ public static Task<T> AsTask<T>(this IRdTask<T> task)
       });
       return tcs.Task;
     }
-    
-    [PublicAPI]
-    public static RdTaskAwaiter<T> GetAwaiter<T>(this IRdTask<T> task) => new(task.Result);
-    
-    public readonly struct RdTaskAwaiter<T> : INotifyCompletion
-    {
-      private readonly IReadonlyProperty<RdTaskResult<T>> myResult;
-
-      internal RdTaskAwaiter(IReadonlyProperty<RdTaskResult<T>> result)
-      {
-        myResult = result;
-      }
 
-      public bool IsCompleted => myResult.Maybe.HasValue;
-
-      public T GetResult() => myResult.Value.Unwrap();
-
-      public void OnCompleted(Action continuation)
-      {
-        myResult.Change.AdviseOnce(Lifetime.Eternal, _ => continuation());
-      }
-    }
+    [PublicAPI]
+    public static TaskAwaiter<T> GetAwaiter<T>(this IRdTask<T> task) => task.AsTask().GetAwaiter();
   }
 }
\ No newline at end of file
diff --git a/rd-net/Test.RdFramework/AsyncRdTaskTest.cs b/rd-net/Test.RdFramework/AsyncRdTaskTest.cs
index 95cfc0fb4..d784962d0 100644
--- a/rd-net/Test.RdFramework/AsyncRdTaskTest.cs
+++ b/rd-net/Test.RdFramework/AsyncRdTaskTest.cs
@@ -3,8 +3,10 @@
 using System.Threading.Tasks;
 using JetBrains.Collections.Viewable;
 using JetBrains.Core;
+using JetBrains.Rd.Base;
 using JetBrains.Rd.Impl;
 using JetBrains.Rd.Tasks;
+using JetBrains.Threading;
 using JetBrains.Util;
 using NUnit.Framework;
 
@@ -31,7 +33,7 @@ public void BindableRdCallListUseRdTaskTest()
     BindableRdCallListTest(TaskKind.Rd);
   }
 
-  private enum TaskKind
+  public enum TaskKind
   {
     System,
     Rd,
@@ -96,6 +98,41 @@ private void BindableRdCallListTest(TaskKind taskKind)
       Assert.IsTrue(bindClientTask.Wait(Timeout(TimeSpan.FromSeconds(10))));
     }
   }
+  
+  [Test]
+  [TestCase(TaskKind.Rd)]
+  [TestCase(TaskKind.System)]
+  public void TestRdTaskAwaiter(TaskKind kind)
+  {
+    var rdTask = new RdTask<Unit>();
+    var scheduler = new TaskSchedulerWrapper(new ConcurrentExclusiveSchedulerPair(TaskScheduler.Default).ExclusiveScheduler, false);
+
+    var task = TestLifetime.StartAsync(scheduler.AsTaskScheduler(), async () =>
+    {
+      scheduler.AssertThread();
+
+      TestLifetime.Start(scheduler.AsTaskScheduler(), () =>
+      {
+        scheduler.AssertThread();
+        TestLifetime.Start(TaskScheduler.Default, () =>
+        {
+          rdTask.ResultInternal.Set(RdTaskResult<Unit>.Success(Unit.Instance));
+        }).NoAwait();
+      }).NoAwait();
+      
+      _ = kind switch
+      {
+        TaskKind.System => await rdTask.AsTask(),
+        TaskKind.Rd     => await rdTask,
+        _ => throw new ArgumentOutOfRangeException(nameof(kind), kind, null)
+      };
+      
+      scheduler.AssertThread();
+    });
+
+    task.Wait(TimeSpan.FromSeconds(10));
+    Assert.IsTrue(task.IsCompleted);
+  } 
 
   private static TimeSpan Timeout(TimeSpan timeout)
   {