• Java xss攻击拦截,Java CSRF跨站点伪造请求拦截


    Java xss攻击拦截,Java CSRF跨站点伪造请求拦截

    ================================

    ©Copyright 蕃薯耀 2021-05-07

    https://www.cnblogs.com/fanshuyao/

    一、Java xss攻击拦截

    XssFilter过滤器

    import java.io.IOException;
    
    import javax.servlet.Filter;
    import javax.servlet.FilterChain;
    import javax.servlet.FilterConfig;
    import javax.servlet.ServletException;
    import javax.servlet.ServletRequest;
    import javax.servlet.ServletResponse;
    import javax.servlet.http.HttpServletRequest;
    
    public class XssFilter implements Filter{
        FilterConfig filterConfig = null;  
          
        public void init(FilterConfig filterConfig) throws ServletException {  
            this.filterConfig = filterConfig;  
        }  
      
        public void destroy() {  
            this.filterConfig = null;  
        }  
      
        public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
            chain.doFilter(new XssWrapper((HttpServletRequest) request), response);  
        }  
    }

    XssWrapper类:

    import java.io.BufferedReader;
    import java.io.ByteArrayInputStream;
    import java.io.IOException;
    import java.io.InputStream;
    import java.io.InputStreamReader;
    import java.nio.charset.Charset;
    import java.util.HashMap;
    import java.util.LinkedHashMap;
    import java.util.Map;
    
    import javax.servlet.ReadListener;
    import javax.servlet.ServletInputStream;
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpServletRequestWrapper;
    
    import org.apache.commons.lang3.StringUtils;
    import org.apache.commons.text.StringEscapeUtils;
    
    import com.szpl.csgx.utils.JsonUtil;
    
    
    public class XssWrapper extends HttpServletRequestWrapper {
        
        public static final String JSON_TYPE = "application/json";
        public static final String CONTENT_TYPE = "Content-Type";
        public static final String CHARSET = "UTF-8";
        private String mBody;
        HttpServletRequest originalRequest = null;
    
        public XssWrapper(HttpServletRequest request) throws IOException {
            super(request);
            // 将body数据存储起来
            originalRequest = request;
            setRequestBody(request.getInputStream());
        }
        
        
        /**
         * 获取最原始的request。已经被getInputStream()了。
         *
         * @return
         */
        public HttpServletRequest getOrgRequest() {
            return originalRequest;
        }
    
        
        /**
         * 获取最原始的request的静态方法。已经被getInputStream()了。
         *
         * @return
         */
        public static HttpServletRequest getOriginalRequest(HttpServletRequest req) {
            if (req instanceof XssWrapper) {
                return ((XssWrapper) req).getOrgRequest();
            }
            return req;
        }
    
        
        @Override
        public String getHeader(String name) {
            String value = super.getHeader(name);
            if(StringUtils.isBlank(value)) {
                return value;
            }
            return StringEscapeUtils.escapeHtml4(value);
        }
        
        
        @Override
        public String getQueryString() {
            return StringUtils.isBlank(super.getQueryString()) ? "" : StringEscapeUtils.escapeHtml4(super.getQueryString());
        }
    
    
        @Override
        public String getParameter(String name) {
            String value = super.getParameter(name);
            if(StringUtils.isBlank(value)) {
                return value;
            }
            return StringEscapeUtils.escapeHtml4(value);
        }
    
        
        @Override
        public String[] getParameterValues(String name) {
            String[] values = super.getParameterValues(name);
            if (values == null) {
                return values;
            }
            
            for (int i=0; i < values.length; i++) {
                values[i] = StringEscapeUtils.escapeHtml4(values[i]);
            }
            return values;
        }
        
        
        @Override
        public Map<String, String[]> getParameterMap() {
            Map<String, String[]> map = new LinkedHashMap<String, String[]>();
            Map<String, String[]> parameterMap = super.getParameterMap();
            
            if(parameterMap == null) {
                return super.getParameterMap();
            }
            
            for (String key : parameterMap.keySet()) {
                String[] values = parameterMap.get(key);
                if(values != null && values.length > 0) {
                    for (int i = 0; i < values.length; i++) {
                        values[i] = StringEscapeUtils.escapeHtml4(values[i]);
                    }
                }
                map.put(key, values);
            }
            return map;
        }
        
    
        private void setRequestBody(InputStream stream) {
            String line = "";
            StringBuilder body = new StringBuilder();
            
            // 读取POST提交的数据内容
            BufferedReader reader = new BufferedReader(new InputStreamReader(stream, Charset.forName(CHARSET)));
            try {
                while ((line = reader.readLine()) != null) {
                    body.append(line);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            mBody = body.toString();
            
            if(StringUtils.isBlank(mBody)) {//为空时,直接返回
                return;
            }
            @SuppressWarnings("unchecked")
            Map<String,Object> map= JsonUtil.string2Obj(mBody, Map.class);
            
            Map<String,Object> resultMap=new HashMap<>(map.size());
            for(String key : map.keySet()){
                Object val = map.get(key);
                
                if(map.get(key) instanceof String){
                    resultMap.put(key, StringEscapeUtils.escapeHtml4(val.toString()));
                }else{
                    resultMap.put(key, val);
                }
            }
            mBody = JsonUtil.obj2String(resultMap);
        }
    
        @Override
        public BufferedReader getReader() throws IOException {
            return new BufferedReader(new InputStreamReader(getInputStream()));
        }
    
        @Override
        public ServletInputStream getInputStream() throws IOException {
            
            if(!JSON_TYPE.equalsIgnoreCase(super.getHeader(CONTENT_TYPE))) {//非json类型,直接返回
                return super.getInputStream();
            }
            
            if(StringUtils.isBlank(mBody)) {//为空时,直接返回
                return super.getInputStream();
            }
            
            final ByteArrayInputStream bais = new ByteArrayInputStream(mBody.getBytes(CHARSET));
            return new ServletInputStream() {
                @Override
                public int read() throws IOException {
                    return bais.read();
                }
                @Override
                public boolean isFinished() {
                    return false;
                }
                @Override
                public boolean isReady() {
                    return false;
                }
                @Override
                public void setReadListener(ReadListener listener) {
                }
            };
        }
    
    }

    二、Java CSRF跨站点伪造请求拦截

    import java.io.IOException;
    import java.net.URL;
    
    import javax.servlet.Filter;
    import javax.servlet.FilterChain;
    import javax.servlet.FilterConfig;
    import javax.servlet.ServletException;
    import javax.servlet.ServletRequest;
    import javax.servlet.ServletResponse;
    import javax.servlet.http.HttpServletRequest;
    import javax.servlet.http.HttpServletResponse;
    
    import org.apache.commons.lang3.StringUtils;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.stereotype.Component;
    
    import cn.hutool.json.JSONUtil;
    
    
    /** 
     * CSRF跨站点伪造请求拦截
     */
    @Component
    public class CsrfFilter implements Filter {  
        
        //后台日志打印
        private Logger log = LoggerFactory.getLogger(CsrfFilter.class);
        
        
        //跨站点请求白名单,通过英文逗号分隔。在application.properties配置
        @Value("${csrf.white.paths}")
        private String[] csrfWhitePaths;
        
        //跨站点请求域名白名单,通过英文逗号分隔。在application.properties配置
        @Value("${csrf.white.domains}")
        private String[] csrfWhiteDomains;
        
        
        @Override
        public void init(FilterConfig filterConfig) throws ServletException {
            log.info("init……");
        }
        
        public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
                throws IOException, ServletException {
            HttpServletRequest req = (HttpServletRequest) request;
            HttpServletResponse res = (HttpServletResponse) response;
            String referer = req.getHeader("Referer");
            
            if (!StringUtils.isBlank(referer)) {
                
                //log.info("referer = " + referer);
                
                URL refererUrl = new URL(referer);
                String refererHost = refererUrl.getHost();
                int refererPort = refererUrl.getPort();
                String refererHostAndPort;
                if(refererPort == -1) {
                    refererHostAndPort = refererHost;
                }else {
                    refererHostAndPort = refererHost + ":" + refererPort;
                }
                
                //log.info("refererHostAndPort = " + refererHostAndPort);
                //log.info("refererHost = " + refererHost);
                
                String requestURL = req.getRequestURL().toString();
                //log.info("requestURL = " + requestURL);
                
                URL urlRequest = new URL(requestURL);
                String requestHost = urlRequest.getHost();
                int requestPort = urlRequest.getPort();
                String requestHostAndPort;
                if(requestPort == -1) {
                    requestHostAndPort = requestHost;
                }else {
                    requestHostAndPort = requestHost + ":" + requestPort;
                }
                
                //log.info("requestHost = " + requestHost);
                if(requestHostAndPort.equalsIgnoreCase(refererHostAndPort)) {//同域名和同端口,即同一个域的系统,通过
                    filterChain.doFilter(request, response);
                    
                }else {
                    
                    if(isCsrfWhiteDomains(refererHostAndPort)) {//域名白名单
                        filterChain.doFilter(request, response);
                        return;
                    }
                    
                    String path = urlRequest.getPath();
                    log.info("path = " + path);
                    String actionPath = path.replaceAll(request.getServletContext().getContextPath(), "");
                    log.info("actionPath = " + actionPath);
                        
                    if(isCsrfWhitePaths(actionPath)) {//访问路径白名单
                        filterChain.doFilter(request, response);
                        return;
                    }
                    
                    log.warn("csrf跨站点伪造请求已经被拦截:");
                    log.warn("requestURL = " + requestURL);
                    log.warn("referer = " + referer);
                    res.sendRedirect(req.getContextPath() + "/illegal");
                    return;
                }
                
                
            }else{
                filterChain.doFilter(request, response);
            }
        }
        
        @Override
        public void destroy() {
            
            log.info("destroy……");
        }  
        
        
        /**
         * 本系统不拦截的路径白名单
         * @param path
         * @return
         */
        private boolean isCsrfWhitePaths(String path) {
            
            if(csrfWhitePaths != null && csrfWhitePaths.length > 0) {
                for (String csrfWhitePath : csrfWhitePaths) {
                    if(!StringUtils.isBlank(csrfWhitePath)) {
                        if(csrfWhitePath.equals(path)) {
                            log.info("跨站点请求所有路径白名单:csrfWhitePaths = " + JSONUtil.toJsonStr(csrfWhitePaths));
                            log.info("符合跨站点请求路径白名单:path = " + path);
                            return true;
                        }
                    }
                }
            }
            return false;
        }
        
        
        /**
         * 不拦截外部系统的域名(可带端口)白名单
         * @param path
         * @return
         */
        private boolean isCsrfWhiteDomains(String refererHostAndPort) {
            
            if(csrfWhiteDomains != null && csrfWhiteDomains.length > 0) {
                for (String csrfWhiteDomain : csrfWhiteDomains) {
                    if(!StringUtils.isBlank(csrfWhiteDomain)) {
                        if(csrfWhiteDomain.equals(refererHostAndPort)) {
                            log.info("跨站点请求所有【域名】]白名单:csrfWhiteDomains = " + JSONUtil.toJsonStr(csrfWhiteDomains));
                            log.info("符合跨站点请求【域名】白名单:refererHost = " + refererHostAndPort);
                            return true;
                        }
                    }
                }
                log.info("跨站点请求非法【域名】:refererHost = " + refererHostAndPort);
            }
            return false;
        }
        
        
    }

    配置文件:

    #跨站点请求域名白名单,通过英文逗号分隔。如(abc.com:9010,abc.org:9010)
    csrf.white.domains=www.abc.com,abc.cn:9011
    #跨站点伪造请求
    #跨站点请求路径白名单,通过英文逗号分隔。如(/illegal,/illegal2,/illegal3)
    csrf.white.paths=/illegal

    三、SpringBoot注册过滤器

    import javax.servlet.Filter;
    
    import org.springframework.boot.web.servlet.FilterRegistrationBean;
    import org.springframework.context.annotation.Bean;
    import org.springframework.context.annotation.Configuration;
    
    import com.szpl.csgx.security.CsrfFilter;
    import com.szpl.csgx.security.XssFilter;
    
    /**
     * 使用配置方式开发Filter,否则其中的自动注入无效
     *
     */
    @Configuration
    public class HttpFilterConfig {
    
        /**
         * xss攻击过滤器
         * @return
         */
        @Bean
        public FilterRegistrationBean<Filter> xssFilter() {
            FilterRegistrationBean<Filter> XssBean = new FilterRegistrationBean<>(new XssFilter());
            XssBean.setName("xssFilter");
            XssBean.addUrlPatterns("/*");
            XssBean.setOrder(4);
            return XssBean;
        }
        
        
        /**
         * csrf跨站点欺骗过滤器
         */
        @Bean
        public FilterRegistrationBean<Filter> csrfFilterRegistrationBean(CsrfFilter csrfFilter) {
            FilterRegistrationBean<Filter> registration = new FilterRegistrationBean<Filter>();
            registration.setFilter(csrfFilter);//这里不能直接使用New,因为直接New出来的东西,CsrfFilter不受spring管理,不能通过@value注入变量
            registration.addUrlPatterns("/*");
    
            registration.setName("csrfFilter");
            registration.setOrder(0);
            return registration;
        }
    
    
        
    }

    (时间宝贵,分享不易,捐赠回馈,^_^)

    ================================

    ©Copyright 蕃薯耀 2021-05-07

    https://www.cnblogs.com/fanshuyao/

    今天越懒,明天要做的事越多。
  • 相关阅读:
    JDK1.8HashMap底层实现原理
    关于map转json,空key丢失的问题
    spring一些注解的使用及相关注解差异
    搭建基础项目遇到的一些小坑
    解析ftp上word文档的文字并输入
    R语言中回归模型预测的不同类型置信区间应用比较分析
    R语言中的广义线性模型(GLM)和广义相加模型(GAM):多元(平滑)回归分析保险资金投资组合信用风险敞口
    R语言对巨灾风险下的再保险合同定价研究案例:广义线性模型和帕累托分布Pareto distributions分析
    R语言中GLM(广义线性模型),非线性和异方差可视化分析
    如何用R语言绘制生成正态分布图表
  • 原文地址:https://www.cnblogs.com/fanshuyao/p/14741668.html
Copyright © 2020-2023  润新知