• 【JDK源码分析】深入理解ThreadLocal以及破坏它的线程隔离机制


    前言

    众所周知ThreadLocal提供了线程局部变量,独立于变量的初始化副本。ThreadLocal设计初衷是用来存放与当前线程绑定的对象,其它线程不应该去访问也不能访问。文末会用例子来举例说明不当使用会破坏这种设计。

    通过源码深入理解

    ThreadLocal 通过set方法设置的变量并非是放在ThreadLocal对象中,而是通过一个ThreadLocal.ThreadLocalMap类型的对象与当前线程绑定。下面通过源码一探究竟。
    Thread类中有个ThreadLocalMap类型的局部变量,而且在线程退出后会清除threadLocals线程变量,也就是ThreadLocalMap对象的生命周期和当前线程一样,这就是不建议在线程池中使用ThreadLocal的原因。

    public class Thread implements Runnable {
        // 此变量用来存放之前提到的线程局部变量
        ThreadLocal.ThreadLocalMap threadLocals = null;
        ...
        private void exit() {
            ...
            threadLocals = null;
            ...
        }   
    }

    ThreadLocal set方法

    再来看ThreadLocal,set方法是将值value设置到与当前线程绑定的ThreadLocalMap 对象中。

        public void set(T value) {
            Thread t = Thread.currentThread();
            // 以当前线程为键,获取当前线程的ThreadLocalMap对象,此对象就是存放通过初始化的变量或者set方法设置的变量
            ThreadLocalMap map = getMap(t);
            // 当前线程的ThreadLocalMap对象存在时,直接将当前ThreadLocal对象作为键,需要set的对象作为值设置到map中
            if (map != null)
                map.set(this, value);
            else
                // 当前线程ThreadLocalMap对象不存在时,为其创建变量
                createMap(t, value);
        }
    
        // 获取线程t的threadLocals变量
        ThreadLocalMap getMap(Thread t) {
            return t.threadLocals;
        }
    
        //当线程t中threadLocals对象还未被赋值时,需要为其初始化
        void createMap(Thread t, T firstValue) {
            t.threadLocals = new ThreadLocalMap(this, firstValue);
        }

    再来看看ThreadLocal 静态内部类ThreadLocalMap,此map数据结构是一个Entry数组,ThreadLocalMap类中的Entry继承了弱引用,而弱引用会在下次垃圾回收时被回收。 也就是说,当threadLocal对象为强引用时,下次垃圾回收会将ThreadLocalMap中entry的key对象当作垃圾回收,而我们一般在使用中都会将ThreadLocal 作为静态变量使用,此时ThreadLocal对象为强引用,其生命周期和就会和其被引用的类生命周期一样长。

        static class ThreadLocalMap {
    
            // ThreadLocalMap中存放的就是Entry类型的数组,而Entry是以当前ThreadLocal对象为key的键值对
            // Entry中的key为弱引用类型
            static class Entry extends WeakReference<ThreadLocal<?>> {
                // value与当前ThreadLocal对象关联
                Object value;
                Entry(ThreadLocal<?> k, Object v) {
                    super(k);
                    value = v;
                }
            }
    
            private static final int INITIAL_CAPACITY = 16;
    
            private Entry[] table;
    
            /**
             * The number of entries in the table.
             */
            private int size = 0;
    
            /**
             * 扩容阈值
             */
            private int threshold; // Default to 0
    
            /**
             * Set the resize threshold to maintain at worst a 2/3 load factor.
             */
            private void setThreshold(int len) {
                threshold = len * 2 / 3;
            }
    
            /**
             * Increment i modulo len.
             */
            private static int nextIndex(int i, int len) {
                return ((i + 1 < len) ? i + 1 : 0);
            }
    
            /**
             * Decrement i modulo len.
             */
            private static int prevIndex(int i, int len) {
                return ((i - 1 >= 0) ? i - 1 : len - 1);
            }
    
            /**
             * Construct a new map initially containing (firstKey, firstValue).
             * ThreadLocalMaps are constructed lazily, so we only create
             * one when we have at least one entry to put in it.
             */
            ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
                table = new Entry[INITIAL_CAPACITY];
                // rehash得到firstKey在table的位置
                int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
                table[i] = new Entry(firstKey, firstValue);
                size = 1;
                setThreshold(INITIAL_CAPACITY);
            }
    
            private Entry getEntry(ThreadLocal<?> key) {
                int i = key.threadLocalHashCode & (table.length - 1);
                Entry e = table[i];
                if (e != null && e.get() == key)
                    return e;
                else
                    // 在i位置上没有找到对象
                    return getEntryAfterMiss(key, i, e);
            }
            //向i往后遍历entry数组,找到后直接返回;如果没找到还会清除tab中弱引用key被垃圾回收的entry元素(即将该entry设置为null)
            private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
                Entry[] tab = table;
                int len = tab.length;
    
                while (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == key)
                        return e;
                    // e不为空,而key为空,意味key被垃圾回收
                    if (k == null)
                        // 清除已失效的entry
                        expungeStaleEntry(i);
                    else
                        // 移动到下个位置,每次递增1
                        i = nextIndex(i, len);
                    e = tab[i];
                }
                return null;
            }
            // 清楚失效entry
            private int expungeStaleEntry(int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
    
                // 此处将value设置为null,便于gc
                tab[staleSlot].value = null;
                tab[staleSlot] = null;
                size--;
    
                // Rehash until we encounter null
                Entry e;
                int i;
                // 扫描tab数组中staleSlot右边的元素,清除被垃圾回收的entry
                for (i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null;
                        tab[i] = null;
                        size--;
                    } else {
                        int h = k.threadLocalHashCode & (len - 1);
                        if (h != i) {
                            tab[i] = null;
    
                            // Unlike Knuth 6.4 Algorithm R, we must scan until
                            // null because multiple entries could have been stale.
                            while (tab[h] != null)
                                h = nextIndex(h, len);
                            tab[h] = e;
                        }
                    }
                }
                return i;
            }
            private void set(ThreadLocal<?> key, Object value) {
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);
    
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                    ThreadLocal<?> k = e.get();
    
                    //找到key时,覆盖value
                    if (k == key) {
                        e.value = value;
                        return;
                    }
                    // 此时表示改entry的key已被垃圾回收
                    if (k == null) {
                        //用当前值替换掉已失效的entry
                        replaceStaleEntry(key, value, i);
                        return;
                    }
                }
    
                tab[i] = new Entry(key, value);
                int sz = ++size;
                if (!cleanSomeSlots(i, sz) && sz >= threshold)
                    rehash();
            }          
            private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                           int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
                Entry e;
    
                // 反向查找key值失效的最小i
                int slotToExpunge = staleSlot;
                for (int i = prevIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = prevIndex(i, len))
                    if (e.get() == null)
                        slotToExpunge = i;
    
                // Find either the key or trailing null slot of run, whichever
                // occurs first
                for (int i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {
                    ThreadLocal<?> k = e.get();
    
                    // If we find key, then we need to swap it
                    // with the stale entry to maintain hash table order.
                    // The newly stale slot, or any other stale slot
                    // encountered above it, can then be sent to expungeStaleEntry
                    // to remove or rehash all of the other entries in run.
                    if (k == key) {
                        e.value = value;
    
                        tab[i] = tab[staleSlot];
                        tab[staleSlot] = e;
    
                        // Start expunge at preceding stale entry if it exists
                        if (slotToExpunge == staleSlot)
                            slotToExpunge = i;
                         //清除无效的entry,后面的操作都是尽量清除掉失效的entry
                        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                        return;
                    }
    
                    // If we didn't find stale entry on backward scan, the
                    // first stale entry seen while scanning for key is the
                    // first still present in the run.
                    if (k == null && slotToExpunge == staleSlot)
                        slotToExpunge = i;
                }
    
                // If key not found, put new entry in stale slot
                tab[staleSlot].value = null;
                tab[staleSlot] = new Entry(key, value);
    
                // If there are any other stale entries in run, expunge them
                if (slotToExpunge != staleSlot)
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            }

    ThreadLocal get方法

    get方法通过获取当前线程的ThreadLocal.ThreadLocalMap类型的变量,然后通过当前ThreadLocal对象作为key在map中做查找。第一次rehash查询未命中的时候,会触发对数组向右遍历查询直到查询命中或没找到为止,中途会清楚entry数组中的无效元素并且会rehash重新指定key不为null的entry在数组中的位置。

        public T get() {
            //通过当前线程获取到ThreadLocalMap对象
            Thread t = Thread.currentThread();
            ThreadLocalMap map = getMap(t);
            if (map != null) {
                //以当前ThreadLocal对象作为键,在ThreadLocalMap中查找对应的value
                ThreadLocalMap.Entry e = map.getEntry(this);
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    T result = (T)e.value;
                    return result;
                }
            }
            return setInitialValue();
        }
            // 通过当前threadLocal对象获取对应的entry
            private Entry getEntry(ThreadLocal<?> key) {
                int i = key.threadLocalHashCode & (table.length - 1);
                Entry e = table[i];
                if (e != null && e.get() == key)
                    return e;
                else
                    return getEntryAfterMiss(key, i, e);
            }

    ThreadLocal remove方法

    remove方法是通过当前线程中threadLocals对象,找到以当前threadLocal对象为key的entry,然后将其从threadLocals中删除。同时每次使用完threadLocal get到的变量,记得remove操作。

         public void remove() {
             //获取当前线程的threadLocals对象
             ThreadLocalMap m = getMap(Thread.currentThread());
             if (m != null)
                 m.remove(this);
         }
            // 删除threadLocals对象中以当前threadLocal对象为key的entry
            private void remove(ThreadLocal<?> key) {
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);
                //向后遍历数组
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                    if (e.get() == key) {
                        e.clear();
                        expungeStaleEntry(i);
                        return;
                    }
                }
            }    

    使用方式

    开启1条线程,分别对2个threadlocal变量进行读写,然后观察是否影响了主线程读取threadlocal变量

        private static ThreadLocal<String> stringThreadLocal = ThreadLocal.withInitial(() -> "initialValue");
    
        private static ThreadLocal<Integer> integerThreadLocal = ThreadLocal.withInitial(() -> 0);
    
        public static void main(String[] args) throws InterruptedException {
    
            Thread thread1 = new Thread(() -> {
                stringThreadLocal.set("thread-1");
                integerThreadLocal.set(129);
                System.out.println(String.format("%s-----stringThreadLocalValue=%s",
                        Thread.currentThread().getName(), stringThreadLocal.get()));
                System.out.println(String.format("%s-----integerThreadLocalValue=%s",
                        Thread.currentThread().getName(), integerThreadLocal.get()));
    
            }, "thread-1");
            thread1.start();
            thread1.join();
    
            System.out.println(String.format("%s-----stringThreadLocalValue=%s",
                    Thread.currentThread().getName(), stringThreadLocal.get()));
            System.out.println(String.format("%s-----integerThreadLocalValue=%s",
                    Thread.currentThread().getName(), integerThreadLocal.get()));
        }  

    输出结果可见各线程之间对象不会互相影响。

    thread-1-----stringThreadLocalValue=thread-1
    thread-1-----integerThreadLocalValue=129
    main-----stringThreadLocalValue=initialValue
    main-----integerThreadLocalValue=0
     

    如果ThreadLocal中引用的变量是静态变量会怎么样?动态变量呢?

        private static ThreadLocal<Map> mapThreadLocal = new ThreadLocal<Map>() {
            public Map initialValue() {
                return Collections.emptyMap();
            }
    
        };
        private static Map<String, String> staticMap = new HashMap<String, String>() {{
            put("initKey", "initValue");
        }};
    
    
        public static void main(String[] args) throws InterruptedException {
            // 主线程将staticMap设置到当前线程中
            mapThreadLocal.set(staticMap);
            System.out.println(String.format("%s---------initKey->%s",
                    Thread.currentThread().getName(), mapThreadLocal.get().get("initKey")));      
            Thread thread_0 = new Thread(() -> {
    
                System.out.println(String.format("%s-----initKey->%s",
                        Thread.currentThread().getName(), mapThreadLocal.get().get("initKey")));
                // 线程0,修改staticMap的值
                staticMap.put("initKey", "newValue");
            });
            thread_0.start();
            thread_0.join();
            System.out.println(String.format("%s---------initKey->%s",
                    Thread.currentThread().getName(), mapThreadLocal.get().get("initKey")));
            Thread thread_1 = new Thread(() -> {
                // 线程1将新创建的map对象设置到线程1中
                mapThreadLocal.set(new HashMap<String, String>() {{
                    put("initKey", "newerValue");
                }});
                System.out.println(String.format("%s-----initKey->%s",
                        Thread.currentThread().getName(), mapThreadLocal.get().get("initKey")));
            }, "thread-1");
            thread_1.start();
            thread_1.join();
            System.out.println(String.format("%s---------initKey->%s",
                    Thread.currentThread().getName(), mapThreadLocal.get().get("initKey")));
        }

    理解了源码,这个运行结果就不会感到意外了。动态变量也和这个类似,threadlocal中设置的可变对象如果被其它线程修改,当前线程get到的数据就不是原来的值了。

    main---------initKey->initValue
    Thread-0-----initKey->null
    main---------initKey->newValue
    thread-1-----initKey->newerValue
    main---------initKey->newValue

    总结

    1. ThreadLocal并不是为每条线程创建副本变量,一般情况下线程的threadLocals变量为该线程独有,其他线程无法访问。
    2. ThreadLocal在每次被调用get、set、remove方法时,都会在碰到entry不为null, entry的key值为null时对ThreadLocalMap里的entry数组会做向后扫描在遇到下一个entry为null时终止扫描,该过程中会清除key为null的entry。
    3. 谨慎使用ThreadLocal引用可变对象,该对象被set到某个线程时,其它线程如果可以访问到该可变对象并且对其进行修改则会影响到之前线程threadLocals的值。
    4. 每次使用完ThreadLocal时,要记得及时remove。
  • 相关阅读:
    软件工程实践总结-黄紫仪
    beta冲刺总结附(分工)-咸鱼
    beta冲刺总结-咸鱼
    beta冲刺7-咸鱼
    beta冲刺用户测评-咸鱼
    beta冲刺6-咸鱼
    beta冲刺5-咸鱼
    beta冲刺4-咸鱼
    beta冲刺3-咸鱼
    beta冲刺2-咸鱼
  • 原文地址:https://www.cnblogs.com/d-homme/p/9357007.html
Copyright © 2020-2023  润新知