From 1a04cdab1de6cb1b876ac7bc5750007a5238ca9c Mon Sep 17 00:00:00 2001 From: Ilya Usov <Ilyausov96@gmail.com> Date: Sat, 16 Mar 2024 00:34:36 +0100 Subject: [PATCH] Fix RdTask GetAwaiter behaviour (return to original context after await) add test --- rd-net/RdFramework/Tasks/RdTaskEx.cs | 23 ++----------- rd-net/Test.RdFramework/AsyncRdTaskTest.cs | 39 +++++++++++++++++++++- 2 files changed, 40 insertions(+), 22 deletions(-) diff --git a/rd-net/RdFramework/Tasks/RdTaskEx.cs b/rd-net/RdFramework/Tasks/RdTaskEx.cs index 4f3a455c0..bb041e81a 100644 --- a/rd-net/RdFramework/Tasks/RdTaskEx.cs +++ b/rd-net/RdFramework/Tasks/RdTaskEx.cs @@ -122,27 +122,8 @@ public static Task<T> AsTask<T>(this IRdTask<T> task) }); return tcs.Task; } - - [PublicAPI] - public static RdTaskAwaiter<T> GetAwaiter<T>(this IRdTask<T> task) => new(task.Result); - - public readonly struct RdTaskAwaiter<T> : INotifyCompletion - { - private readonly IReadonlyProperty<RdTaskResult<T>> myResult; - - internal RdTaskAwaiter(IReadonlyProperty<RdTaskResult<T>> result) - { - myResult = result; - } - public bool IsCompleted => myResult.Maybe.HasValue; - - public T GetResult() => myResult.Value.Unwrap(); - - public void OnCompleted(Action continuation) - { - myResult.Change.AdviseOnce(Lifetime.Eternal, _ => continuation()); - } - } + [PublicAPI] + public static TaskAwaiter<T> GetAwaiter<T>(this IRdTask<T> task) => task.AsTask().GetAwaiter(); } } \ No newline at end of file diff --git a/rd-net/Test.RdFramework/AsyncRdTaskTest.cs b/rd-net/Test.RdFramework/AsyncRdTaskTest.cs index 95cfc0fb4..d784962d0 100644 --- a/rd-net/Test.RdFramework/AsyncRdTaskTest.cs +++ b/rd-net/Test.RdFramework/AsyncRdTaskTest.cs @@ -3,8 +3,10 @@ using System.Threading.Tasks; using JetBrains.Collections.Viewable; using JetBrains.Core; +using JetBrains.Rd.Base; using JetBrains.Rd.Impl; using JetBrains.Rd.Tasks; +using JetBrains.Threading; using JetBrains.Util; using NUnit.Framework; @@ -31,7 +33,7 @@ public void BindableRdCallListUseRdTaskTest() BindableRdCallListTest(TaskKind.Rd); } - private enum TaskKind + public enum TaskKind { System, Rd, @@ -96,6 +98,41 @@ private void BindableRdCallListTest(TaskKind taskKind) Assert.IsTrue(bindClientTask.Wait(Timeout(TimeSpan.FromSeconds(10)))); } } + + [Test] + [TestCase(TaskKind.Rd)] + [TestCase(TaskKind.System)] + public void TestRdTaskAwaiter(TaskKind kind) + { + var rdTask = new RdTask<Unit>(); + var scheduler = new TaskSchedulerWrapper(new ConcurrentExclusiveSchedulerPair(TaskScheduler.Default).ExclusiveScheduler, false); + + var task = TestLifetime.StartAsync(scheduler.AsTaskScheduler(), async () => + { + scheduler.AssertThread(); + + TestLifetime.Start(scheduler.AsTaskScheduler(), () => + { + scheduler.AssertThread(); + TestLifetime.Start(TaskScheduler.Default, () => + { + rdTask.ResultInternal.Set(RdTaskResult<Unit>.Success(Unit.Instance)); + }).NoAwait(); + }).NoAwait(); + + _ = kind switch + { + TaskKind.System => await rdTask.AsTask(), + TaskKind.Rd => await rdTask, + _ => throw new ArgumentOutOfRangeException(nameof(kind), kind, null) + }; + + scheduler.AssertThread(); + }); + + task.Wait(TimeSpan.FromSeconds(10)); + Assert.IsTrue(task.IsCompleted); + } private static TimeSpan Timeout(TimeSpan timeout) {