什么是ThreadLocal?
ThreadLocal直译为“线程本地”或“本地线程”,如果真的这么认为,那就错了!其实它就是一个容器,用于存放线程的局部变量,应该叫ThreadLocalVariable(线程局部变量)才对。
早在JDK1.2的时代,java.lang.ThreadLocal就诞生了,它是为了解决多线程并发问题而设计的,只不过设计得有些难用而已,所以至今没有得到广泛的应用。
一个序列号生成器的程序可能同时会有多个线程并发访问它,要保证每个线程得到的序列号都是自增的,而补鞥呢互相干扰。
先定义一个接口:
public interface Sequence { int getNumber(); }
每次调用getNumber方法可获取一个序列号,下次再调用时,序列号会自增。
在做一个线程类:
public class ClientThread extends Thread{ private Sequence sequence; public ClientThread(Sequence sequence) { this.sequence = sequence; } @Override public void run() { for (int i=0;i<3;i++){ System.out.println(Thread.currentThread().getName() + " =>"+sequence.getNumber()); } } }
在线程中连续输出三次线程名与其对应的序列号。
我们不用ThreadLocal,先做一个实现类:
public class SequenceA implements Sequence { private static int number = 0; @Override public int getNumber() { number = number +1; return number; } public static void main(String[] args) { Sequence sequence = new SequenceA(); ClientThread thread1 = new ClientThread(sequence); ClientThread thread2 = new ClientThread(sequence); ClientThread thread3 = new ClientThread(sequence); thread1.start(); thread2.start(); thread3.start(); } }
序列号初始值是0,在main方法中模拟了三个线程,运行后结果如下:
分析发现,线程之间共享的static变量无法保证对于不同线程而言是安全的,也就是说,此时无法保证"线程安全"。
那么如何才能做到“线程安全”呢?对于这个案例,就是说不同的线程可拥有自己的static变量,如何实现呢?下面看另一个实现:
public class SequenceB implements Sequence { //private static int number = 0; private static ThreadLocal<Integer> numberContainer = new ThreadLocal<Integer>(){ @Override protected Integer initialValue() { return 0; } }; @Override public int getNumber() { numberContainer.set(numberContainer.get()+1); return numberContainer.get(); } public static void main(String[] args) { Sequence sequence = new SequenceB(); ClientThread thread1 = new ClientThread(sequence); ClientThread thread2 = new ClientThread(sequence); ClientThread thread3 = new ClientThread(sequence); thread1.start(); thread2.start(); thread3.start(); } }
通过ThreadLocal封装了一个Integer类型的numberContainer静态成员变量,并且初始值是0。再看getNumber方法,首先从numberContainer中get出当前的值,加1,随后set到numberContainer中,最后在numberContainer中get出当前的值并返回。
是不是很绕?但是很强大!我们不妨把ThreadLocal看作是一个容器,这样理解起来就简单了。所以,这里故意用了Container这个词作为后缀来命名ThreadLocal变量。
每个线程独立了,同样是static变量,对于不同的线程而言,它没有被共享,而是每个线程各一份,这样也就保证了线程安全。也就是说,ThreadLocal为每一个线程提供了一个独立的副本。
搞清楚ThreadLocal的原理后,总结一下API:
public void set(T value):将值放入线程局部变量中;
public T get():从线程局部变量中获取值;
public void remove():从线程局部变量中移除值(有助于JVM垃圾回收);
protected T initialValue():返回线程局部变量中的初始值(默认为null)。
为什么initialValue方法是protected的呢?就是为了提醒程序员,这个方法是要程序员来实现的,要给这个线程局部变量设置一个初始值。
自己实现ThreadLocal
熟悉了原理之后与这些API之后,可以想想ThreadLocal里面不就是封装了一个Map吗?我们自己可以写一个ThreadLocal了:
package com.autumn.threadlocal; import java.util.Collections; import java.util.HashMap; import java.util.Map; /** * @program: MyThreadLocal * @description: 模式ThreadLocal * @author: Created by Autumn * @create: 2018-11-21 17:09 */ public class MyThreadLocal<T> { private Map<Thread,T> container = Collections.synchronizedMap(new HashMap<Thread, T>()); public void set(T value){ container.put(Thread.currentThread(),value); } public T get(){ Thread thread = Thread.currentThread(); T value = container.get(thread); if (value == null && !container.containsKey(thread)){ value = initialValue(); container.put(thread,value); } return value; } public void remove(){ container.remove(Thread.currentThread()); } protected T initialValue(){ return null; } }
上面定义了一个山寨版的ThreadLocal,其中定义了一个同步Map(这个操作会在map上加锁)
写个类运行一下
/** * @program: SequenceB * @description: 用ThreadLocal实现线程共享 * @author: Created by Autumn * @create: 2018-11-21 15:45 */ public class SequenceC implements Sequence { //private static int number = 0; private static MyThreadLocal<Integer> numberContainer = new MyThreadLocal<Integer>(){ @Override protected Integer initialValue() { return 0; } }; @Override public int getNumber() { numberContainer.set(numberContainer.get()+1); return numberContainer.get(); } public static void main(String[] args) { Sequence sequence = new SequenceC(); ClientThread thread1 = new ClientThread(sequence); ClientThread thread2 = new ClientThread(sequence); ClientThread thread3 = new ClientThread(sequence); thread1.start(); thread2.start(); thread3.start(); } }
返回结果
只是把ThreadLocal换成了MyThreadLocal而已,运行效果和之前的一样,也是正确的。
提示:当在一个类中使用了static成员变量的时候,一定要多问问自己,这个static成员变量考虑“线程安全”了吗?也就是说,多个线程需要独享自己的static成员变量吗?如果需要考虑,不妨用ThreadLocal。
ThreadLocal使用例子
ThreadLocal具体有哪些使用案例呢?
首先要说的就是通过ThreadLocal存放JDBC Connection,以达到事务控制的能力。
记得在很久以前,用户提出过一个需求,需求就很繁琐,就一句话:
当修改产品价格的时候,需要记录操作日志,什么时候做了什么事情。
想必这个案例,只要是做个应用系统的小伙伴都应该遇到过。不外乎数据库里就两张表:product与log,用两条sql语句应该就可以解决问题:
update product set price = ? and id = ? insert into log(created,description) values(?,?)
但要确保这两条sql语句必须在同一个事务里进行提交,否则有可能update提交了,但是insert却没有提交。
为了解决这个问题,首先我们写一个DBUtil的工具类
/** * @program: DBUtil * @description: 数据库配置工具类 * @author: qiuyu * @create: 2018-11-28 05:52 **/ public class DBUtil { private static final Logger LOGGER = LoggerFactory.getLogger(DBUtil.class); //数据库配置 private static final String DRIVER = "com.mysql.jdbc.Driver"; private static final String URL = "jdbc:mysql://222.222.221.198:3306/customer"; private static final String USERNAME = "root"; private static final String PASSWORD = "root"; //定义一个数据库连接 private static Connection conn = null; /** * 获取数据库连接 * @return */ public static Connection getConnection(){ try { /*JDBC获取连接*/ Class.forName(DRIVER); conn = DriverManager.getConnection(URL,USERNAME,PASSWORD); } catch (Exception e) { e.printStackTrace(); //在catlina.out中打印 LOGGER.error("get connection failure",e); } return conn; } /** * 关闭数据库连接 * @param conn */ public static void closeConnection(Connection conn){ if (conn!=null){ try { conn.close(); } catch (SQLException e) { e.printStackTrace(); LOGGER.error("close connection failure",e); } } } }
里面设置了一个static的Connection,这下数据库连接就好操作了。
然后定义一个借口用于逻辑层调用:
/** * 接口 - 更新数据添加日志表记录 */ public interface ProductService { void updateProductPrice(long id,int price); }
根据productId去更新对应的Product的price,然后再插入一条数据到log表中。
实现类
/** * @program: ProductServiceImpl * @description: ProductService实现类 * @author: qiuyu * @create: 2018-11-28 06:04 **/ public class ProductServiceImpl implements ProductService{ private static final String UPDATE_PRODUCT_SQL = "update product set price = ? where id = ?"; private static final String INSERT_LOG_SQL = "insert into log(createid,description) value (?,?)"; @Override public void updateProductPrice(long id, int price) { try { //获取连接 Connection conn = DBUtil.getConnection(); conn.setAutoCommit(false); //关闭自动提交事物(开启事物) //执行操作 updateProduct(conn,UPDATE_PRODUCT_SQL,id,price); //更新产品 insertLog(conn,INSERT_LOG_SQL,"create product."); //插入日志 //提交事物 conn.commit(); }catch (Exception e){ e.printStackTrace(); }finally { DBUtil.closeConnection(); //关闭连接 } } private void updateProduct(Connection conn,String updateProdutSQL,long productId,int productPrice) throws SQLException { PreparedStatement pstmt = conn.prepareStatement(updateProdutSQL); pstmt.setInt(1,productPrice); pstmt.setLong(2,productId); int rows = pstmt.executeUpdate(); if (rows != 0){ System.out.println("Update Product Success!"); } } private void insertLog(Connection conn,String insertLogSQL,String logDescription) throws SQLException { PreparedStatement pstmt = conn.prepareStatement(insertLogSQL); pstmt.setString(1,new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date())); pstmt.setString(2,logDescription); int rows = pstmt.executeUpdate(); if (rows != 0){ System.out.println("Insert log Success!"); } } }
这里用到了JDBC的高级特性Transaction。暗自庆幸了一番后,是不是有必要写一个客户端来测试一下执行结果是不是我想要的呢?于是偷懒,直接在ProductServiceImpl中加了一个main方法:
public static void main(String[] args) { ProductService service = new ProductServiceImpl(); service.updateProductPrice(1,3000); }
运行程序
作为一名专业的程序员,为了万无一失,我一定要到数据库里再看看。没错!product表对应的记录更新了,log表也插入了一条记录。这样就可以将ProductService接口交付给别人来调用了。
几个小时过去了,QA妹妹开始对着我嚷:“那谁!我刚才模拟10个请求,你这个接口怎么就挂了?报错说是数据库连接关闭了!”。
她是用工具模拟的,也就是模拟多个线程了!那我也可以模拟,于是写了一个线程类:
/** * @program: ClientThread * @description: 线程类 * @author: qiuyu * @create: 2018-11-28 07:16 **/ public class ClientThread extends Thread { private ProductService productService; public ClientThread(ProductService productService) { this.productService = productService; } @Override public void run() { System.out.println(Thread.currentThread().getName()); productService.updateProductPrice(1,3000); } }
用这个线程去调用ProductService的方法,看看是不是有问题。此时,还要再修改一下main方法:
public static void main(String[] args) { /*调用*/ /*ProductService service = new ProductServiceImpl(); service.updateProductPrice(1,3000);*/ /*多线程调用*/ for (int i=1;i<10;i++){ ProductService service = new ProductServiceImpl(); ClientThread thread = new ClientThread(service); thread.start(); } }
模拟十个线程,运行结果如下:
没想到!竟然在多线程的环境下报错了,果然是数据库连接关闭了。怎么回事呢?我陷入了沉思中。在百度、Google,还有OSC上都查找了那句报错信息,解答实在是千奇百怪。
既然是跟Connection有关系,那就将主要精力放在检查Connection相关的代码上。是不是Connection不应该是static呢?当初设计成static的主要目的是为了让DBUtil的static方法访问更加方便,用static变量来存放Connection也提高了性能。怎么办呢?
后来看到OSC上非常火爆的一片文章“ThreadLocal”那点事,才终于明白了,原来要使每个线程都拥有自己的连接,而不是共享同一个连接,否则“线程一”有可能会关闭“线程二”的连接,所以“线程二”就报错了。
于是将DBUtil重构:
public class DBUtil { private static final Logger LOGGER = LoggerFactory.getLogger(DBUtil.class); //数据库配置 private static final String DRIVER = "com.mysql.jdbc.Driver"; private static final String URL = "jdbc:mysql://222.222.221.198:3306/demo2"; private static final String USERNAME = "root"; private static final String PASSWORD = "root"; //定义一个数据库连接 //private static Connection conn = null; //定义一个用于放置数据库连接的局部线程变量(是每个线程拥有自己的连接) private static ThreadLocal<Connection> connContainer = new ThreadLocal<Connection>(); /** * 获取数据库连接 * @return */ public static Connection getConnection(){ Connection conn = connContainer.get(); //从ThreadLocal中获取conn try { if (conn == null) { //从ThreadLocal中拿到的conn如果为null /*JDBC获取连接*/ Class.forName(DRIVER); conn = DriverManager.getConnection(URL, USERNAME, PASSWORD); } } catch (Exception e) { e.printStackTrace(); //在catlina.out中打印 LOGGER.error("get connection failure",e); }finally { connContainer.set(conn); } return conn; } /** * 关闭数据库连接 */ public static void closeConnection(){ Connection conn = connContainer.get(); //从ThreadLocal中获取conn if (conn!=null){ try { conn.close(); } catch (SQLException e) { e.printStackTrace(); LOGGER.error("close connection failure",e); }finally { connContainer.remove(); //从ThreadLocal中删除当前线程的conn } } } }
把Connection放入到了ThreadLocal中,这样每个线程之间就隔离了,不会互相干扰了。
此外,在getConnection方法中,首先从ThreadLocal(也就是ConnContainer)中获取Connection,如果没有,就通过JDBC来创建连接,最后再把创建好的连接放入这个ThreadLocal中。可以把ThreadLocal看作一个容器。
同样也对closeConnection方法做了重构,先从容器中获取Connection,拿到了就close掉,最后从容器中将其remove掉,以保持容器的清洁。
注意:该示例仅用于ThreadLocal的基本用法。在实际工作中,推荐使用连接池来管理数据库连接。
该demo中虽然每次都从当前线程中获取Connection,Connection是线程隔离的,但是Mysql连接不是隔离的,Connection对象有可能用到其他Mysql的连接,如果mysql的连接关闭了,但是Connection对象没关闭会造成 java.sql.SQLException: No operations allowed after connection closed.