12

【学习笔记】深入理解ThreadLocal

 3 years ago
source link: https://mp.weixin.qq.com/s?__biz=MzI5NjY0MzEwNA%3D%3D&%3Bmid=2247484762&%3Bidx=1&%3Bsn=d6668ca0ee31a4d341b48e07e86aac55
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

目录

  • 一 引言

  • 二 源码解析

  • 三 案例

  • 四 总结

一 引言

ThreadLocal的官方API解释为:

 * This class provides thread-local variables.  These variables differ from
 * their normal counterparts in that each thread that accesses one (via its
 * {@code get} or {@code set} method) has its own, independently initialized
 * copy of the variable.  {@code ThreadLocal} instances are typically private
 * static fields in classes that wish to associate state with a thread (e.g.,
 * a user ID or Transaction ID).

这个类提供线程局部变量。这些变量与正常的变量不同,每个线程访问一个(通过它的get或set方法)都有它自己的、独立初始化的变量副本。ThreadLocal实例通常是类中的私有静态字段,希望将状态与线程关联(例如,用户ID或事务ID)。

1、当使用ThreadLocal维护变量时,ThreadLocal为每个使用该变量的线程提供独立的变量副本,
        所以每一个线程都可以独立地改变自己的副本,而不会影响其它线程所对应的副本
2、使用ThreadLocal通常是定义为 private static ,更好是 private final static
3、Synchronized用于线程间的数据共享,而ThreadLocal则用于线程间的数据隔离
4、ThreadLocal类封装了getMap()、Set()、Get()、Remove()4个核心方法

从表面上来看ThreadLocal内部是封闭了一个Map数组,来实现对象的线程封闭,map的key就是当前的线程id,value就是我们要存储的对象。

实际上是ThreadLocal的静态内部类ThreadLocalMap为每个Thread都维护了一个数组table,hreadLocal确定了一个数组下标,而这个下标就是value存储的对应位置,继承自弱引用,用来保存ThreadLocal和Value之间的对应关系,之所以用弱引用,是为了解决线程与ThreadLocal之间的强绑定关系,会导致如果线程没有被回收,则GC便一直无法回收这部分内容。

二 源码剖析

2.1 ThreadLocal

    //set方法
    public void set(T value) {
        //获取当前线程
        Thread t = Thread.currentThread();
        //实际存储的数据结构类型
        ThreadLocalMap map = getMap(t);
        //判断map是否为空,如果有就set当前对象,没有创建一个ThreadLocalMap
        //并且将其中的值放入创建对象中
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

    //get方法 
    public T get() {
        //获取当前线程
        Thread t = Thread.currentThread();
        //实际存储的数据结构类型
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //传入了当前线程的ID,到底层Map Entry里面去取
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }   

    //remove方法
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);//调用ThreadLocalMap删除变量
     }

       //ThreadLocalMap中getEntry方法
      private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        } 

   //getMap()方法
   ThreadLocalMap getMap(Thread t) {
    //Thread中维护了一个ThreadLocalMap
        return t.threadLocals;
    }

    //setInitialValue方法
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

    //createMap()方法
   void createMap(Thread t, T firstValue) {
   //实例化一个新的ThreadLocalMap,并赋值给线程的成员变量threadLocals
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

从上面源码中我们看到不管是 set() get() remove() 他们都是操作ThreadLocalMap这个静态内部类的,每一个新的线程Thread都会实例化一个ThreadLocalMap并赋值给成员变量threadLocals,使用时若已经存在threadLocals则直接使用已经存在的对象

ThreadLocal.get()

  • 获取当前线程对应的ThreadLocalMap

  • 如果当前ThreadLocal对象对应的Entry还存在,并且返回对应的值

  • 如果获取到的ThreadLocalMap为null,则证明还没有初始化,就调用setInitialValue()方法

ThreadLocal.set()

  • 获取当前线程,根据当前线程获取对应的ThreadLocalMap

  • 如果对应的ThreadLocalMap不为null,则调用set方法保存对应关系

  • 如果为null,创建一个并保存k-v关系

ThreadLocal.remove()

  • 获取当前线程,根据当前线程获取对应的ThreadLocalMap

  • 如果对应的ThreadLocalMap不为null,则调用ThreadLocalMap中的remove方法,根据key.threadLocalHashCode & (len-1)获取当前下标并移除

  • 成功后调用expungeStaleEntry进行一次连续段清理

3emqqiN.png!mobile

2.2 ThreadLocalMap

ThreadLocalMap是ThreadLocal的一个内部类

static class ThreadLocalMap {

         /**    
         * 自定义一个Entry类,并继承自弱引用
         * 同时让ThreadLocal和储值形成key-value的关系
         * 之所以用弱引用,是为了解决线程与ThreadLocal之间的强绑定关系
         * 会导致如果线程没有被回收,则GC便一直无法回收这部分内容
         * 
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * Entry数组的初始化大小(初始化长度16,后续每次都是2倍扩容)
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * 根据需要调整大小
         * 长度必须是2的N次幂
         */
        private Entry[] table;

        /**
         * The number of entries in the table.
         * table中的个数
         */
        private int size = 0;

        /**
         * The next size value at which to resize.
         * 下一个要调整大小的大小值(扩容的阈值)
         */
        private int threshold; // Default to 0

        /**
         * Set the resize threshold to maintain at worst a 2/3 load factor.
         * 根据长度计算扩容阈值
         * 保持一定的负债系数
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

        /**
         * Increment i modulo len
         * nextIndex:从字面意思我们可以看出来就是获取下一个索引
         * 获取下一个索引,超出长度则返回
         */
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }

        /**
         * Decrement i modulo len.
         * 返回上一个索引,如果-1为负数,返回长度-1的索引
         */
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }

        /**
         * ThreadLocalMap构造方法
         * ThreadLocalMaps是延迟构造的,因此只有在至少要放置一个节点时才创建一个
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            //内部成员数组,INITIAL_CAPACITY值为16的常量
            table = new Entry[INITIAL_CAPACITY];
            //通过threadLocalHashCode(HashCode) & (长度-1)的位运算,确定键值对的位置
            //位运算,结果与取模相同,计算出需要存放的位置
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            // 创建一个新节点保存在table当中
            table[i] = new Entry(firstKey, firstValue);
            //设置table元素为1
            size = 1;
            //根据长度计算扩容阈值
            setThreshold(INITIAL_CAPACITY);
        }

        /**
         * 构造一个包含所有可继承ThreadLocals的新映射,只能createInheritedMap调用
         * ThreadLocal本身是线程隔离的,一般来说是不会出现数据共享和传递的行为
         */
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }

        /**
         * ThreadLocalMap中getEntry方法
         */
        private Entry getEntry(ThreadLocal<?> key) {
            //通过hashcode确定下标
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            //如果找到则直接返回
            if (e != null && e.get() == key)
                return e;
            else
             // 找不到的话接着从i位置开始向后遍历,基于线性探测法,是有可能在i之后的位置找到的
                return getEntryAfterMiss(key, i, e);
        }


        /**
         * ThreadLocalMap的set方法
         */
        private void set(ThreadLocal<?> key, Object value) {
           //新开一个引用指向table
            Entry[] tab = table;
            //获取table长度
            int len = tab.length;
            ////获取索引值,threadLocalHashCode进行一个位运算(取模)得到索引i
            int i = key.threadLocalHashCode & (len-1);
            /**
            * 遍历tab如果已经存在(key)则更新值(value)
            * 如果该key已经被回收失效,则替换该失效的key
            **/
            //
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }
                //如果 k 为null,则替换当前失效的k所在Entry节点
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //如果上面没有遍历成功则创建新值
            tab[i] = new Entry(key, value);
            // table内元素size自增
            int sz = ++size;
            //满足条件数组扩容x2
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

        /**
         * remove方法
         * 将ThreadLocal对象对应的Entry节点从table当中删除
         */
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();//将引用设置null,方便GC回收
                    expungeStaleEntry(i);//从i的位置开始连续段清理工作
                    return;
                }
            }
        }

        /**
        * ThreadLocalMap中replaceStaleEntry方法
         */
        private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
            // 新建一个引用指向table
            Entry[] tab = table;
            //获取table的长度
            int len = tab.length;
            Entry e;


            // 记录当前失效的节点下标
            int slotToExpunge = staleSlot;

           /**
             * 通过prevIndex(staleSlot, len)可以看出,由staleSlot下标向前扫描
             * 查找并记录最前位置value为null的下标
             */
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                if (e.get() == null)
                    slotToExpunge = i;

            // nextIndex(staleSlot, len)可以看出,这个是向后扫描
            // occurs first
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                 // 获取Entry节点对应的ThreadLocal对象
                ThreadLocal<?> k = e.get();

                  //如果和新的key相等的话,就直接赋值给value,替换i和staleSlot的下标
                if (k == key) {
                    e.value = value;

                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // 如果之前的元素存在,则开始调用cleanSomeSlots清理
                    if (slotToExpunge == staleSlot)
                        slotToExpunge = i;
                     /**
                     *在调用cleanSomeSlots()    清理之前,会调用
                     *expungeStaleEntry()从slotToExpunge到table下标所在为
                     *null的连续段进行一次清理,返回值就是table为null的下标
                     *然后以该下标 len进行一次启发式清理
                     * 最终里面的方法实际上还是调用了expungeStaleEntry
                      * 可以看出expungeStaleEntry方法是ThreadLocal核心的清理函数
                     */
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // 如果在table中没有找到这个key,则直接在当前位置new Entry(key, value)
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // 如果有其他过时的节点正在运行,会将它们进行清除,slotToExpunge会被重新赋值
            if (slotToExpunge != staleSlot)
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }

        /**
         * expungeStaleEntry() 启发式地清理被回收的Entry
         * 有两个地方调用到这个方法
         * 1、set方法,在判断是否需要resize之前,会清理并rehash一遍
         * 2、替换失效的节点时候,也会进行一次清理
        */
          private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                //判断如果Entry对象不为空
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    //调用该方法进行回收,
                    //对 i 开始到table所在下标为null的范围内进行一次清理和rehash
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }  

        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    int h = k.threadLocalHashCode & (len - 1);
                    if (h != i) {
                        tab[i] = null;
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }



        /**
         * Re-pack and/or re-size the table. First scan the entire
         * table removing stale entries. If this doesn't sufficiently
         * shrink the size of the table, double the table size.
         */
        private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }

        /**
         * 对table进行扩容,因为要保证table的长度是2的幂,所以扩容就扩大2倍
         */
        private void resize() {
        //获取旧table的长度
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            int newLen = oldLen * 2;
            //创建一个长度为旧长度2倍的Entry数组
            Entry[] newTab = new Entry[newLen];
            //记录插入的有效Entry节点数
            int count = 0;

             /**
             * 从下标0开始,逐个向后遍历插入到新的table当中
             * 通过hashcode & len - 1计算下标,如果该位置已经有Entry数组,则通过线性探测向后探测插入
             */
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {//如遇到key已经为null,则value设置null,方便GC回收
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
            // 重新设置扩容的阈值
            setThreshold(newLen);
            // 更新size
            size = count;
             // 指向新的Entry数组
            table = newTab;
        }


    }

ThreadLocalMap.set()

  • key.threadLocalHashCode & (len-1),将threadLocalHashCode进行一个位运算(取模)得到索引 " i " ,也就是在table中的下标

  • for循环遍历,如果Entry中的key和我们的需要操作的ThreadLocal的相等,这直接赋值替换

  • 如果拿到的key为null ,则调用replaceStaleEntry()进行替换

  • 如果上面的条件都没有成功满足,直接在计算的下标中创建新值

  • 在进行一次清理之后,调用rehash()下的resize()进行扩容

ThreadLocalMap.expungeStaleEntry()

  • 这是 ThreadLocal 中一个核心的清理方法

  • 为什么需要清理?

  • 在我们 Entry 中,如果有很多节点是已经过时或者回收了,但是在table数组中继续存在,会导致资源浪费

  • 我们在清理节点的同时,也会将后面的Entry节点,重新排序, 调整Entry大小 ,这样我们在取值(get())的时候,可以快速定位资源,加快我们的程序的获取效率

ThreadLocalMap.remove()

  • 我们在使用remove节点的时候,会使用线性探测的方式,找到当前的key

  • 如果当前key一致,调用clear()将引用指向null

  • 从"i"开始的位置进行一次连续段清理

三 案例

目录结构:

IrQVRv2.png!mobile 在这里插入图片描述

HttpFilter.java

package com.lyy.threadlocal.config;

import lombok.extern.slf4j.Slf4j;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;

@Slf4j
public class HttpFilter implements Filter {

//初始化需要做的事情
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    //核心操作在这个里面
    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest)servletRequest;
//        request.getSession().getAttribute("user");
        System.out.println("do filter:"+Thread.currentThread().getId()+":"+request.getServletPath());
        RequestHolder.add(Thread.currentThread().getId());
        //让这个请求完,,同时做下一步处理
        filterChain.doFilter(servletRequest,servletResponse);


    }

    //不再使用的时候做的事情
    @Override
    public void destroy() {

    }
}

HttpInterceptor.java

package com.lyy.threadlocal.config;

import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

public class HttpInterceptor extends HandlerInterceptorAdapter {

    //接口处理之前
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        System.out.println("preHandle:");
        return true;
    }

    //接口处理之后
    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {
        RequestHolder.remove();
       System.out.println("afterCompletion");

        return;
    }
}

RequestHolder.java

package com.lyy.threadlocal.config;

public class RequestHolder {

    private final static ThreadLocal<Long> requestHolder = new ThreadLocal<>();//

    //提供方法传递数据
    public static void add(Long id){
        requestHolder.set(id);

    }

    public static Long getId(){
        //传入了当前线程的ID,到底层Map里面去取
        return requestHolder.get();
    }

    //移除变量信息,否则会造成逸出,导致内容永远不会释放掉
    public static void remove(){
        requestHolder.remove();
    }
}

ThreadLocalController.java

package com.lyy.threadlocal.controller;

import com.lyy.threadlocal.config.RequestHolder;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
@RequestMapping("/thredLocal")
public class ThreadLocalController {

    @RequestMapping("test")
    @ResponseBody
    public Long test(){
        return RequestHolder.getId();
    }

}

ThreadlocalDemoApplication.java

package com.lyy.threadlocal;

import com.lyy.threadlocal.config.HttpFilter;
import com.lyy.threadlocal.config.HttpInterceptor;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurerAdapter;

@SpringBootApplication
public class ThreadlocalDemoApplication extends WebMvcConfigurerAdapter {

    public static void main(String[] args) {
        SpringApplication.run(ThreadlocalDemoApplication.class, args);
    }

    @Bean
    public FilterRegistrationBean httpFilter(){
        FilterRegistrationBean registrationBean = new FilterRegistrationBean();
        registrationBean.setFilter(new HttpFilter());
        registrationBean.addUrlPatterns("/thredLocal/*");


        return registrationBean;
    }


    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(new HttpInterceptor()).addPathPatterns("/**");
    }

}

输入:http://localhost:8080/thredLocal/test

7zAV3ii.png!mobile

后台打印:

do filter:35:/thredLocal/test preHandle:
afterCompletion

四 总结

1、ThreadLocal是通过每个线程单独一份存储空间,每个ThreadLocal只能保存一个变量副本。

2、相比于Synchronized,ThreadLocal具有线程隔离的效果,只有在线程内才能获取到对应的值,线程外则不能访问到想要的值,很好的实现了线程封闭。

3、每次使用完ThreadLocal,都调用它的remove()方法,清除数据,避免内存泄漏的风险

4、通过上面的源码分析,我们也可以看到大神在写代码的时候会考虑到整体实现的方方面面,一些逻辑上的处理是真严谨的,我们在看源代码的时候不能只是做了解,也要看到别人实现功能后面的目的。

源码地址:https://github.com/839022478/other/tree/master/threadlocal_demo

我是牧小农,怕什么真理无穷,进一步有进一步的欢喜,大家加油!

--  End  --

———————

1.原创不易,你的 在看 是我创作的动力。

2.欢迎关注公众号  牧小码农 「带你一起学Java」

3.疫情期间,勤洗手,戴口罩,做好个人防护。

“在看转发” 是最大的支持


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK