TPL实现Task.WhileAll扩展方法
文章翻译整理自 Nikola Malovic 两篇博文:
当 Task.WhenAll 遇见 Task.WhenAny
在 TPL (Task Parallel Library) 中,有两种通过非阻塞方式等待 Task 数组任务结束的方式:Task.WhenAll 和 Task.WhenAny 。
它们的工作方式是:
- WhenAll 当每项任务都完成时为完成。
- WhenAny 当任意项任务完成时为完成。
现在我们需要一项功能,完成 Task 数组中的所有任务,并且当有任务完成时汇报状态。
我们称这个扩展方法为:Task.WhileAll 。
扩展方法实现
1 public static class TaskExtensions 2 { 3 public static async Task<IList<T>> WhileAll<T>(this IList<Task<T>> tasks, IProgress<T> progress) 4 { 5 var result = new List<T>(tasks.Count); 6 var done = new List<Task<T>>(tasks); 7 8 while (done.Count > 0) 9 { 10 await Task.WhenAny(tasks); 11 12 var spinning = new List<Task<T>>(done.Count - 1); 13 for (int i = 0; i < done.Count; i++) 14 { 15 if (done[i].IsCompleted) 16 { 17 result.Add(done[i].Result); 18 progress.Report(done[i].Result); 19 } 20 else 21 { 22 spinning.Add(done[i]); 23 } 24 } 25 26 done = spinning; 27 } 28 29 return result; 30 } 31 }
代码实现很简单:
- 其是 IList<Task<T>> 的一个 async 扩展方法
- 方法返回完整的 IList<T> 结果
- 方法会接受一个 IProgress<T> 类型的参数,用于向订阅者发布 Task 完成信息
- 在方法体内,我们使用一个循环来检测,直到所有 Task 完成
- 通过使用 Task.WhenAny 来异步等待 Task 完成
单元测试
1 [TestClass] 2 public class UnitTest1 3 { 4 [TestMethod] 5 public async Task TestTaskExtensionsWhileAll() 6 { 7 var task1 = Task.Run(() => 101); 8 var task2 = Task.Run(() => 102); 9 var tasks = new List<Task<int>>() { task1, task2 }; 10 11 List<int> result = new List<int>(); 12 var listener = new Progress<int>( 13 taskResult => 14 { 15 result.Add(taskResult); 16 }); 17 18 var actual = await tasks.WhileAll(listener); 19 Thread.Sleep(50); // wait a bit for progress reports to complete 20 21 Assert.AreEqual(2, result.Count); 22 Assert.IsTrue(result.Contains(101)); 23 Assert.IsTrue(result.Contains(102)); 24 25 Assert.AreEqual(2, actual.Count); 26 Assert.IsTrue(actual.Contains(101)); 27 Assert.IsTrue(actual.Contains(102)); 28 } 29 }
同样,测试代码也不复杂:
- 创建两个哑元 Task,并存到数组中
- 定义进度侦听器 Progress<T>,来监测每个任务运行的结果
- 通过 await 方式来调用方法
- 使用 Thread.Sleep 来等待 50ms ,以便 Progress 可以来得及处理结果
- 检查所有 Task 执行完毕后均已上报 Progress
- 检查所有 Task 均已执行完毕
我知道每当使用 Thread.Sleep 时绝不是件好事,所以我决定摆脱它。
实现IProgressAsync<T>
问题实际上是因为 IProgress<T> 接口定义的是 void 委托,因此无法使用 await 进行等待。
因此我决定定义一个新的接口,使用同样的 Report 行为,但会返回 Task ,用以实现真正的异步。
1 public interface IProgressAsync<in T> 2 { 3 Task ReportAsync(T value); 4 }
有了异步版本的支持,将使订阅者更容易处理 await 调用。当然也可以使用 async void 来达成,但我认为 async void 总会延伸出更差的设计。所以,我还是选择通过定义 Task 返回值签名的接口来达成这一功能。
如下为接口实现:
1 public class ProgressAsync<T> : IProgressAsync<T> 2 { 3 private readonly Func<T, Task> handler; 4 5 public ProgressAsync(Func<T, Task> handler) 6 { 7 this.handler = handler; 8 } 9 10 public async Task ReportAsync(T value) 11 { 12 await this.handler.InvokeAsync(value); 13 } 14 }
显然也没什么特别的:
- 使用 Func<T, Task> 来代替 Action<T>,以便可以使用 await
- ReportAsync 通过使用 await 方式来提供 Task
有了这些之后,我们来更新扩展方法:
1 public static class TaskExtensions 2 { 3 public static async Task<IList<T>> WhileAll<T>(this IList<Task<T>> tasks, IProgressAsync<T> progress) 4 { 5 var result = new List<T>(tasks.Count); 6 var remainingTasks = new List<Task<T>>(tasks); 7 8 while (remainingTasks.Count > 0) 9 { 10 await Task.WhenAny(tasks); 11 var stillRemainingTasks = new List<Task<T>>(remainingTasks.Count - 1); 12 for (int i = 0; i < remainingTasks.Count; i++) 13 { 14 if (remainingTasks[i].IsCompleted) 15 { 16 result.Add(remainingTasks[i].Result); 17 await progress.ReportAsync(remainingTasks[i].Result); 18 } 19 else 20 { 21 stillRemainingTasks.Add(remainingTasks[i]); 22 } 23 } 24 25 remainingTasks = stillRemainingTasks; 26 } 27 28 return result; 29 } 30 31 public static Task InvokeAsync<T>(this Func<T, Task> task, T value) 32 { 33 return Task<Task>.Factory.FromAsync(task.BeginInvoke, task.EndInvoke, value, null); 34 } 35 }
所有都就绪后,我们就可以将 Thread.Sleep 从单元测试中移除了。
1 [TestClass] 2 public class UnitTest1 3 { 4 private List<int> result = new List<int>(); 5 private async Task OnProgressAsync(int arg) 6 { 7 result.Add(arg); 8 } 9 10 [TestMethod] 11 public async Task TestTaskExtensionsWhileAll() 12 { 13 var task1 = Task.Run(() => 101); 14 var task2 = Task.Run(() => 102); 15 var tasks = new List<Task<int>>() { task1, task2 }; 16 17 var listener = new ProgressAsync<int>(this.OnProgressAsync); 18 var actual = await tasks.WhileAll(listener); 19 20 Assert.AreEqual(2, this.result.Count); 21 Assert.IsTrue(this.result.Contains(101)); 22 Assert.IsTrue(this.result.Contains(102)); 23 24 Assert.AreEqual(2, actual.Count); 25 Assert.IsTrue(actual.Contains(101)); 26 Assert.IsTrue(actual.Contains(102)); 27 } 28 }