之前的token验证一直觉得很烂,今天做下优化,老项目就不贴出来了。
第一步
首先将过滤器注册到Bean容器,拦截相关请求,本来是想通过实现ApplicationContextAware接口的setApplicationContext去获取spring的上下文,但是容器启动时报错,才发现WebMvcConfigurationSupport中已经实现了,所以这里直接取用,new Fitler时要将applicationContext通过构造方法传入以便处理
package com.mbuyy.config; import com.mbuyy.common.CommonInterceptor; import com.mbuyy.filter.LoginFilter; import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.web.servlet.config.annotation.InterceptorRegistry; import org.springframework.web.servlet.config.annotation.WebMvcConfigurationSupport; import java.nio.charset.Charset; @Configuration public class WebConfig extends WebMvcConfigurationSupport { @Bean public HttpMessageConverter<String> responseBodyConverter() { StringHttpMessageConverter converter = new StringHttpMessageConverter( Charset.forName("UTF-8")); return converter; } @Override public void addInterceptors(InterceptorRegistry registry) { registry.addInterceptor(new CommonInterceptor()).addPathPatterns("/**"); super.addInterceptors(registry); } /** * 1、注册过滤器 * * @return */ @Bean public FilterRegistrationBean filterRegist() { FilterRegistrationBean frBean = new FilterRegistrationBean(); frBean.setFilter(new LoginFilter(this.getApplicationContext())); frBean.addUrlPatterns("/payment/*"); frBean.addUrlPatterns("/merchant/*"); frBean.addUrlPatterns("/mobile/*"); System.out.println("filter"); return frBean; } }
第二步
书写构造器
package com.mbuyy.filter; import com.mbuyy.annotation.TokenIgnoreUtils; import com.mbuyy.annotation.TokenVerify; import com.mbuyy.merchant.controller.*; import com.mbuyy.mobile.controller.*; import com.mbuyy.util.JWTUtils; import com.mbuyy.util.RedisUtils; import com.mbuyy.util.StringUtils; import com.mbuyy.util.ValidateUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.context.ApplicationContext; import javax.servlet.*; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; import java.util.Enumeration; import java.util.Map; import java.util.Objects; import java.util.regex.Pattern; import static com.mbuyy.constants.Constants.*; public class LoginFilter implements Filter { private static final Logger logger = LoggerFactory.getLogger(LoginFilter.class); private TokenIgnoreUtils tokenIgnoreUtils; private ApplicationContext applicationContext; Pattern pattern = Pattern.compile("^[-\+]?[\d]*$"); public LoginFilter() { } /** * 构造器注入spring的ApplicationContext上下文对象 * @param applicationContext */ public LoginFilter(ApplicationContext applicationContext) { this.applicationContext = applicationContext; } @Override public void destroy() { } @Override public void init(FilterConfig arg0) { tokenIgnoreUtils = new TokenIgnoreUtils(); // 获取所有添加了TokenVerifyAnnotion注解的类,注册并添加 Map<String, Object> beansWithAnnotationMap = this.applicationContext.getBeansWithAnnotation(TokenVerify.class); for (Object tokenVerifyBean : beansWithAnnotationMap.keySet()){ logger.info(tokenVerifyBean.toString()); Object bean = this.applicationContext.getBean(tokenVerifyBean.toString()); tokenIgnoreUtils.registerController(bean.getClass()); } // tokenIgnoreUtils.registerController(MerchantCouponController.class); // tokenIgnoreUtils.registerController(MerchantFlashPurchaseController.class); // tokenIgnoreUtils.registerController(MerchantNoticeController.class); // tokenIgnoreUtils.registerController(MerchantOrderController.class); // tokenIgnoreUtils.registerController(MerchantPayController.class); // tokenIgnoreUtils.registerController(MerchantPlatformController.class); // tokenIgnoreUtils.registerController(MerchantPopularizeController.class); // tokenIgnoreUtils.registerController(MerchantShopController.class); // tokenIgnoreUtils.registerController(MerchantTaskController.class); // tokenIgnoreUtils.registerController(MerchantUserController.class); // // // tokenIgnoreUtils.registerController(MobileCouponController.class); // tokenIgnoreUtils.registerController(MobileFlashPurchaseController.class); // tokenIgnoreUtils.registerController(MobileNoticeController.class); // tokenIgnoreUtils.registerController(MobileOrderController.class); // tokenIgnoreUtils.registerController(MobilePayController.class); // tokenIgnoreUtils.registerController(MobilePlatformController.class); // tokenIgnoreUtils.registerController(MobilePopularizeController.class); // tokenIgnoreUtils.registerController(MobileTaskController.class); // tokenIgnoreUtils.registerController(MobileThirdOrderController.class); // tokenIgnoreUtils.registerController(MobileUserController.class); } @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { //拿到对象 HttpServletRequest req = (HttpServletRequest) request; HttpServletResponse rep = (HttpServletResponse) response; setCrossDomian(req, rep); // 获取请求路径 String requestURI = req.getRequestURI(); // 0客户端token,1商户端token,2代驾端token int tokenType = getTokenType(req); String token = getToken(req); req.setAttribute(TOKEN_TYPE, tokenType); System.out.println("token:" + token); System.out.println("tokenType:" + tokenType); // TODO: 2019/1/10 这里有万能token,本地测试使用,别忘了去掉!!!! if (ValidateUtils.isNotEmpty(token) && pattern.matcher(token).matches()) { req.setAttribute(TOKEN_ID, Integer.parseInt(token)); chain.doFilter(request, response); return; } try { verify(req, token, tokenType,requestURI); } catch (Exception e) { e.printStackTrace(); failed(rep, e.getMessage()); return; } chain.doFilter(request, response); } private String getToken(HttpServletRequest req) { String token = ""; if (ValidateUtils.isNotEmpty(req.getHeader("token_merchant"))){ token = req.getHeader("token_merchant"); } else { token = req.getHeader("token"); } return token; } /** * 判断是来自哪个端的请求,哪个token有取哪一个,不同的token之间互斥 * @param req * @return */ private int getTokenType(HttpServletRequest req) { int tokenType = 0; Enumeration<String> headerNames = req.getHeaderNames(); String headName = headerNames.nextElement(); while (ValidateUtils.isNotEmpty(headName)){ if (headName.equals("token_merchant")){ tokenType = 1; break; } else if (headName.equals("token")){ tokenType = 0; break; } headName = headerNames.nextElement(); } return tokenType; } /** * 验证token * @param req * @param token * @param tokenType * @param requestURI */ private void verify(HttpServletRequest req, String token, int tokenType, String requestURI) throws Exception { //如果是登录就不需要验证 boolean isNeedToken = tokenIgnoreUtils.startCheck(requestURI); // boolean isNeedToken = false; if (!isNeedToken) { //如果不需要Token if (ValidateUtils.isNotEmpty(token)) { // 解析token parseToken(req, token, tokenType); } }else{ if (ValidateUtils.isNotEmpty(token)) { // 解析token parseToken(req, token, tokenType); }else{ //token没传 直接失败 throw new RuntimeException("请输入验证Token"); } } } /** * 设置跨域问题 * @param req * @param rep * @throws UnsupportedEncodingException */ private void setCrossDomian(HttpServletRequest req, HttpServletResponse rep) throws UnsupportedEncodingException { req.setCharacterEncoding("utf-8"); // 设置允许跨域访问的域,*表示支持所有的来源 rep.setHeader("Access-Control-Allow-Origin", "*"); //Access-Control-Allow-Origin System.out.println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>"); // 设置允许跨域访问的方法 rep.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE"); rep.setHeader("Access-Control-Max-Age", "3600"); rep.setHeader("Access-Control-Allow-Headers", "*"); } /** * 解析token得到tokenId * * @param req * @param token * @param tokenType * @return * @throws Exception */ private void parseToken(HttpServletRequest req, String token, int tokenType) throws Exception { //存在token使用JWT鉴权,鉴权失败以无token处理 Map<String, String> tokenMap; try { tokenMap = JWTUtils.getInstance().verifyToken(token); } catch (JWTUtils.TokenException e) { throw new RuntimeException("鉴权异常"); } // 根据token获取token_id Integer tokenId = Integer.parseInt(tokenMap.get(TOKEN_ID)); if (tokenId == null){ throw new RuntimeException("TokenId获取失败"); } switch (tokenType){ case 1 : parseMerchantToken(tokenId, token); break; default: parseCustomerToken(tokenId, token); break; } //放行 //添加token里的user_id req.setAttribute(TOKEN_ID, tokenId); logger.info("token_id = " + tokenId); } private void parseCustomerToken(Integer tokenId, String token) throws Exception { logger.info("解析客户端token:" + token); // redis缓存中是否存在token if(!RedisUtils.exists(TOKEN + tokenId)){ throw new RuntimeException("redis match error"); } String redisToken = RedisUtils.get(TOKEN + tokenId); if (!Objects.equals(token, redisToken) && !("-1".equals(redisToken))) { throw new RuntimeException("Token Error"); } // int result = userPunishService.verifyUser(tokenId); // switch (result){ // case 1: throw new RuntimeException("Account has been deleted"); // case 2: throw new RuntimeException("The account has been disabled"); //// case 3: throw new RuntimeException("Account has been deleted"); break; // } } private void parseMerchantToken(Integer tokenId, String token) throws Exception { logger.info("解析商户端token:" + token); // redis缓存中是否存在token if(!RedisUtils.exists(TOKEN_SHOP + tokenId)){ throw new RuntimeException("匹配异常"); } String redisToken = RedisUtils.get(TOKEN_SHOP + tokenId); if (!Objects.equals(token, redisToken) && !("-1".equals(redisToken))) { throw new RuntimeException("Token错误"); } // int result = userPunishService.verifyMerchant(tokenId); // switch (result){ // case 1: throw new RuntimeException("Account has been deleted"); // case 2: throw new RuntimeException("The account has been disabled"); //// case 3: throw new RuntimeException("Account has been deleted"); break; // } } private void failed(HttpServletResponse rep, String msg) throws IOException { PrintWriter w = rep.getWriter(); w.write("{"status": 401,"msg": "" + msg + ""}"); w.flush(); w.close(); } }
根据注解(TokenIgnore)验证是否需要Token
public class TokenIgnoreUtils { private List<Class<?>> mControllerClass; public TokenIgnoreUtils() { mControllerClass = new ArrayList<>(); } public void registerController(Class<?> clazz) { mControllerClass.add(clazz); } /** * 该方法有个缺点,就是所有controller下的requestMapping路径长度必须一致/ , * 例:必须是/user/list,不能是/user/collect/list * @param reqUrl * @return */ public boolean startCheck(String reqUrl) { for (Class clazz : mControllerClass) { Method[] methods = clazz.getMethods(); for (Method method : methods) { RequestMapping requestMapping = method.getAnnotation(RequestMapping.class); // PostMapping postMapping = method.getAnnotation(PostMapping.class); // GetMapping getMapping = method.getAnnotation(GetMapping.class); // DeleteMapping deleteMapping = method.getAnnotation(DeleteMapping.class); // PutMapping putMapping = method.getAnnotation(PutMapping.class); if (requestMapping == null) { continue; } if (ConfigConstants.TOKEN_UTILS_TEST) { if (reqUrl.contains(requestMapping.value()[0]) && clazz.getSimpleName().toLowerCase().contains(reqUrl.split("/")[1])) { TokenIgnore tokenIgnore = method.getAnnotation(TokenIgnore.class); return tokenIgnore == null; } } else { String requestMappingUrl = ""; // 类上的映射 RequestMapping cRequestMapping = (RequestMapping) clazz.getAnnotation(RequestMapping.class); if (cRequestMapping != null && cRequestMapping.value() != null && cRequestMapping.value().length > 0){ requestMappingUrl += cRequestMapping.value()[0]; } // 方法上的映射 if (requestMapping != null && requestMapping.value() != null && requestMapping.value().length > 0){ requestMappingUrl += requestMapping.value()[0]; } // 1、判断映射路径是否包含在请求路径中,2、进一步判断映射路径,是否为请求路径最后的全路径 if (reqUrl.contains(requestMappingUrl) && reqUrl.lastIndexOf(requestMappingUrl) + requestMappingUrl.length() == reqUrl.length()) { System.out.println("requestMappingUrl = " + requestMappingUrl); TokenIgnore tokenIgnore = method.getAnnotation(TokenIgnore.class); System.out.println(tokenIgnore == null); return tokenIgnore == null; } } } } return true; } }
第三步
获取当前登录用户信息
Integer currentUserId = (Integer) request.getAttribute(TOKEN_ID);
笔者还很菜,要走的路还很长,如有不妥之处,还请联系小编不吝赐教