ThreadPoolExecutor是可扩展的,其提供了几个可在子类化中改写的方法,如下:
protected void beforeExecute(Thread t, Runnable r) { } protected void afterExecute(Runnable r, Throwable t) { } protected void terminated() { }
现基于此,完成一个统计每个线程执行耗时,并计算平均耗时的 自定义线程池样例。通过 beforeExecute、afterExecute、terminated 方法来添加日志记录和统计信息收集。为了测量任务的运行时间,beforeExecute必须记录开始时间并把它保存到一个ThreadLocal变量中,然后由afterExecute来读取。同时,使用两个 AtomicLong变量,分别用以记录已处理的任务数和总的处理时间,并通过terminated来输出包含平均任务时间的日志消息。
自定义线程池代码如下:
import java.util.concurrent.BlockingQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Logger; /** * 自定义线程池 */ public class TimingThreadPool extends ThreadPoolExecutor { private final ThreadLocal<Long> startTime = new ThreadLocal<>(); private final Logger log = Logger.getLogger("TimingThreadPool"); private final AtomicLong numTasks = new AtomicLong(); private final AtomicLong totalTime = new AtomicLong(); public TimingThreadPool(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) { super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue); } @Override protected void beforeExecute(Thread t, Runnable r) { super.beforeExecute(t, r); log.info(String.format("Thread %s: start %s",t,r)); startTime.set(System.nanoTime()); } @Override protected void afterExecute(Runnable r, Throwable t) { try { long endTime = System.nanoTime(); long taskTime = endTime - startTime.get(); numTasks.incrementAndGet(); totalTime.addAndGet(taskTime); log.info(String.format("Thread %s: end %s, time=%dns",t,r,taskTime)); } finally { super.afterExecute(r,t); } } @Override protected void terminated() { try { log.info(String.format("Terminated: avg time=%dns",totalTime.get() / numTasks.get())); } finally { super.terminated(); } } }
测试执行效果代码如下:
import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; /** * 测试自定义线程池 */ public class TestCustomThreadPool { public static void main(String[] args) { try { TimingThreadPool threadPool = new TimingThreadPool(10,10,0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<Runnable>()); List<TestCallable> tasks = new ArrayList<>(); for (int i = 0 ; i < 100 ; i++) { tasks.add(new TestCallable()); } List<Future<Long>> futures = threadPool.invokeAll(tasks); for (Future<Long> future : futures) { System.out.print(" - "+future.get()); } threadPool.shutdown(); } catch (Exception e) { e.printStackTrace(); } } static class TestCallable implements Callable<java.lang.Long> { @Override public Long call() throws Exception { long total = 0 ; for (int i = 0 ; i < 100 ; i++) { long now = getRandom(); total += now; } Thread.sleep(total); return total; } public long getRandom () { return Math.round(Math.random() * 10); } } }
执行结果: