Mybatis: 动手封装ORM框架

2,804 阅读12分钟

ORM核心思想在于通过建立MODEL与数据库的映射来简化大量重复的工作量. 对于简单增删改查操作来说, 通过MODEL自动转换为SQL语句并执行可以节省很多工作量. 但是对于复杂的系统来说, 需要各种各样的复杂操作, 并且SQL也需要经过高度优化, 因此通过MODEL自动执行SQL并不可行.

Mybatis中通过所谓的半自动化解决了这一问题. 即手动书写SQL, 自动完成映射. 本章将实现一个简易的Mybatis.

设计思路

Mapper文件是Mybatis中非常重要的组件, 所有的SQL操作及映射方式都在Mapper文件中. 在程序(或项目)启动时加载所有的Mapper文件, 解析其中的SQL节点并保存. 当执行某项操作时从Mapper中找到对应的SQL, 完成参数映射后执行并返回执行结果.

组件与接口

在了解了基本的设计思路后, 下面开始对ORM框架中的组件及接口进行设计.

MappedStatement

SQL节点信息类, 负责保存Mapper中的SQL节点信息. 加载Mapper文件时, 将Mapper中的每个SQL节点进行解析并保存至MappedStatement中.

// SQL节点信息类
public class MappedStatement {

    // SQL节点ID
    private String id;

    // SQL语句
    private String sql;

    // SQL节点类型(select, insert, update, delete)
    private String statementType;

    // 返回值类型(类型为select时需要)
    private String resultType;

    // SQL中需要被映射复制的参数集合
    private List<String> parameters;

    // 实例化时必须传入SQL_ID
    public MappedStatement(String id) {
        this.id = id;
    }

    // Getter & Setter
    // ...

}

Configuration

全局配置类, 保存SQL节点信息及其他配置信息.

// 全局配置(保存SQL节点信息及其他配置)
public class Configuration {

    // 保存所有的SQL节点信息. k: SQL节点ID, v: SQL节点对象
    private Map<String, MappedStatement> mappedStatements = new HashMap<String, MappedStatement>();

    // 根据SQL节点ID获取SQL节点信息
    public MappedStatement getMappedStatement(String id) {
        return this.mappedStatements.get(id);
    }

    // 添加SQL节点
    public MappedStatement addMappedStatement(MappedStatement s) {
        return this.mappedStatements.put(s.getId(), s);
    }

}

SqlSession

SQL会话组件,封装了数据库操作. 对外提供增删改查的接口供使用者调用.

// SQL会话(提供SQL基本操作)
public class SqlSession {

    // 全局配置
    private Configuration config;

    // 实例化时必须传入全局配置
    public SqlSession(Configuration config) {
        this.config = config;
    }

    /**
     * 查询单条数据(查询结果必须为一条记录)
     * 
     * @param sqlId SQL节点ID
     * @return 查询结果对应的MODEL(resultType指定的类)
     */
    public <T> T selectOne(String sqlId) {
        return selectOne(sqlId, null);
    }

    /**
     * 查询单条数据(查询结果必须为一条记录)
     * 
     * @param sqlId SQL节点ID
     * @param args  执行SQL所需参数
     * @return 查询结果对应的MODEL(resultType指定的类)
     */
    public <T> T selectOne(String sqlId, Object args) {
        List<T> list = selectList(sqlId, null);
        if (list.size() > 1) {
            throw new RuntimeException("count 1111");
        }
        if (list.size() == 1) {
            return list.get(0);
        }
        return null;
    }

    /**
     * 查询多条数据
     * 
     * @param sqlId SQL节点ID
     * @return 查询结果对应的MODEL(resultType指定的类)集合
     */
    public <T> List<T> selectList(String sqlId) {
        return selectList(sqlId, null);
    }

    /**
     * 查询多条数据
     * 
     * @param sqlId SQL节点ID
     * @param args  执行SQL所需参数
     * @return 查询结果对应的MODEL(resultType指定的类)集合
     */
    public <T> List<T> selectList(String sqlId, Object args) {
        // TODO
        return null;
    }

    /**
     * 插入数据
     * 
     * @param sqlId SQL节点ID
     * @return 成功被插入的数据条数
     */
    public int insert(String sqlId) {
        return insert(sqlId, null);
    }

    /**
     * 插入数据
     * 
     * @param sqlId SQL节点ID
     * @param args  执行SQL所需参数
     * @return 成功被插入的数据条数
     */
    public int insert(String sqlId, Object args) {
        return update(sqlId, args);
    }

    /**
     * 删除数据
     * 
     * @param sqlId SQL节点ID
     * @return 成功被删除的数据条数
     */
    public int delete(String sqlId) {
        return delete(sqlId, null);
    }

    /**
     * 删除数据
     * 
     * @param sqlId SQL节点ID
     * @param args  执行SQL所需参数
     * @return 成功被删除的数据条数
     */
    public int delete(String sqlId, Object args) {
        return update(sqlId, args);
    }

    /**
     * 更新数据
     * 
     * @param sqlId SQL节点ID
     * @return 成功被更新的数据条数
     */
    public int update(String sqlId) {
        return update(sqlId, null);
    }

    /**
     * 更新数据
     * 
     * @param sqlId SQL节点ID
     * @param args  执行SQL所需参数
     * @return 成功被更新的数据条数
     */
    public int update(String sqlId, Object args) {
        // TODO
        return 0;
    }

}
  • insert, delete本质上是update,无需单独实现.
  • 根据主键查询只返回一条记录, 提供selectOne方法便于用户直接对MODEL操作(无需从集合获取在转换)

SqlSessionFactory SQL会话工厂类. 负责创建SQL会话.

// SQL会话工厂类(负责创建SQL会话)
public class SqlSessionFactory {

    // 全局配置
    private Configuration config;

    // 实例化时必须传入全局配置
    public SqlSessionFactory(Configuration config) {
        this.config = config;
    }

    // 获取SQL会话
    public SqlSession getSession() {
        return new SqlSession(config);
    }

}

SqlSessionFactoryBean

框架初始化入口, 负责解析Mapper文件并创建SQL会话工厂.

// 框架初始化入口(负责解析Mapper文件并创建SQL会话工厂)
public class SqlSessionFactoryBean {

    private static final Pattern PARAMETER_PATTERN = Pattern.compile("#\\{(.+?)\\}");

    // 全局配置
    private Configuration config = new Configuration();

    // Mapper文件路径
    private String mapperLocation;
    
    // 默认构造器
    public SqlSessionFactoryBean() {
    }

    // 通过Mapper文件路径实例化
    public SqlSessionFactoryBean(String mapperLocation) {
        this.setMapperLocation(mapperLocation);
    }
    // 设置Mapper文件
    public void setMapperLocation(String mapperLocation) {
        this.mapperLocation = mapperLocation;
    }

    // 构建SQL会话工厂
    public SqlSessionFactory build() {
        // TODO
        // 加载Mapper
        return new SqlSessionFactory(config);
    }

}

使用流程

  1. 程序(或项目)启动时, 实例化SqlSessionFactoryBean, 设置Mapper所在路径, 调用build方法构建SqlSessionFactory

  2. SqlSessionFactoryBean在build时根据Mapper路径加载并解析Mapper, 将Mapper下的每个SQL节点封装成MappedStatement对象. 并保存在全局配置Configuration中.

  3. 数据访问层(DAO)对数据库进行操作时, 通过SqlSessionFactory获取SqlSession, 调用其对应的方法并传入SQL ID.

  4. SqlSession中根据SQL ID在全局配置Configuration中找到对应的SQL节点信息对象MappedStatement.

  5. 将传入的参数按照MappedStatement中SQL的参数规则进行参数映射后生成可执行的SQL.

  6. 执行SQL后将执行结果返回. 如果是select操作, 将查询结果封装成指定对象后返回

代码实现

SqlSessionFactoryBean作为框架的初始化入口, 在build方法中加载配置文件, 解析Mapper并构建SQL会话工厂.

// 构建SQL会话工厂
public SqlSessionFactory build() {

    if (this.mapperLocation == null) {
        throw new RuntimeException("请设置Mapper文件所在路径");
    }

    // 获取Mapper文件
    List<File> mappers = getMapperFiles();

    // 加载所有Mapper并解析
    for (File mapper : mappers) {
        parseStatement(mapper);
    }

    // 返回SQL会话工厂
    return new SqlSessionFactory(config);

}

Mapper文件配置路径支持通配符(*), 例: /m/*_mapper.xml为Mapper文件位于classpath下的m文件内,并且以_mapper.xml结尾. 在加载Mapper文件时需要对通配符进行处理.

// 根据Mapper所在路径获取所有Mapper文件
private List<File> getMapperFiles() {

    List<File> mappers = new ArrayList<File>();

    String mapperDir; // Mapper文件目录
    String mapperName; // Mapper文件名称

    // 根据配置的Mapper路径分别获取文件目录及文件名称
    // 是否含有文件分隔符
    int lastPos = this.mapperLocation.lastIndexOf("/");
    // 含有文件分隔符
    if (lastPos > -1) {
        mapperDir = this.mapperLocation.substring(0, lastPos);
        mapperName = this.mapperLocation.substring(lastPos + 1);
    }
    // 无文件分隔符
    // 配置路径为Mapper文件名
    else {
        mapperDir = "";
        mapperName = this.mapperLocation;
    }

    // 获取Mapper目录下所有文件
    String classpath = ClassLoader.getSystemResource("").getPath();
    File[] allMappers = new File(classpath, mapperDir).listFiles();

    // *为通配符,将*转换为正则表达式通配符进行匹配
    // *_mapper.xml -> .+?_mapper.xml
    Pattern pattern = Pattern.compile(mapperName.replaceAll("\\*", ".+?"));

    // 遍历Mapper目录下所有文件
    for (File f : allMappers) {
        // 文件是否和指定的Mapper名称一致
        if (pattern.matcher(f.getName()).matches()) {
            mappers.add(f);
        }
    }

    return mappers;
}

解析Mapper文件中的所有SQL节点并保存. Mapper文件格式如下:

<mapper>
	
	<select id="selectGoods" resultType="com.atd681.xc.ssm.orm.test.Goods">
		select id, goods_name goodsName, category, price from t_orm_goods
	</select>
	
	<update id="updateGoods">
		UPDATE t_orm_goods 
		   SET goods_name = #{goodsName}, price = #{price}
		 WHERE id = #{id}
	</update>
	
</mapper> 

将每个SQL节点封装至MappedStatement对象中. 并统一保存至全局配置中.

// 解析Mapper文件
@SuppressWarnings("unchecked")
private void parseStatement(File mapper) {

    Document doc = null;
    // 使用JDom解析XML
    try {
        doc = new SAXBuilder().build(mapper);
    } catch (Exception e) {
        throw new RuntimeException("加载配置文件错误", e);
    }

    // Mapper下所有SQL节点
    List<Element> statementList = doc.getRootElement().getChildren();

    // 遍历Mapper下所有SQL节点
    for (Element statement : statementList) {

        // SQL节点ID
        String sqlId = statement.getAttributeValue("id");

        // SQL节点必须设置ID属性
        if (sqlId == null) {
            throw new RuntimeException("SQL节点需要设置id属性");
        }
        // SQL节点的ID不能重复
        if (config.getMappedStatement(sqlId) != null) {
            throw new RuntimeException("SQL节点id已经存在");
        }

        // 解析SQL节点
        MappedStatement ms = new MappedStatement(sqlId);

        ms.setSql(statement.getTextTrim());
        ms.setStatementType(statement.getName());
        ms.setResultType(statement.getAttributeValue("resultType"));

        // 解析SQL中的参数
        parseSqlAndParameters(ms);

        // 将SQL节点信息添加至全局配置中
        config.addMappedStatement(ms);

    }

}

SQL中如果含有需要被替换的参数时, 需要对SQL进行处理. 保存参数名称并将其替换成?, 以便使用JDBC执行时可以使用PrepareStatement进行赋值.

// 解析SQL中的参数
private void parseSqlAndParameters(MappedStatement ms) {

    List<String> parameters = new ArrayList<String>();
    StringBuffer sql = new StringBuffer();

    // 匹配SQL中的#{}
    Matcher m = PARAMETER_PATTERN.matcher(ms.getSql());

    // 将匹配到的#{}中的参数名称保存, 并替换为?
    // where u_name=#{UName} and u_age=#{UAge} -> where u_name=? and u_age=?
    // 执行SQL时在传入的参数中找到UName对第一个?赋值,UAge对第二个?赋值
    while (m.find()) {
        parameters.add(m.group(1));
        m.appendReplacement(sql, "?");
    }
    m.appendTail(sql);

    ms.setSql(sql.toString());
    ms.setParameters(parameters);

}

至此, Mapper文件已经解析完成. SQL会话工厂也已经创建完成. 在DAO中可以获取SQL会话并调用对应的数据库操作方法执行.

执行SQL时, 根据对SQL中的参数进行赋值后执行. 在解析Mapper时已经将SQL中的参数替换为?并且保存了参数的名称. 赋值时只需要依次将通过参数名称获取对应的参数值并且通过JDBC赋值到SQL中即可.

// 对参数进行赋值
private void setParameters(PreparedStatement ps, MappedStatement ms, Object args) throws Exception {
    if (args == null) {
        return;
    }
    List<String> parameters = ms.getParameters();
    if (parameters == null) {
        return;
    }
    // 依次根据参数名称从对应的MODEL中获取参数值替换SQL中的?
    for (int i = 0; i < parameters.size(); i++) {
        Object value = BeanUtil.getValue(args, parameters.get(i));
        ps.setObject(i + 1, value);
    }

}

DAO调用SQL会话的方法传入的参数支持以下几种类型

  • 只有一个参数, 可以传入字符串
  • 多个参数可以通过Map传入
  • 多个参数可以通过Javabean传入

在根据SQL参数名称获取对应值时需要对三种情况分别进行解析.

// Bean工具类
public class BeanUtil {

    // 从对象中获取执行属性的值
    @SuppressWarnings("rawtypes")
    public static Object getValue(Object bean, String name) throws Exception {
        // 字符串
        if (bean instanceof String) {
            return bean;
        }
        // Map
        if (bean instanceof Map) {
            return ((Map) bean).get(name);

        }
        // Javabean调用属性的Getter方法获取
        Class<?> clazz = bean.getClass();
        Method getter = clazz.getDeclaredMethod(getGetter(name), new Class<?>[] {});
        return getter.invoke(bean, new Object[] {});
    }

    // 获取Getter方法名. userName -> getUserName
    public static String getGetter(String name) {
        return "get" + capitalize(name);
    }

    // 获取Setter方法名. userName -> setUserName
    public static String getSetter(String name) {
        return "set" + capitalize(name);
    }

    // 首字母大写. userName -> UserName
    private static String capitalize(String name) {
        return name.substring(0, 1).toUpperCase() + name.substring(1);
    }

}

对于insert,update,delete三种操作来说,最终在JDBC中的执行方式一致.

/**
 * 更新数据
 * 
 * @param sqlId SQL节点ID
 * @param args  执行SQL所需参数
 * @return 成功被更新的数据条数
 */
public int update(String sqlId, Object args) {

    MappedStatement ms = config.getMappedStatement(sqlId);

    Connection conn = null;
    PreparedStatement ps = null;

    try {
        conn = getConnection();
        ps = getPreparedStatement(conn, ms);
        setParameters(ps, ms, args);
        return ps.executeUpdate();

    } catch (Exception e) {
        throw new RuntimeException("", e);
    } finally {
        close(ps, conn);
    }

}

对于select操作来说, 需要将查询结果封装至指定对象中.

/**
 * 查询多条数据
 * 
 * @param sqlId SQL节点ID
 * @param args  执行SQL所需参数
 * @return 查询结果对应的MODEL(resultType指定的类)集合
 */
public <T> List<T> selectList(String sqlId, Object args) {

    MappedStatement ms = config.getMappedStatement(sqlId);

    Connection conn = null;
    PreparedStatement ps = null;
    ResultSet rs = null;

    try {
        conn = getConnection();
        ps = getPreparedStatement(conn, ms);
        setParameters(ps, ms, args);
        rs = ps.executeQuery();
        
        // 查询结果封装至指定对象中
        return handleResultSet(rs, ms);

    } catch (Exception e) {
        throw new RuntimeException("", e);
    } finally {
        close(rs, ps, conn);
    }

}

查询结果映射的类为SQL节点中resultType属性定义的类. 查询结果的字段名和类中的属性名一致才自动赋值. 如果不一致可以在SQL中加入别名使其与类中属性名一致. 例: select u_name uName from ...

// 将查询结果集封装至对象中
@SuppressWarnings("unchecked")
private <T> List<T> handleResultSet(ResultSet rs, MappedStatement ms) throws Exception {

    List<T> list = new ArrayList<T>();

    // ResultSetMetaData对象保存查询到的数据库相关信息.
    ResultSetMetaData metaData = rs.getMetaData();

    while (rs.next()) {
        // 通过Java反射实例化对应的Javabean, 类型为SQL配置文件中的resultType
        // 如果不设置resultType,则无法知道返回值的类型.所以resultType必须要设置.
        Class<?> classObj = (Class<?>) Class.forName(ms.getResultType());
        T t = (T) classObj.newInstance();

        // 将每个字段的值映射到对应的对象中
        int count = metaData.getColumnCount();
        for (int i = 1; i <= count; i++) {

            // 取得字段对象Javabean中的属性名称
            String ormName = metaData.getColumnLabel(i);
            // 通过属性名称,用Java反射取得Javabean中set方法.
            // Javabean的定义为:所有属性为私有(private),提供共有(public)的get和set方法对其进行操作
            // set方法为设置该属性的方法.set方法格式为set+属性名(首字母大写),例属性为userName,set方法为setUserName()
            Class<?> filedType = classObj.getDeclaredField(ormName).getType();
            Method setter = classObj.getMethod(BeanUtil.getSetter(ormName), filedType);
            
            // 根据数据库字段的类型执行相应的set方法即可将字段值设置到属性中
            setter.invoke(t, getColumnValue(rs, i));

        }

        list.add(t);
    }

    return list;

}
// 根据数据库字段类型获取其在Java中对应类型的值
private Object getColumnValue(ResultSet rs, int index) throws Exception {

    int columnType = rs.getMetaData().getColumnType(index);

    if (columnType == Types.BIGINT) {
        return rs.getLong(index);
    } else if (columnType == Types.INTEGER) {
        return rs.getInt(index);
    } else if (columnType == Types.VARCHAR) {
        return rs.getString(index);
    } else if (columnType == Types.DATE || columnType == Types.TIME || columnType == Types.TIMESTAMP) {
        return rs.getDate(index);
    } else if (columnType == Types.DOUBLE) {
        return rs.getDouble(index);
    }

    return null;
}

测试

用户MODEL

// 用户MODEL
public class User {

    // id
    private Long id;

    // 用户名
    private String uname;

    // 用户年龄
    private Integer uage;

    // 用户地址
    private String uaddr;

    // 备注
    private String remark;

    public User() {
    }

    // 通过用户信息构造用户
    public User(Long id, String uname, Integer uage, String uaddr, String remark) {
        this.id = id;
        this.uname = uname;
        this.uage = uage;
        this.uaddr = uaddr;
        this.remark = remark;
    }
    
    // Getter & Setter
    // ...

}

用户DAO接口

// 用户DAO接口
public interface UserDAO {

    // 创建用户
    int insertUser(User user);

    // 更新用户
    int updateUser(User user);

    // 查询用户
    List<User> selectUser();

}

用户DAO实现类

// 用户DAO
public class UserDAOImpl extends BaseDAO implements UserDAO {

    // SQL会话工厂在父类中
    public UserDAOImpl(SqlSessionFactory sf) {
        super(sf);
    }

    // 添加用户
    public int insertUser(User user) {
        return getSqlSession().update("insertUser", user);
    }

    // 更新用户
    public int updateUser(User user) {
        return getSqlSession().update("updateUser", user);
    }

    // 查询用户
    public List<User> selectUser() {
        return getSqlSession().selectList("selectUser");
    }

}

为便于DAO操作, 所有DAO继承BaseDAO, 其中BaseDAO负责保存SQL会话工厂并提供获取SQL会话的方法.


// DAO基类
public class BaseDAO {

    // SQL会话工厂
    private SqlSessionFactory sf;

    // 通过SQL会话工厂实例化
    public BaseDAO(SqlSessionFactory sf) {
        this.sf = sf;
    }

    // 获取SQL会话
    protected SqlSession getSqlSession() {
        return sf.getSession();
    }

}

新建测试类, 设置Mapper位置并初始化ORM框架, 执行DAO的操作输出结果.

public class TestImpl {

    public static void main(String[] args) {

        // 创建SQL会话工厂
        SqlSessionFactory sf = new SqlSessionFactoryBean("*_mapper.xml").build();

        // 实例化DAO
        UserDAO userDAO = new UserDAOImpl(sf);

        // 调用DAO中方法
        int count = userDAO.insertUser(new User(1L, "zhangsan", 20, "sssss", "ok"));
        List<User> userList = userDAO.selectUser();

        // 输出查询结果
        System.out.println(count);

        for (User u : userList) {
            System.out.println("| " + u.getId() + " | " + u.getUname() + " | ");
        }

    }

}