本文使用 ThreadPoolExecutor实现一个带优先级的线程池,其实正常的实现方式是使用优先级队列(java.util.PriorityQueue / java.util.concurrent.PriorityBlockingQueue)这种方式没办法同步的获取结果, 编程上有点复杂, java.util.concurrent.ThreadPoolExecutor 可以 public <T> Future<T> submit(Callable<T> task); 使用Future.get(), 阻塞线程, 等待结果, 来实现同步调用。
public class PriorityThreadPoolExecutor extends ThreadPoolExecutor;
实现方法很简单, 继承 ThreadPoolExecutor 使用 PriorityBlockingQueue 优先级队列. PriorityBlockingQueue 有个坑就是.
Operations on this class make no guarantees about the ordering of elements with equal priority. *如果优先级相同,不能确定顺序. *
实际测试下来的结果是, 如果优先级相同则执行顺序跟插入顺序相反, 这就尴尬了, 着还是FIFO队列吗? 官网给了解决方式.对每一个队列元素编号, 照抄就可以了. 限制就是队列历史总个数不能超过 Long 个. 实现一个Comparable 的类。
class PriorityRunnable<E extends Comparable<? super E>> implements Runnable, Comparable<PriorityRunnable<E>>;
重载线程池的添加任务的方法,追加一个参数,如果使用基类的方法, 优先级为 0 。
public void execute(Runnable command, int priority); public <T> Future<T> submit(Callable<T> task, int priority); public <T> Future<T> submit(Runnable task, T result, int priority); public Future<?> submit(Runnable task, int priority);
最终代码如下:
1 package com.springboot.study.tests.threads; 2 3 /** 4 * @Author: guodong 5 * @Date: 2021/3/22 15:20 6 * @Version: 1.0 7 * @Description: 8 */ 9 import org.slf4j.Logger; 10 import org.slf4j.LoggerFactory; 11 import java.util.concurrent.*; 12 import java.util.concurrent.atomic.AtomicLong; 13 14 public class PriorityThreadPoolExecutor extends ThreadPoolExecutor { 15 16 private static final Logger log = LoggerFactory.getLogger(PriorityThreadPoolExecutor.class); 17 18 private ThreadLocal<Integer> local = new ThreadLocal<Integer>() { 19 @Override 20 protected Integer initialValue() { 21 return 0; 22 } 23 }; 24 25 public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit) { 26 super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue()); 27 } 28 29 public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory) { 30 super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory); 31 } 32 33 public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, RejectedExecutionHandler handler) { 34 super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), handler); 35 } 36 37 public PriorityThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, RejectedExecutionHandler handler) { 38 super(corePoolSize, maximumPoolSize, keepAliveTime, unit, getWorkQueue(), threadFactory, handler); 39 } 40 41 protected static PriorityBlockingQueue getWorkQueue() { 42 return new PriorityBlockingQueue(); 43 } 44 45 @Override 46 public void execute(Runnable command) { 47 int priority = local.get(); 48 try { 49 this.execute(command, priority); 50 } finally { 51 local.set(0); 52 } 53 } 54 55 public void execute(Runnable command, int priority) { 56 super.execute(new PriorityRunnable(command, priority)); 57 } 58 59 public <T> Future<T> submit(Callable<T> task, int priority) { 60 local.set(priority); 61 return super.submit(task); 62 } 63 64 public <T> Future<T> submit(Runnable task, T result, int priority) { 65 local.set(priority); 66 return super.submit(task, result); 67 } 68 69 public Future<?> submit(Runnable task, int priority) { 70 local.set(priority); 71 return super.submit(task); 72 } 73 74 protected static class PriorityRunnable<E extends Comparable<? super E>> implements Runnable, Comparable<PriorityRunnable<E>> { 75 private final static AtomicLong seq = new AtomicLong(); 76 private final long seqNum; 77 Runnable run; 78 private int priority; 79 80 public PriorityRunnable(Runnable run, int priority) { 81 seqNum = seq.getAndIncrement(); 82 this.run = run; 83 this.priority = priority; 84 } 85 86 public int getPriority() { 87 return priority; 88 } 89 90 public void setPriority(int priority) { 91 this.priority = priority; 92 } 93 94 public Runnable getRun() { 95 return run; 96 } 97 98 @Override 99 public void run() { 100 this.run.run(); 101 } 102 103 @Override 104 public int compareTo(PriorityRunnable<E> other) { 105 int res = 0; 106 if (this.priority == other.priority) { 107 if (other.run != this.run) {// ASC 108 res = (seqNum < other.seqNum ? -1 : 1); 109 } 110 } else {// DESC 111 res = this.priority > other.priority ? -1 : 1; 112 } 113 return res; 114 } 115 } 116 }
下面是测试用例
package com.springboot.study.tests.threads; /** * @Author: guodong * @Date: 2021/3/22 15:22 * @Version: 1.0 * @Description: */ import org.junit.Test; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import static org.junit.Assert.*; public class PriorityThreadPoolExecutorTest { @Test public void testDefault() throws InterruptedException, ExecutionException { PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES); Future[] futures = new Future[20]; StringBuffer buffer = new StringBuffer(); for (int i = 0; i < futures.length; i++) { int index = i; futures[i] = pool.submit(new Callable() { @Override public Object call() throws Exception { Thread.sleep(10); buffer.append(index + ", "); return null; } }); } // 等待所有任务结束 for (int i = 0; i < futures.length; i++) { futures[i].get(); } System.out.println(buffer); assertEquals("0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, ", buffer.toString()); } @Test public void testSamePriority() throws InterruptedException, ExecutionException { PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES); Future[] futures = new Future[10]; StringBuffer buffer = new StringBuffer(); for (int i = 0; i < futures.length; i++) { futures[i] = pool.submit(new TenSecondTask(i, 1, buffer), 1); } // 等待所有任务结束 for (int i = 0; i < futures.length; i++) { futures[i].get(); } System.out.println(buffer); assertEquals("01@00, 01@01, 01@02, 01@03, 01@04, 01@05, 01@06, 01@07, 01@08, 01@09, ", buffer.toString()); } @Test public void testRandomPriority() throws InterruptedException, ExecutionException { PriorityThreadPoolExecutor pool = new PriorityThreadPoolExecutor(1, 1000, 1, TimeUnit.MINUTES); Future[] futures = new Future[20]; StringBuffer buffer = new StringBuffer(); for (int i = 0; i < futures.length; i++) { int r = (int) (Math.random() * 100); futures[i] = pool.submit(new TenSecondTask(i, r, buffer), r); } // 等待所有任务结束 for (int i = 0; i < futures.length; i++) { futures[i].get(); } buffer.append("01@00"); System.out.println(buffer); String[] split = buffer.toString().split(", "); // 从 2 开始, 因为前面的任务可能已经开始 for (int i = 2; i < split.length - 1; i++) { String s = split[i].split("@")[0]; assertTrue(Integer.valueOf(s) >= Integer.valueOf(split[i + 1].split("@")[0])); } } public static class TenSecondTask<T> implements Callable<T> { private StringBuffer buffer; int index; int priority; public TenSecondTask(int index, int priority, StringBuffer buffer) { this.index = index; this.priority = priority; this.buffer = buffer; } @Override public T call() throws Exception { Thread.sleep(10); buffer.append(String.format("%02d@%02d", this.priority, index)).append(", "); return null; } } }