项目中需要用到包扫描的情况是很多的,一般是在项目初始化的时候,根据一些条件来对某个package下的类进行特殊处理。现在想实现的功能是,在一个filter或interceptor初始化的时候,扫描指定的一些package路径,遍历下面的每个class,找出method上使用了一个特殊注解的所有方法,然后缓存起来,当方法拦截器工作的时候,就不用再挨个判断方法是否需要拦截了
网上有很多自己编码实现scan package功能的例子,但是如果有工具已经帮你实现了,并且经受了普遍的验证,那么,自己造轮子的必要性就不大了
spring框架中有扫描包的类ClassPathBeanDefinitionScanner 里面的findCandidateComponents方法是我们进行改造的依据
/** * Scan the class path for candidate components. * @param basePackage the package to check for annotated classes * @return a corresponding Set of autodetected bean definitions */ public Set<BeanDefinition> findCandidateComponents(String basePackage) { Set<BeanDefinition> candidates = new LinkedHashSet<BeanDefinition>(); try { String packageSearchPath = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + resolveBasePackage(basePackage) + "/" + this.resourcePattern; Resource[] resources = this.resourcePatternResolver.getResources(packageSearchPath); boolean traceEnabled = logger.isTraceEnabled(); boolean debugEnabled = logger.isDebugEnabled(); for (Resource resource : resources) { if (traceEnabled) { logger.trace("Scanning " + resource); } if (resource.isReadable()) { try { MetadataReader metadataReader = this.metadataReaderFactory.getMetadataReader(resource); if (isCandidateComponent(metadataReader)) { ScannedGenericBeanDefinition sbd = new ScannedGenericBeanDefinition(metadataReader); sbd.setResource(resource); sbd.setSource(resource); if (isCandidateComponent(sbd)) { if (debugEnabled) { logger.debug("Identified candidate component class: " + resource); } candidates.add(sbd); } else { if (debugEnabled) { logger.debug("Ignored because not a concrete top-level class: " + resource); } } } else { if (traceEnabled) { logger.trace("Ignored because not matching any filter: " + resource); } } } catch (Throwable ex) { throw new BeanDefinitionStoreException( "Failed to read candidate component class: " + resource, ex); } } else { if (traceEnabled) { logger.trace("Ignored because not readable: " + resource); } } } } catch (IOException ex) { throw new BeanDefinitionStoreException("I/O failure during classpath scanning", ex); } return candidates; }
改造如下:
方法loadCheckClassMethods的入参是逗号分隔的包路径,如com.xx
利用Spring的
ResourcePatternResolver来寻找包下面的资源Resource,因为我们的扫描pattern是.class文件,所以这里的Resource就是class文件
protected static final String DEFAULT_RESOURCE_PATTERN = "**/*.class";
/** * 根据扫描包的配置 * 加载需要检查的方法 */ private void loadCheckClassMethods(String scanPackages) { String[] scanPackageArr = scanPackages.split(","); ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver(); MetadataReaderFactory metadataReaderFactory = new CachingMetadataReaderFactory(resourcePatternResolver); for (String basePackage : scanPackageArr) { if (StringUtils.isBlank(basePackage)) { continue; } String packageSearchPath = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + ClassUtils.convertClassNameToResourcePath(SystemPropertyUtils.resolvePlaceholders(basePackage)) + "/" + DEFAULT_RESOURCE_PATTERN; try { Resource[] resources = resourcePatternResolver.getResources(packageSearchPath); for (Resource resource : resources) { //检查resource,这里的resource都是class loadClassMethod(metadataReaderFactory, resource); } } catch (Exception e) { log.error("初始化SensitiveWordInterceptor失败", e); } } } /** * 加载资源,判断里面的方法 * * @param metadataReaderFactory spring中用来读取resource为class的工具 * @param resource 这里的资源就是一个Class * @throws IOException */ private void loadClassMethod(MetadataReaderFactory metadataReaderFactory, Resource resource) throws IOException { try { if (resource.isReadable()) { MetadataReader metadataReader = metadataReaderFactory.getMetadataReader(resource); if (metadataReader != null) { String className = metadataReader.getClassMetadata().getClassName(); try { tryCacheMethod(className); } catch (ClassNotFoundException e) { log.error("检查" + className + "是否含有需要信息失败", e); } } } } catch (Exception e) { log.error("判断类中的方法实现需要检测xxx失败", e); } } /** * 把action下面的所有method遍历一次,标记他们是否需要进行xxx验证 * 如果需要,放入cache中 * * @param fullClassName */ private void tryCacheMethod(String fullClassName) throws ClassNotFoundException { Class<?> clz = Class.forName(fullClassName); Method[] methods = clz.getDeclaredMethods(); for (Method method : methods) { if (method.getModifiers() != Modifier.PUBLIC) { continue; } if (CheckXXX.class.isAssignableFrom(CHECK_ANNOTATION)) { CheckXXX checkXXX = (CheckXXX) method.getAnnotation(CHECK_ANNOTATION); if (checkXXX != null && checkXXX.check()) { cache.put(fullClassName + "." + method.getName(), checkXXX); log.info("检测到需要检查xxx的方法:" + fullClassName + "." + method.getName()); } } } }
tryCacheMethod做的事就是缓存需要处理的public方法
经测试,这种方式可以取到web的class文件和jar包中的class文件
加强版,package父子集合判断
package utils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.core.io.Resource; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; import org.springframework.core.io.support.ResourcePatternResolver; import org.springframework.core.type.classreading.CachingMetadataReaderFactory; import org.springframework.core.type.classreading.MetadataReader; import org.springframework.core.type.classreading.MetadataReaderFactory; import org.springframework.util.SystemPropertyUtils; import java.io.IOException; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.Set; /** * Created by cdliujian1 on 2016/1/13. * 包工具,根据package路径,加载class */ public class PackageUtil { private final static Log log = LogFactory.getLog(PackageUtil.class); //扫描 scanPackages 下的文件的匹配符 protected static final String DEFAULT_RESOURCE_PATTERN = "**/*.class"; /** * 结合spring的类扫描方式 * 根据需要扫描的包路径及相应的注解,获取最终测method集合 * 仅返回public方法,如果方法是非public类型的,不会被返回 * 可以扫描工程下的class文件及jar中的class文件 * * @param scanPackages * @param annotation * @return */ public static Set<Method> findClassAnnotationMethods(String scanPackages, Class<? extends Annotation> annotation) { //获取所有的类 Set<String> clazzSet = findPackageClass(scanPackages); Set<Method> methods = new HashSet<Method>(); //遍历类,查询相应的annotation方法 for (String clazz : clazzSet) { try { Set<Method> ms = findAnnotationMethods(clazz, annotation); if (ms != null) { methods.addAll(ms); } } catch (ClassNotFoundException ignore) { } } return methods; } /** * 根据扫描包的,查询下面的所有类 * * @param scanPackages 扫描的package路径 * @return */ public static Set<String> findPackageClass(String scanPackages) { if (StringUtils.isBlank(scanPackages)) { return Collections.EMPTY_SET; } //验证及排重包路径,避免父子路径多次扫描 Set<String> packages = checkPackage(scanPackages); ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver(); MetadataReaderFactory metadataReaderFactory = new CachingMetadataReaderFactory(resourcePatternResolver); Set<String> clazzSet = new HashSet<String>(); for (String basePackage : packages) { if (StringUtils.isBlank(basePackage)) { continue; } String packageSearchPath = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + org.springframework.util.ClassUtils.convertClassNameToResourcePath(SystemPropertyUtils.resolvePlaceholders(basePackage)) + "/" + DEFAULT_RESOURCE_PATTERN; try { Resource[] resources = resourcePatternResolver.getResources(packageSearchPath); for (Resource resource : resources) { //检查resource,这里的resource都是class String clazz = loadClassName(metadataReaderFactory, resource); clazzSet.add(clazz); } } catch (Exception e) { log.error("获取包下面的类信息失败,package:" + basePackage, e); } } return clazzSet; } /** * 排重、检测package父子关系,避免多次扫描 * * @param scanPackages * @return 返回检查后有效的路径集合 */ private static Set<String> checkPackage(String scanPackages) { if (StringUtils.isBlank(scanPackages)) { return Collections.EMPTY_SET; } Set<String> packages = new HashSet<String>(); //排重路径 Collections.addAll(packages, scanPackages.split(",")); for (String pInArr : packages.toArray(new String[packages.size()])) { if (StringUtils.isBlank(pInArr) || pInArr.equals(".") || pInArr.startsWith(".")) { continue; } if (pInArr.endsWith(".")) { pInArr = pInArr.substring(0, pInArr.length() - 1); } Iterator<String> packageIte = packages.iterator(); boolean needAdd = true; while (packageIte.hasNext()) { String pack = packageIte.next(); if (pInArr.startsWith(pack + ".")) { //如果待加入的路径是已经加入的pack的子集,不加入 needAdd = false; } else if (pack.startsWith(pInArr + ".")) { //如果待加入的路径是已经加入的pack的父集,删除已加入的pack packageIte.remove(); } } if (needAdd) { packages.add(pInArr); } } return packages; } /** * 加载资源,根据resource获取className * * @param metadataReaderFactory spring中用来读取resource为class的工具 * @param resource 这里的资源就是一个Class * @throws IOException */ private static String loadClassName(MetadataReaderFactory metadataReaderFactory, Resource resource) throws IOException { try { if (resource.isReadable()) { MetadataReader metadataReader = metadataReaderFactory.getMetadataReader(resource); if (metadataReader != null) { return metadataReader.getClassMetadata().getClassName(); } } } catch (Exception e) { log.error("根据resource获取类名称失败", e); } return null; } /** * 把action下面的所有method遍历一次,标记他们是否需要进行敏感词验证 * 如果需要,放入cache中 * * @param fullClassName */ public static Set<Method> findAnnotationMethods(String fullClassName, Class<? extends Annotation> anno) throws ClassNotFoundException { Set<Method> methodSet = new HashSet<Method>(); Class<?> clz = Class.forName(fullClassName); Method[] methods = clz.getDeclaredMethods(); for (Method method : methods) { if (method.getModifiers() != Modifier.PUBLIC) { continue; } Annotation annotation = method.getAnnotation(anno); if (annotation != null) { methodSet.add(method); } } return methodSet; } public static void main(String[] args) { String packages = "com.a,com.ab,com.c,com.as.t,com.as,com.as.ta,com.at.ja,com.at.jc,com.at."; System.out.println("检测前的package: " + packages); System.out.println("检测后的package: " + StringUtils.join(checkPackage(packages), ",")); } }