深度阅读Spring5.x源码后,使用Java实现迷你版Spring的基本思路实践

342 阅读5分钟

看过Spring5源码的同学们,一开始肯定会边看边去用断点进行源码调试,调试来调试去时间长了肯定会晕车的,最有效的方式是先猜测后进行验证,当然猜测是建立在对阅读源码相当有经验的基础上的,否则也是一头雾水。对于使用很长时间Spring框架的开发者来说,应该对其架构和结构不会太陌生,可以大胆地进行猜测。

根据源码猜测,整理了迷你版Spring的基本实现思路,如下图:

一、web.xml配置文件

所有依赖于Web容器的项目,基本都是从web.xml文件开始的,首先我们先配置好web.xml的文件内容,代码如下:

<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns="http://xmlns.jcp.org/xml/ns/javaee"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/javaee http://xmlns.jcp.org/xml/ns/javaee/web-app_4_0.xsd"
         version="4.0">

    <display-name>mini-spring</display-name>
    <servlet>
        <servlet-name>symvc</servlet-name>
        <servlet-class>com.sy.sa.framework.servlet.SyDispatcherServlet</servlet-class>
        <init-param>
            <param-name>contextConfigLocation</param-name>
            <param-value>application.properties</param-value>
        </init-param>

        <load-on-startup>1</load-on-startup>
    </servlet>
    
    <servlet-mapping>
        <servlet-name>symvc</servlet-name>
        <url-pattern>/*</url-pattern>
    </servlet-mapping>
    
</web-app>

SyDispatcherServlet是模拟的Spring实现的核心类,后面会讲解具体的实现源码的。

二、配置application.properties

无论是xml、properties、yml都是配置文件的表现形式,无论格式怎样变化,其表现的内容大致上是没有什么变化的。具体的内容如下:

scanPackage=com.sy.sa

三、自定义注解Annotation

3.1 @SyAutoWired

package com.sy.sa.framework.annotation;

import java.lang.annotation.*;

@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SyAutoWired {
    String value() default "";
}

3.2 @SyController

package com.sy.sa.framework.annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SyController {
    String value() default "";
}

3.3 @SyRequestMapping

package com.sy.sa.framework.annotation;

import java.lang.annotation.*;

@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SyRequestMapping {
    String value() default "";
}

3.4 @SyRequestParam

package com.sy.sa.framework.annotation;

import java.lang.annotation.*;

@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SyRequestParam {
    String value() default "";
}

3.5 @SyService

package com.sy.sa.framework.annotation;

import java.lang.annotation.*;

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface SyService {
    String value() default "";
}

四、配置注解Annotation

4.1 IDemoService接口

public interface IDemoService {

    String get(String name);

}

4.2 DemoServiceImpl实现类

public class DemoService implements IDemoService {
    @Override
    public String get(String name) {
        return "Hello World, [" + name + "]";
    }
}

4.3 DemoController类

@SyController
@SyRequestMapping("/sy")
public class DemoController {

    @SyAutoWired
    private IDemoService demoService;

    @SyRequestMapping("/query")
    public void query(HttpServletRequest request, HttpServletResponse response,
                      @SyRequestParam("name") String name) {
        String result = demoService.get(name);
        try{
            response.getWriter().write(result);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}

五、自定义SyDispatcherServlet的实现

5.1 覆盖实现HttpServlet中的init()方法

/**
 * mini-spring的第一步初始化阶段
 * @param config
 * @throws ServletException
 */
@Override
public void init(ServletConfig config) throws ServletException {

    //1.加载application.properties配置文件
    doLoadConfig(config.getInitParameter("contextConfigLocation"));

    //2.扫描对应包下的所有类
    doScanner(contextConfig.getProperty("scanPackage"));

    //3.初始化扫描到的类,并将它们放入到IoC容器中
    doInstance();

    //4.完成DI依赖注入
    doAutowried();

    //5.初始化HandlerMapping
    initHandlerMapping();

    System.out.println("Sy Spring framework is init.");
}

/**
 * 初始化url和Method的一一对应的关系
 */
private void initHandlerMapping() {
    if(ioc.isEmpty()) {
        return;
    }

    for (Map.Entry<String, Object> entry : ioc.entrySet()
         ) {
        Class<?> clazz = entry.getValue().getClass();
        if(!clazz.isAnnotationPresent(SyController.class)) {
            continue;
        }
        String baseUrl = "";

        if(clazz.isAnnotationPresent(SyRequestMapping.class)) {
            SyRequestMapping requestMapping = clazz.getAnnotation(SyRequestMapping.class);
            baseUrl = requestMapping.value();
        }

        //获取所有的public方法
        for (Method method :
             clazz.getMethods()) {
            if(method.isAnnotationPresent(SyRequestMapping.class)) {
                SyRequestMapping requestMapping = method.getAnnotation(SyRequestMapping.class);
                // //sy///query
                String regex = ("/" + baseUrl + "/" + requestMapping.value())
                        .replaceAll("/+","/");
                Pattern pattern = Pattern.compile(regex);

                handlerMapping.add(new Handler(pattern, entry.getValue(), method));

                System.out.println("Mapped :" + pattern + "," + method);
            }
        }
    }
}

/**
 * 将容器中的类进行依赖注入
 */
private void doAutowried() {

    if(ioc.isEmpty()) {
        return;
    }

    for (Map.Entry<String, Object> entry:
         ioc.entrySet()) {
        /**
         * Declared 所有的,特定的 字段,包括private/protected/default
         * 正常来说,普通的OOP编程只能拿到public的属性
         */
        Field[] fields = entry.getValue().getClass().getDeclaredFields();
        for(Field field: fields) {
            if(!field.isAnnotationPresent(SyAutoWired.class)) {
                continue;
            }
            SyAutoWired autoWired = field.getAnnotation(SyAutoWired.class);
            String beanName = autoWired.value().trim();
            if("".equals(beanName)) {
                //获得接口的类型,作为key待会拿这个key到ioc容器中去取值
                beanName = field.getType().getName();
            }

            /**
             * 如果是public以外的修饰符,只要加了@Autowired注解,都要强制赋值
             * 反射中叫做暴力访问, 强吻
             */
            field.setAccessible(true);

            try {
                field.set(entry.getValue(), ioc.get(beanName));
            } catch (IllegalAccessException e) {
                e.printStackTrace();
            }
        }
    }

}

/**
 * 初始化类放入IoC容器中,为依赖注入做准备
 */
private void doInstance() {
    if (classNames.isEmpty()) {
        return;
    }
    try {
        for (String className :
                classNames) {
            Class<?> clazz = Class.forName(className);
            //根据注解实例化对应的类
            if(clazz.isAnnotationPresent(SyController.class)) {
                Object instance = clazz.newInstance();
                //Spring默认类名首字母小写
                String beanName = this.toLowerFirstCase(clazz.getSimpleName());
                //获取到的类的实例放入到IoC容器中
                ioc.put(beanName, instance);
            }else if(clazz.isAnnotationPresent(SyService.class)) {
                //1、自定义的beanName
                SyService service = clazz.getAnnotation(SyService.class);
                String beanName = service.value();
                //2、默认类名首字母小写
                if("".equals(beanName.trim())) {
                    beanName = toLowerFirstCase(clazz.getSimpleName());
                }
                Object instance = clazz.newInstance();
                ioc.put(beanName, instance);
                //3、根据类型自动赋值,投机取巧的方式
                for (Class<?> cls:
                clazz.getInterfaces()) {
                    if(ioc.containsKey(cls.getName())) {
                        throw new Exception("The “" + cls.getName() + "” is exists!!");
                    }
                    //把接口的类型直接当成key了
                    ioc.put(cls.getName(), instance);
                }
            }else {
                continue;
            }
        }
    }catch (ClassNotFoundException e) {
        e.printStackTrace();
    } catch (IllegalAccessException e) {
        e.printStackTrace();
    } catch (InstantiationException e) {
        e.printStackTrace();
    } catch (Exception e) {
        e.printStackTrace();
    }
}

/**
 * 根据包路径扫描相关的类
 * @param scanPackage
 */
private void doScanner(String scanPackage) {
    //将包路径中的“.”替换为“/”
    URL url = this.getClass().getClassLoader().getResource("/" +
            scanPackage.replaceAll("\\.", "/"));
    File classPath = new File(url.getFile());
    for (File file : classPath.listFiles()
         ) {
        if(file.isDirectory()) {
            doScanner(scanPackage + "." + file.getName());
        }else {
            if(!file.getName().endsWith(".class")) {
                continue;
            }
            String className = (scanPackage + "."
                    + file.getName().replace(".class", ""));
            //将完整类名保存到List中
            classNames.add(className);
        }
    }
}

/**
 * 加载配置文件application.properties,放到properties中
 * @param contextConfigLocation
 */
private void doLoadConfig(String contextConfigLocation) {
    InputStream is = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation);
    try {
        contextConfig.load(is);
    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        if(null != is) {
            try {
                is.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }
}

/**
 * 将类名首字母变成小写字母
 * @param simpleName
 * @return
 */
private String toLowerFirstCase(String simpleName) {
    char[] chars = simpleName.toCharArray();
    /**
     * 之所以加,是因为大小写字母的ASCII码相差32,
     * 而且大写字母的ASCII码要小于小写字母的ASCII码
     * 在Java中对char做算学运算,实际上就是对ASCII码做算学运算
     */
    chars[0] +=  32;
    return String.valueOf(chars);
}

5.2 Handler类

保存url和method的对应关系,代码如下:

package com.sy.sa.framework.handler;

import com.sy.sa.framework.annotation.SyRequestParam;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;

/**
 * 保存一个url和一个method的关系
 */
public class Handler {

    private Pattern pattern;
    private Method method;
    private Object controller;
    private Class<?>[] paramTypes;

    /**
     * 形参列表
     * 参数的名字作为key,参数的顺序,位置作为值
     */
    private Map<String, Integer> paramIndexMapping;

    public Pattern getPattern() {
        return pattern;
    }

    public Method getMethod() {
        return method;
    }

    public Object getController() {
        return controller;
    }

    public Class<?>[] getParamTypes() {
        return paramTypes;
    }

    public Handler(Pattern pattern, Object controller, Method method) {
        this.pattern = pattern;
        this.method = method;
        this.controller = controller;

        paramTypes = method.getParameterTypes();

        paramIndexMapping = new HashMap<String, Integer>();

        this.putParamIndexMapping(method);
    }

    private void putParamIndexMapping(Method method) {
        /**
         * 根据方法得到注解是一个二维数组
         * 一个参数可以有多个注解,而一个方法又有多个参数
         */
        Annotation[][] result = method.getParameterAnnotations();
        for(int i = 0; i < result.length; i++) {
            for (Annotation annotation : result[i]) {
                if(annotation instanceof SyRequestParam) {
                    String paramName = ((SyRequestParam) annotation).value();
                    if(!"".equals(paramName.trim())) {
                        paramIndexMapping.put(paramName, i);
                    }
                }
            }
        }
        /**
         * 提取方法中的request、response参数
         */
        Class<?>[] paramTypes = method.getParameterTypes();
        for(int i = 0 ; i < paramTypes.length; i++) {
            Class<?> type = paramTypes[i];
            if(type == HttpServletRequest.class ||
                        type == HttpServletResponse.class) {
                paramIndexMapping.put(type.getName(), i);
            }
        }
    }


}

5.3 doPost()/doGet()方法

 @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        this.doPost(req, resp);
    }

@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
    //6、调用,运行阶段
    try {
        doDispatch(req,resp);
    } catch (Exception e) {
        e.printStackTrace();
        resp.getWriter().write("500 Exection,Detail : " + Arrays.toString(e.getStackTrace()));
    }
}

private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception {
    Handler handler = this.getHandler(req);
    if(handler == null) {
        resp.getWriter().write("404 Not Found!!!");
        return;
    }

    Map<String, Integer> paramIndexMapping = handler.getParamIndexMapping();

    //获得方法的形参列表
    Class<?> [] paramTypes = handler.getParamTypes();

    Object [] paramValues = new Object[paramTypes.length];

    Map<String,String[]> params = req.getParameterMap();
    for (Map.Entry<String, String[]> parm : params.entrySet()) {
        String value = Arrays.toString(parm.getValue()).replaceAll("\\[|\\]","")
                .replaceAll("\\s",",");
        if(!paramIndexMapping.containsKey(parm.getKey())){continue;}

        int index = paramIndexMapping.get(parm.getKey());
        paramValues[index] = convert(paramTypes[index],value);
    }

    if(paramIndexMapping.containsKey(HttpServletRequest.class.getName())) {
        int reqIndex = paramIndexMapping.get(HttpServletRequest.class.getName());
        paramValues[reqIndex] = req;
    }

    if(paramIndexMapping.containsKey(HttpServletResponse.class.getName())) {
        int respIndex = paramIndexMapping.get(HttpServletResponse.class.getName());
        paramValues[respIndex] = resp;
    }

    Object returnValue = handler.getMethod().invoke(handler.getController(),paramValues);
    if(returnValue == null || returnValue instanceof Void){ return; }
    resp.getWriter().write(returnValue.toString());
}

doPost()方法中使用了委派模式,委派模式的代码逻辑在doDispath()方法中。

至此,迷你版Spring的基本思路代码实践就完成了。