接上文:封装多线程处理大量数据操作(一)
我们需要解决WaitAny和取得异步执行的返回值的问题。地球人都知道Thread和ThreadPool接受的委托都是没有返回值的。要想取的返回值,我们就得自己动手了,我们需要构造一个AsyncContext类,由这个类来保存异步执行的状态以并存储返回值。
代码如下:
using System; using System.Collections.Generic; using System.Text; using System.Collections; using System.Threading; using System.Diagnostics; namespace AppUtility {
public delegate object DoGetObjTask(object state); public static class AsyncHelper { /// <summary> /// 执行多线程操作任务 /// </summary> /// <param name="dataCollection">多线程操作的数据集合</param> /// <param name="threadCn">分多少个线程来做</param> /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> public static void DoAsync(IList dataCollection, int threadCn, WaitCallback processItemMethod) { DoAsync(dataCollection, threadCn, processItemMethod, true); } /// <summary> /// 执行多线程操作任务 /// </summary> /// <param name="dataCollection">多线程操作的数据集合</param> /// <param name="threadCn">分多少个线程来做</param> /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> /// <param name="needWaitAll">是否需要等待所有线程执行完毕才返回,为true时会等待所有线程执行完毕,否则则是在有一个线程执行完毕就返回</param> public static void DoAsync(IList dataCollection, int threadCn, DoGetObjTask processItemMethod, bool needWaitAll, out Hashtable processResult) { DoAsyncPrivate(dataCollection, threadCn, null, processItemMethod, needWaitAll, true, out processResult); } /// <summary> /// 执行多线程操作任务 /// </summary> /// <param name="dataCollection">多线程操作的数据集合</param> /// <param name="threadCn">分多少个线程来做</param> /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> /// <param name="needWaitAll">是否需要等待所有线程执行完毕才返回,为true时会等待所有线程执行完毕,否则则是在有一个线程执行完毕就返回</param> public static void DoAsync(IList dataCollection, int threadCn, DoGetObjTask processItemMethod, out Hashtable processResult) { DoAsyncPrivate(dataCollection, threadCn, null, processItemMethod, true, true, out processResult); } /// <summary> /// 执行多线程操作任务 /// </summary> /// <param name="dataCollection">多线程操作的数据集合</param> /// <param name="threadCn">分多少个线程来做</param> /// <param name="processItemMethod">处理数据集合中单个数据使用的处理方法</param> /// <param name="needWaitAll">是否需要等待所有线程执行完毕才返回,为true时会等待所有线程执行完毕,否则则是在有一个线程执行完毕就返回</param> public static void DoAsync(IList dataCollection, int threadCn, WaitCallback processItemMethod, bool needWaitAll) { Hashtable hash; DoAsyncPrivate(dataCollection, threadCn, processItemMethod, null, needWaitAll, false, out hash); } private static void DoAsyncPrivate(IList dataCollection, int threadCn, WaitCallback processItemMethod, DoGetObjTask getObjMethod, bool needWaitAll, bool hasReturnValue, out Hashtable processResult) { if (dataCollection == null) throw new ArgumentNullException("dataCollection"); if (threadCn >= 64 || threadCn < 2) { throw new ArgumentOutOfRangeException("threadCn", "threadCn 参数必须在2和64之间"); } if (threadCn > dataCollection.Count) threadCn = dataCollection.Count; IList[] colls = new ArrayList[threadCn]; DataWithStateList dataWithStates = new DataWithStateList(); AutoResetEvent[] evts = new AutoResetEvent[threadCn]; for (int i = 0; i < threadCn; i++) { colls[i] = new ArrayList(); evts[i] = new AutoResetEvent(false); } for (int i = 0; i < dataCollection.Count; i++) { object obj = dataCollection[i]; int threadIndex = i % threadCn; colls[threadIndex].Add(obj); dataWithStates.Add(new DataWithState(obj, ProcessState.WaitForProcess)); } AsyncContext context = AsyncContext.GetContext(threadCn, dataWithStates, needWaitAll, hasReturnValue, processItemMethod, getObjMethod); for (int i = 0; i < threadCn; i++) { ThreadPool.QueueUserWorkItem(DoPrivate, new object[] { colls[i],context,evts[i] }); } if (needWaitAll) { WaitHandle.WaitAll(evts); } else { WaitHandle.WaitAny(evts); context.SetBreakSignal(); } processResult = context.ProcessResult; } private class AsyncContext { static public AsyncContext GetContext( int threadCn, DataWithStateList dataWithStates, bool needWaitAll, bool hasReturnValue, WaitCallback processItemMethod, DoGetObjTask hasReturnValueMethod ) { AsyncContext context = new AsyncContext(); context.ThreadCount = threadCn; context.DataWithStates = dataWithStates; context.NeedWaitAll = needWaitAll; if (hasReturnValue) { Hashtable processResult = Hashtable.Synchronized(new Hashtable()); context.ProcessResult = processResult; context.HasReturnValueMethod = hasReturnValueMethod; } else { context.VoidMethod = processItemMethod; } context.HasReturnValue = hasReturnValue; return context; } internal int ThreadCount; internal DataWithStateList DataWithStates; internal bool NeedWaitAll; internal bool HasReturnValue; internal WaitCallback VoidMethod; internal DoGetObjTask HasReturnValueMethod; private bool _breakSignal; private Hashtable _processResult; internal Hashtable ProcessResult { get { return _processResult; } set { _processResult = value; } } internal void SetReturnValue(object obj, object result) { lock (_processResult.SyncRoot) { _processResult[obj] = result; } } internal void SetBreakSignal() { if (NeedWaitAll) throw new NotSupportedException("设定为NeedWaitAll时不可设置BreakSignal"); _breakSignal = true; } internal bool NeedBreak { get { return !NeedWaitAll && _breakSignal; } } internal void Exec(object obj) { if (HasReturnValue) { SetReturnValue(obj, HasReturnValueMethod(obj)); } else { VoidMethod(obj); } DataWithStates.SetState(obj, ProcessState.Processed); } } private enum ProcessState : byte { WaitForProcess = 0, Processing = 1, Processed = 2 } private class DataWithStateList : List<DataWithState> { public void SetState(object obj, ProcessState state) { lock (((ICollection)this).SyncRoot) { DataWithState dws = this.Find(delegate(DataWithState i) { return Object.Equals(i.Data, obj); }); if (dws != null) { dws.State = state; } } } public ProcessState GetState(object obj) { lock (((ICollection)this).SyncRoot) { DataWithState dws = this.Find(delegate(DataWithState i) { return Object.Equals(i.Data, obj); }); return dws.State; } } private int GetCount(ProcessState state) { List<DataWithState> datas = this.FindAll(delegate(DataWithState i) { return i.State == state; }); if (datas == null) return 0; return datas.Count; } public int WaitForDataCount { get { return GetCount(ProcessState.WaitForProcess); } } internal object GetWaitForObject() { lock (((ICollection)this).SyncRoot) { DataWithState dws = this.Find(delegate(DataWithState i) { return i.State == ProcessState.WaitForProcess; }); if (dws == null) return null; dws.State = ProcessState.Processing; return dws.Data; } } internal bool IsWaitForData(object obj, bool setState) { lock (((ICollection)this).SyncRoot) { DataWithState dws = this.Find(delegate(DataWithState i) { return i.State == ProcessState.WaitForProcess; }); if (setState && dws != null) dws.State = ProcessState.Processing; return dws != null; } } } private class DataWithState { public readonly object Data; public ProcessState State; public DataWithState(object data, ProcessState state) { Data = data; State = state; } } private static int _threadNo = 0; private static void DoPrivate(object state) { object[] objs = state as object[]; IList datas = objs[0] as IList; AsyncContext context = objs[1] as AsyncContext; AutoResetEvent evt = objs[2] as AutoResetEvent; DataWithStateList objStates = context.DataWithStates; #if DEBUG Thread.CurrentThread.Name = "Thread " + _threadNo; Interlocked.Increment(ref _threadNo); string threadName = Thread.CurrentThread.Name + "[" + Thread.CurrentThread.ManagedThreadId + "]"; Trace.WriteLine("线程ID:" + threadName); #endif if (datas != null) { for (int i = 0; i < datas.Count; i++) { if (context.NeedBreak) { #if DEBUG Trace.WriteLine("线程" + threadName + "未执行完跳出"); #endif break; } object obj = datas[i]; if (objStates.IsWaitForData(obj, true)) { if (context.NeedBreak) { #if DEBUG Trace.WriteLine("线程" + threadName + "未执行完跳出"); #endif break; } context.Exec(obj); #if DEBUG Trace.WriteLine(string.Format("线程{0}处理{1}", threadName, obj)); #endif } } } if (context.NeedWaitAll) { //如果执行完当前进程的数据,还要查看剩下多少没有做,如果还剩下超过ThreadCount个没有做 while (objStates.WaitForDataCount > context.ThreadCount) { if (context.NeedBreak) break; object obj = objStates.GetWaitForObject(); if (obj != null && objStates.IsWaitForData(obj, false)) { if (context.NeedBreak) { #if DEBUG Trace.WriteLine("线程" + threadName + "未执行完跳出"); #endif break; } context.Exec(obj); #if DEBUG Trace.WriteLine(string.Format("线程{0}执行另一个进程的数据{1}", threadName, obj)); #endif } } } evt.Set(); } } }
如何使用AsyncHelper类,请看下面的测试代码:
using System; using System.Collections.Generic; using System.Text; using System.Diagnostics; using AppUtility; using System.IO; using System.Collections; using System.Threading; namespace ConsoleApplication2 { class Program { static void Main(string[] args) { Stopwatch sw = new Stopwatch(); sw.Start(); /* List<string> testFiles = new List<string>(); for (int i = 0; i < 100; i++) { testFiles.Add("D:\\test\\async\\file_" + i.ToString() + ".log"); } AsyncHelper.DoAsync(testFiles, 10, WriteFile); Console.WriteLine("异步写耗时"+sw.ElapsedMilliseconds + "ms"); */ List<string> testFiles = new List<string>(); for (int i = 0; i < 200; i++) { testFiles.Add("D:\\test\\async\\file_" + i.ToString() + ".log"); } Hashtable result; AsyncHelper.DoAsync(testFiles, 20, WriteFileAndReturnRowCount,false,out result); Console.WriteLine("异步写耗时" + sw.ElapsedMilliseconds + "ms"); Thread.Sleep(10); if (result != null) { foreach (object key in result.Keys) { Console.WriteLine("{0}={1}", key,result[key]); } } sw.Reset(); sw.Start(); for (int i = 0; i < 200; i++) { WriteFile("D:\\test\\sync\\file_" + i.ToString() + ".log"); } Console.WriteLine("同步写耗时" + sw.ElapsedMilliseconds + "ms"); Console.Read(); } static void WriteFile(object objFilePath) { string filePath = (string)objFilePath; string dir = Path.GetDirectoryName(filePath); if (!Directory.Exists(dir)) { Directory.CreateDirectory(dir); } //Random r = new Random(DateTime.Now.Minute); int rowCn = 10000; using (StreamWriter writer = new StreamWriter(filePath, false, Encoding.Default)) { for (int i = 0; i < rowCn; i++) writer.WriteLine(Guid.NewGuid()); } } static object WriteFileAndReturnRowCount(object objFilePath) { string filePath = (string)objFilePath; string dir = Path.GetDirectoryName(filePath); if (!Directory.Exists(dir)) { Directory.CreateDirectory(dir); } //Random r = new Random(DateTime.Now.Minute); int rowCn = 10000; using (StreamWriter writer = new StreamWriter(filePath, false, Encoding.Default)) { for (int i = 0; i < rowCn ; i++) writer.WriteLine(Guid.NewGuid()); } return DateTime.Now.ToLongTimeString(); } } } Sorry,代码太多,文字太少。发个牢骚,代码写完之后,再写思路是一件痛苦的事情!