package test.filter; import java.io.IOException; import java.util.Iterator; import java.util.Map; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import org.springframework.beans.BeanWrapper; import org.springframework.beans.BeansException; import org.springframework.web.filter.OncePerRequestFilter; import org.springframework.web.multipart.MultipartHttpServletRequest; import org.springframework.web.multipart.commons.CommonsMultipartResolver; /** * 使用Spring过滤器来过滤请求中的非法字符<br> * 如果请求被重定向,则在被重定向的控制器方法执行前此过滤器也会执行 * @author admin * */ public class CharacterFilter extends OncePerRequestFilter { // 如果使用CommonsMultipartResolver处理文件上传,并且表单类型为multipart/form-data // 则此处需使用CommonsMultipartResolver,其参数设置应与配置文件中保持一致 private CommonsMultipartResolver multipartResolver = null; /** * 过滤器加载时,initBeanWrapper(BeanWrapper)方法会在initFilterBean()方法之前加载<br> * 可以通过super.getFilterConfig().getInitParameter("param1")方法获取在web.xml中配置的init-param参数 */ @Override protected void initBeanWrapper(BeanWrapper bw) throws BeansException { String param1 = super.getFilterConfig().getInitParameter("param1"); System.out.println("param1:" + param1); super.initBeanWrapper(bw); } @Override protected void initFilterBean() throws ServletException { multipartResolver = new CommonsMultipartResolver(); multipartResolver.setMaxInMemorySize(104857600); multipartResolver.setDefaultEncoding("utf-8"); super.initFilterBean(); } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { //此处可通过配置参数判断是否需要过滤 ... HttpServletRequest httpRequest = (HttpServletRequest)request; // 此处使用httpRequest,直接使用request可能造成CharacterFilterRequestWrapper中request获取不到值 if(httpRequest.getContentType().toLowerCase().contains("multipart/form-data")){ MultipartHttpServletRequest resolveMultipart = multipartResolver.resolveMultipart(httpRequest); filterChain.doFilter(new CharacterFilterRequestWrapper(resolveMultipart), response); }else{ filterChain.doFilter(new CharacterFilterRequestWrapper(httpRequest), response); } } class CharacterFilterRequestWrapper extends HttpServletRequestWrapper { public CharacterFilterRequestWrapper(HttpServletRequest request) { super(request); } @Override public String getParameter(String name) { return super.getParameter(name); } @Override public String[] getParameterValues(String name) { return filterString(super.getParameterValues(name)); } @Override public Map<String, String[]> getParameterMap() { Map<String, String[]> map = super.getParameterMap(); if(map == null){ return null; } Iterator<String> it = map.keySet().iterator(); while(it.hasNext()){ String param = it.next(); String[] value = map.get(param); map.put(param, filterString(value)); } return map; } private String filterString(String value){ if(value == null){ return null; } // 此处可根据需要选择需要过滤的字符 value = value.replaceAll(" ", ""); value = value.replaceAll(" ", " "); value = value.replaceAll(">", ">"); value = value.replaceAll("<", "<"); value = value.replaceAll(""", """); return value; } private String[] filterString(String[] values){ if(values == null){ return null; } for (int i = 0; i < values.length; i++) { values[i] = filterString(values[i]); } return values; } } }