本节介绍Spring的XML和注解加载Bean,手写简单的源码,仅供参考
/** * @description: spring的上下文 * @author: ZhuCJ * @date: 2020-08-27 12:32 */ public class SpringContext implements BaseFactory{ /** * 指定扫描的包名 */ private String packerName; /** * 指定spring 配置Bean的xml位置 */ private String[] xmlPath; public SpringContext(String packerName) { this.packerName = packerName; } public SpringContext(String[] xmlPath) { this.xmlPath = xmlPath; } public SpringContext(String packerName, String[] xmlPath) { this.packerName = packerName; this.xmlPath = xmlPath; } @Override public Object getBean(String beanName) { if (Objects.nonNull(this.xmlPath) && xmlPath.length>0){ //加载XML配置Bean loadXml(); } if (!StringUtils.isEmpty(this.packerName)){ //加载注解配置Bean loadAnnotation(); } return BEAN_MAP.get(beanName) ; } /** * 创建注解加载工厂 */ public void loadAnnotation(){ new AnnotationBeanFactory(this.packerName); } /** * 创建XML方式工厂 */ public void loadXml(){ for (String xmlPath:this.xmlPath){ new XmlBeanFactory(xmlPath); } } }
/** * @description: * @author: ZhuCJ * @date: 2020-08-27 10:16 */ public class AnnotationBeanFactory implements BaseFactory { private static final Logger logger = LoggerFactory.getLogger(AnnotationBeanFactory.class); private String packerName ; private static final String EXT = "class"; @Override public Object getBean(String beanName) { return BEAN_MAP.get(beanName); } public AnnotationBeanFactory(String packerName){ this.packerName = packerName; //加载注解Bean loadBean(); //加载注入的属性Bean loadInjectBean(); } /** * 加载bean到容器中 */ public void loadBean(){ //读取包名的路径 String packerPath = null; try { packerPath = getPkgPath(packerName); } catch (UnsupportedEncodingException e) { logger.info("文件路径编码异常:{}",e.getMessage()); throw new RuntimeException("packerName path error"); } logger.info("扫描文件目录的路径:{}", packerPath); // 查找包含Component注解的类 Map<Class<? extends Annotation>, Set<Class<?>>> classesMap = scanClassesByAnnotations(packerName, packerPath, true, Arrays.asList(ServiceTest.class)); if (classesMap.size() == 0){ logger.error("目录:{}下,未获取到需要加载的类", packerPath); return; } //标记的反射对象 Set<Class<?>> classSet = new HashSet<>(); classesMap.forEach((k, v) -> { classSet.addAll(v); }); //默认设置类名为类名小写 for (Class<?> classObj:classSet){ Object object = null; try { object = classObj.newInstance(); } catch (InstantiationException e) { throw new RuntimeException(classObj.getSimpleName()+ " create error"); } catch (IllegalAccessException e) { throw new RuntimeException(classObj.getSimpleName()+ " create error"); } BEAN_MAP.put(StringUtils.uncapitalize(classObj.getSimpleName()),object); } } /** * 加载注入的Bean属性 */ public void loadInjectBean(){ BEAN_MAP.forEach((k,v)->{ setAttributeValue(v); }); } /** * 根据包名获取包的URL * @param pkgName com.demo.controller * @return */ public static String getPkgPath(String pkgName) throws UnsupportedEncodingException { String pkgDirName = pkgName.replace('.', File.separatorChar); URL url = Thread.currentThread().getContextClassLoader().getResource(pkgDirName); return url == null ? null : URLDecoder.decode(url.getFile(), "UTF-8"); } /** * 获取指定包下包含指定注解的所有类对象的集合 * @param pkgName 包名(com.demo.controller) * @param pkgPath 包路径(/Users/xxx/workspace/java/project/out/production/classes/com/demo/controller) * @param recursive 是否递归遍历子目录 * @param targetAnnotations 指定注解 * @return 以注解和对应类集合构成的键值对 */ public static Map<Class<? extends Annotation>, Set<Class<?>>> scanClassesByAnnotations( String pkgName, String pkgPath, final boolean recursive, List<Class<? extends Annotation>> targetAnnotations){ Map<Class<? extends Annotation>, Set<Class<?>>> resultMap = new HashMap<>(16); Collection<File> allClassFile = getAllClassFile(pkgPath, recursive); for (File curFile : allClassFile){ try { Class<?> curClass = getClassObj(curFile, pkgPath, pkgName); for (Class<? extends Annotation> annotation : targetAnnotations){ if (curClass.isAnnotationPresent(annotation)){ if (!resultMap.containsKey(annotation)){ resultMap.put(annotation, new HashSet<Class<?>>()); } resultMap.get(annotation).add(curClass); } } } catch (ClassNotFoundException e) { logger.error("load class fail", e); } } return resultMap; } /** * 遍历指定目录下所有扩展名为class的文件 * @param pkgPath 包目录 * @param recursive 是否递归遍历子目录 * @return */ private static Collection<File> getAllClassFile(String pkgPath, boolean recursive){ File fPkgDir = new File(pkgPath); if (!(fPkgDir.exists() && fPkgDir.isDirectory())){ logger.error("the directory to package is empty: {}", pkgPath); return null; } return FileUtils.listFiles(fPkgDir, new String[]{EXT}, recursive); } /** * 加载类 * @param file * @param pkgPath * @param pkgName * @return * @throws ClassNotFoundException */ private static Class<?> getClassObj(File file, String pkgPath, String pkgName) throws ClassNotFoundException{ // 考虑class文件在子目录中的情况 String absPath = file.getAbsolutePath().substring(0, file.getAbsolutePath().length() - EXT.length() - 1); String className = absPath.substring(pkgPath.length()).replace(File.separatorChar, '.'); className = className.startsWith(".") ? pkgName + className : pkgName + "." + className; return Thread.currentThread().getContextClassLoader().loadClass(className); } /** * 属性赋值 * @param object */ private static void setAttributeValue(Object object){ Class<?> aClass = object.getClass(); Field[] declaredFields = aClass.getDeclaredFields(); for (Field field:declaredFields){ if (field.isAnnotationPresent(AutowiredTest.class)){ //默认取属性值类型小写为 BeanId String simpleName = field.getType().getSimpleName(); Object obj = BEAN_MAP.get(StringUtils.uncapitalize(simpleName)); //允许私有属性赋值 field.setAccessible(true); try { field.set(object,obj); } catch (IllegalAccessException e) { throw new RuntimeException(field.getName() +" attribute set value exception"); } } } } public static void main(String[] args) { BaseFactory baseFactory = new AnnotationBeanFactory("com.spring"); BaseFactory baseFactory1 = new XmlBeanFactory("/spring/test.xml"); Body body = (Body) baseFactory.getBean("body"); System.out.println("加载类:"+BaseFactory.BEAN_MAP.size()+"个"); Object apple = baseFactory.getBean("apple"); System.out.println(body); System.out.println(apple); Order order = new Order(); for (Field field:order.getClass().getDeclaredFields()){ System.out.println(field.getName()); System.out.println(field.getType().getSimpleName()); } } }
/** * @description: Xml方式创建Bean * @author: ZhuCJ * @date: 2020-08-26 12:43 */ public class XmlBeanFactory implements BaseFactory { /** * *******XML形式注册Bean * 1.指定Resources资源 Xml文件位置 * 2.加载Xml文件Document对象 拿到id 和classes值 * 3.反射创建对象 */ private String filePath; public XmlBeanFactory(String xmlPath){ this.filePath = xmlPath; loadBean(); } @Override public Object getBean(String beanName) { return BEAN_MAP.get(beanName); } public void loadBean(){ //读取resource资源的路径 String sysPath = ClassUtils.getDefaultClassLoader().getResource("").getPath(); String path = sysPath + filePath; //dom4j解析XML文件 SAXReader saxReader = new SAXReader(); Document read ; try { read = saxReader.read(new File(path)); } catch (Exception e) { throw new RuntimeException("file No Find"); } Element root; Element rootElement = read.getRootElement(); for (Iterator i = rootElement.elementIterator("bean");i.hasNext();){ root =(Element) i.next(); Attribute id = root.attribute("id"); Attribute aClass = root.attribute("class"); //利用反射创建对象 Class<?> beanClass; try { beanClass = Class.forName(aClass.getText()); } catch (ClassNotFoundException e) { throw new RuntimeException(id.getText()+"class not Found"); } Object object = null; try { object = beanClass.newInstance(); } catch (InstantiationException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } BeanInfo beanInfo = null; try { //获取bean对象信息 beanInfo = Introspector.getBeanInfo(beanClass); } catch (IntrospectionException e) { e.printStackTrace(); } //ben对象的属性描述信息 PropertyDescriptor[] propertyDescriptors = beanInfo.getPropertyDescriptors(); for (Iterator k = root.elementIterator("property");k.hasNext();){ Element propertyElem =(Element) k.next(); Attribute name = propertyElem.attribute("name"); Attribute value = propertyElem.attribute("value"); //判断属性名称,是否和name相等 for (PropertyDescriptor desc:propertyDescriptors){ if (desc.getName().equalsIgnoreCase(name.getText())){ Method writeMethod = desc.getWriteMethod(); try { //赋值 writeMethod.invoke(object,value.getValue()); } catch (IllegalAccessException e) { e.printStackTrace(); } catch (InvocationTargetException e) { e.printStackTrace(); } } } } BEAN_MAP.put(id.getText(),object); } } public static void main(String[] args) { BaseFactory baseFactory = new XmlBeanFactory("/spring/test.xml"); System.out.println(BaseFactory.BEAN_MAP.size()); } }
/** * @description: 基础工厂 * @author: ZhuCJ * @date: 2020-08-26 12:42 */ public interface BaseFactory { /** 静态存放Bean的容器 */ Map<String,Object> BEAN_MAP = new ConcurrentHashMap(); /** * 获取bean * @param beanName * @return */ Object getBean(String beanName); }
/** * @description: 订单服务 * @author: ZhuCJ * @date: 2020-08-27 13:18 */ public interface OrderService { /** * 通过ID查询订单信息 * @param id * @return */ Order selectById (String id); }
/** * @description: 订单服务 * @author: ZhuCJ * @date: 2020-08-27 13:18 */ @ServiceTest public class OrderServiceImpl implements OrderService { @AutowiredTest private Order order; @Override public Order selectById(String id) { return order; } }
/** * @description: 模拟 spring中属性Autowired注解 * @author: ZhuCJ * @date: 2020-08-27 11:56 */ @Target(ElementType.FIELD) @Retention(RetentionPolicy.RUNTIME) @Documented @Inherited public @interface AutowiredTest { }
/** * @description: 模拟spring的标记Service层注解 * @author: ZhuCJ * @date: 2020-08-27 10:17 */ @Target(ElementType.TYPE) @Retention(RetentionPolicy.RUNTIME) @Documented @Inherited public @interface ServiceTest { }
/** * @description: * @author: ZhuCJ 80004071 * @date: 2020-08-26 19:09 */ @Data @AllArgsConstructor @NoArgsConstructor @ToString public class Order { private String id; private Date createTime; private String orderNo; }
/** * @description: * @author: ZhuCJ 80004071 * @date: 2020-08-11 12:36 */ public class Main { public static void main(String[] args) { SpringContext springContext = new SpringContext("com.spring",new String[]{"/spring/test.xml"}); OrderService orderServiceImpl =(OrderService) springContext.getBean("orderServiceImpl"); Order order = orderServiceImpl.selectById("1"); System.out.println(order); } }
测试结果
我已经被创建
我已经被创建
Order(id=12121212, createTime=null, orderNo=T21000)
加载到容器中Bean数量:10
--------------------------------------------------------------------