29

深入底层,仿MyBatis自己写框架

 4 years ago
source link: https://www.tuicool.com/articles/BzeMZry
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

前言:

最近研究了一下Mybatis的底层代码,写了一个操作数据库的小工具,实现了Mybatis的部分功能:

1.SQL语句在mapper.xml中配置。

2.支持int,String,自定义数据类型的入参。

3.根据mapper.xml动态创建接口的代理实现对象。

功能有限,目的是搞清楚MyBatis框架的底层思想,多学习研究优秀框架的实现思路,对提升自己的编码能力大有裨益。

小工具使用到的核心技术点: xml解析+反射+jdk动态代理

接下来,一步一步来实现。

首先来说为什么要使用jdk动态代理。

传统的开发方式:

1.接口定义业务方法。

2.实现类实现业务方法。

3.实例化实现类对象来完成业务操作。

接口:

public interface UserDAO {
    public User get(int id);
}

实现类:

public class UserDAOImpl implements UserDAO{

    @Override
    public User get(int id) {
        Connection conn = JDBCTools.getConnection();
        String sql = "select * from user where id = ?";
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = conn.prepareStatement(sql);
            pstmt.setInt(1, id);
            rs = pstmt.executeQuery();
            if(rs.next()){
                int sid = rs.getInt(1);
                String name = rs.getString(2);
                User user = new User(sid,name);
                return user;
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }finally{
            JDBCTools.release(conn, pstmt, rs);
        }
        return null;
    }

}

测试:

public static void main(String[] args) {

        UserDAO userDAO = new UserDAOImpl();
        User user = userDAO.get(1);
        System.out.println(user);

    }

Mybatis的方式:

1.开发者只需要创建接口,定义业务方法。

2. 不需要创建实现类。

3.具体的业务操作通过配置xml来完成。

接口:

public interface StudentDAO {
    public Student getById(int id);
    public Student getByStudent(Student student);
    public Student getByName(String name);
    public Student getByStudent2(Student student);
}

StudentDAO.xml:

<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> 
<mapper namespace="com.southwind.dao.StudentDAO"> 

    <select id="getById" parameterType="int" 
        resultType="com.southwind.entity.Student">
        select * from student where id=#{id}
    </select>

    <select id="getByStudent" parameterType="com.southwind.entity.Student" 
        resultType="com.southwind.entity.Student">
        select * from student where id=#{id} and name=#{name}
    </select>

    <select id="getByStudent2" parameterType="com.southwind.entity.Student" 
        resultType="com.southwind.entity.Student">
        select * from student where name=#{name} and tel=#{tel} 
    </select>

    <select id="getByName" parameterType="java.lang.String" 
        resultType="com.southwind.entity.Student">
        select * from student where name=#{name}
    </select>

</mapper>

测试:

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student stu = studentDAO.getById(1);
        System.out.println(stu);

    }

通过以上代码可以看到, MyBatis的方式省去了实现类的创建,改为用xml来定义业务方法的具体实现。

那么问题来了。

我们知道Java是面向对象的编程语言, 程序在运行时执行业务方法,必须要有实例化的对象。 但是,接口是不能被实例化的,而且也没有接口的实现类,那么此时这个对象从哪来呢?

程序在运行时,动态创建代理对象。

即jdk动态代理,运行时结合接口和mapper.xml来动态创建一个代理对象,程序调用该代理对象的方法来完成业务。

如何使用jdk动态代理?

创建一个类,实现InvocationHandler接口,该类就具备了创建动态代理对象的功能。

两个核心方法:

1.自定义getInstance方法:入参为目标对象,通过Proxy.newProxyInstance方法创建代理对象,并返回。

    public Object getInstance(Class cls){
        Object newProxyInstance = Proxy.newProxyInstance(  
                cls.getClassLoader(),  
                new Class[] { cls }, 
                this); 
        return (Object)newProxyInstance;
    }

2.实现接口的invoke方法,通过反射机制完成业务逻辑代码。

   @Override
    public Object invoke(Object proxy, Method method, Object[] args)
            throws Throwable {
        // TODO Auto-generated method stub
        return null;
    }

invoke方法是核心代码,在该方法中实现具体的业务需求。接下来我们来看如何实现。

既然是对数据库进行操作,则一定需要数据库连接对象,数据库相关信息配置在config.xml中。

所以invoke方法第一步,就是要解析config.xml,创建数据库连接对象,使用C3P0数据库连接池。

    //读取C3P0数据源配置信息
    public static Map<String,String> getC3P0Properties(){
        Map<String,String> map = new HashMap<String,String>();
        SAXReader reader = new SAXReader();
        try {
            Document document = reader.read("src/config.xml");
            //获取根节点
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element e = (Element) iter.next();
                //解析environments节点
                if("environments".equals(e.getName())){
                    Iterator iter2 = e.elementIterator();
                    while(iter2.hasNext()){
                        //解析environment节点
                        Element e2 = (Element) iter2.next();
                        Iterator iter3 = e2.elementIterator();
                        while(iter3.hasNext()){
                            Element e3 = (Element) iter3.next();
                            //解析dataSource节点
                            if("dataSource".equals(e3.getName())){
                                if("POOLED".equals(e3.attributeValue("type"))){
                                    Iterator iter4 = e3.elementIterator();
                                    //获取数据库连接信息
                                    while(iter4.hasNext()){
                                        Element e4 = (Element) iter4.next();
                                        map.put(e4.attributeValue("name"),e4.attributeValue("value"));
                                    }
                                }
                            }
                        }
                    }
                }
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return map; 
    }
//获取C3P0信息,创建数据源对象
Map<String,String> map = ParseXML.getC3P0Properties();
ComboPooledDataSource datasource = new ComboPooledDataSource();
datasource.setDriverClass(map.get("driver"));
datasource.setJdbcUrl(map.get("url"));
datasource.setUser(map.get("username"));
datasource.setPassword(map.get("password"));
datasource.setInitialPoolSize(20);
datasource.setMaxPoolSize(40);
datasource.setMinPoolSize(2);
datasource.setAcquireIncrement(5);
Connection conn = datasource.getConnection();

有了数据库连接,接下来就需要获取待执行的SQL语句,SQL的定义全部写在StudentDAO.xml中,继续解析xml,执行SQL语句。

SQL执行完毕,查询结果会保存在ResultSet中,还需要将ResultSet对象中的数据进行解析,封装到JavaBean中返回。

两步完成:

1.反射机制创建Student对象。

2.通过反射动态执行类中所有属性的setter方法,完成赋值。

这样就将ResultSet中的数据封装到JavaBean中了。

//获取sql语句
String sql = element.getText();
//获取参数类型
String parameterType = element.attributeValue("parameterType");
//创建pstmt
PreparedStatement pstmt = createPstmt(sql,parameterType,conn,args);
ResultSet rs = pstmt.executeQuery();
if(rs.next()){
    //读取返回数据类型
    String resultType = element.attributeValue("resultType");   
    //反射创建对象
    Class clazz = Class.forName(resultType);
    obj = clazz.newInstance();
    //获取ResultSet数据
    ResultSetMetaData rsmd = rs.getMetaData();
    //遍历实体类属性集合,依次将结果集中的值赋给属性
    Field[] fields = clazz.getDeclaredFields();
    for(int i = 0; i < fields.length; i++){
        Object value = setFieldValueByResultSet(fields[i],rsmd,rs);
        //通过属性名找到对应的setter方法
        String name = fields[i].getName();
        name = name.substring(0, 1).toUpperCase() + name.substring(1);
        String MethodName = "set"+name;
        Method methodObj = clazz.getMethod(MethodName,fields[i].getType());
        //调用setter方法完成赋值
        methodObj.invoke(obj, value);
        }
}

代码的实现大致思路如上所述,具体实现起来有很多细节需要处理。 使用到两个自定义工具类:ParseXML,MyInvocationHandler。

完整代码:

ParseXML

public class ParseXML {

    //读取C3P0数据源配置信息
    public static Map<String,String> getC3P0Properties(){
        Map<String,String> map = new HashMap<String,String>();
        SAXReader reader = new SAXReader();
        try {
            Document document = reader.read("src/config.xml");
            //获取根节点
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element e = (Element) iter.next();
                //解析environments节点
                if("environments".equals(e.getName())){
                    Iterator iter2 = e.elementIterator();
                    while(iter2.hasNext()){
                        //解析environment节点
                        Element e2 = (Element) iter2.next();
                        Iterator iter3 = e2.elementIterator();
                        while(iter3.hasNext()){
                            Element e3 = (Element) iter3.next();
                            //解析dataSource节点
                            if("dataSource".equals(e3.getName())){
                                if("POOLED".equals(e3.attributeValue("type"))){
                                    Iterator iter4 = e3.elementIterator();
                                    //获取数据库连接信息
                                    while(iter4.hasNext()){
                                        Element e4 = (Element) iter4.next();
                                        map.put(e4.attributeValue("name"),e4.attributeValue("value"));
                                    }
                                }
                            }
                        }
                    }
                }
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return map; 
    }

    //根据接口查找对应的mapper.xml
    public static String getMapperXML(String className){
        //保存xml路径
        String xml = "";
        SAXReader reader = new SAXReader();
        Document document;
        try {
            document = reader.read("src/config.xml");
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element mappersElement = (Element) iter.next();
                if("mappers".equals(mappersElement.getName())){
                    Iterator iter2 = mappersElement.elementIterator();
                    while(iter2.hasNext()){
                        Element mapperElement = (Element) iter2.next();
                        //com.southwin.dao.UserDAO . 替换 #
                        className = className.replace(".", "#");
                        //获取接口结尾名
                        String classNameEnd = className.split("#")[className.split("#").length-1];
                        String resourceName = mapperElement.attributeValue("resource");
                        //获取resource结尾名
                        String resourceName2 = resourceName.split("/")[resourceName.split("/").length-1];
                        //UserDAO.xml . 替换 #
                        resourceName2 = resourceName2.replace(".", "#");
                        String resourceNameEnd = resourceName2.split("#")[0];
                        if(classNameEnd.equals(resourceNameEnd)){
                            xml="src/"+resourceName;
                        }
                    }
                }
            }
        } catch (DocumentException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return xml;
    }
}

MyInvocationHandler:

public class MyInvocationHandler implements InvocationHandler{

    private String className;

    public Object getInstance(Class cls){
        //保存接口类型
        className = cls.getName();
        Object newProxyInstance = Proxy.newProxyInstance(  
                cls.getClassLoader(),  
                new Class[] { cls }, 
                this); 
        return (Object)newProxyInstance;
    }

    public Object invoke(Object proxy, Method method, Object[] args)  throws Throwable {        
        SAXReader reader = new SAXReader();
        //返回结果
        Object obj = null;
        try {
            //获取对应的mapper.xml
            String xml = ParseXML.getMapperXML(className);
            Document document = reader.read(xml);
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element element = (Element) iter.next();
                String id = element.attributeValue("id");
                if(method.getName().equals(id)){
                    //获取C3P0信息,创建数据源对象
                    Map<String,String> map = ParseXML.getC3P0Properties();
                    ComboPooledDataSource datasource = new ComboPooledDataSource();
                    datasource.setDriverClass(map.get("driver"));
                    datasource.setJdbcUrl(map.get("url"));
                    datasource.setUser(map.get("username"));
                    datasource.setPassword(map.get("password"));
                    datasource.setInitialPoolSize(20);
                    datasource.setMaxPoolSize(40);
                    datasource.setMinPoolSize(2);
                    datasource.setAcquireIncrement(5);
                    Connection conn = datasource.getConnection();
                    //获取sql语句
                    String sql = element.getText();
                    //获取参数类型
                    String parameterType = element.attributeValue("parameterType");
                    //创建pstmt
                    PreparedStatement pstmt = createPstmt(sql,parameterType,conn,args);
                    ResultSet rs = pstmt.executeQuery();
                    if(rs.next()){
                        //读取返回数据类型
                        String resultType = element.attributeValue("resultType");   
                        //反射创建对象
                        Class clazz = Class.forName(resultType);
                        obj = clazz.newInstance();
                        //获取ResultSet数据
                        ResultSetMetaData rsmd = rs.getMetaData();
                        //遍历实体类属性集合,依次将结果集中的值赋给属性
                        Field[] fields = clazz.getDeclaredFields();
                        for(int i = 0; i < fields.length; i++){
                            Object value = setFieldValueByResultSet(fields[i],rsmd,rs);
                            //通过属性名找到对应的setter方法
                            String name = fields[i].getName();
                            name = name.substring(0, 1).toUpperCase() + name.substring(1);
                            String MethodName = "set"+name;
                            Method methodObj = clazz.getMethod(MethodName,fields[i].getType());
                            //调用setter方法完成赋值
                            methodObj.invoke(obj, value);
                        }
                    }
                    conn.close();
                }
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

       return obj;
    }

    /**
     * 根据条件创建pstmt
     * @param sql
     * @param parameterType
     * @param conn
     * @param args
     * @return
     * @throws Exception
     */
    public PreparedStatement createPstmt(String sql,String parameterType,Connection conn,Object[] args) throws Exception{
        PreparedStatement pstmt = null;
        try {
            switch(parameterType){
                case "int":
                    int start = sql.indexOf("#{");
                    int end = sql.indexOf("}");
                    //获取参数占位符 #{name}
                    String target = sql.substring(start, end+1);
                    //将参数占位符替换为?
                    sql = sql.replace(target, "?");
                    pstmt = conn.prepareStatement(sql);
                    int num = Integer.parseInt(args[0].toString());
                    pstmt.setInt(1, num);
                    break;
                case "java.lang.String":
                    int start2 = sql.indexOf("#{");
                    int end2 = sql.indexOf("}");
                    String target2 = sql.substring(start2, end2+1);
                    sql = sql.replace(target2, "?");
                    pstmt = conn.prepareStatement(sql);
                    String str = args[0].toString();
                    pstmt.setString(1, str);
                    break;
                default:
                    Class clazz = Class.forName(parameterType);
                    Object obj = args[0];
                    boolean flag = true;
                    //存储参数
                    List<Object> values = new ArrayList<Object>();
                    //保存带#的sql
                    String sql2 = "";
                    while(flag){
                        int start3 = sql.indexOf("#{");
                        //判断#{}是否替换完成
                        if(start3<0){
                            flag = false;
                            break;
                        }
                        int end3 = sql.indexOf("}");
                        String target3 = sql.substring(start3, end3+1);
                        //获取#{}的值 如#{name}拿到name
                        String name = sql.substring(start3+2, end3);
                        //通过反射获取对应的getter方法
                        name = name.substring(0, 1).toUpperCase() + name.substring(1);
                        String MethodName = "get"+name;
                        Method methodObj = clazz.getMethod(MethodName);
                        //调用getter方法完成赋值
                        Object value = methodObj.invoke(obj);
                        values.add(value);
                        sql = sql.replace(target3, "?");
                        sql2 = sql.replace("?", "#");
                    }
                    //截取sql2,替换参数
                    String[] sqls = sql2.split("#");
                    pstmt = conn.prepareStatement(sql);
                    for(int i = 0; i < sqls.length-1; i++){
                        Object value = values.get(i);
                        if("java.lang.String".equals(value.getClass().getName())){
                            pstmt.setString(i+1, (String)value);
                        }
                        if("java.lang.Integer".equals(value.getClass().getName())){
                            pstmt.setInt(i+1, (Integer)value);
                        }
                    }
                    break;
                }
        } catch (SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return pstmt;
    }

    /**
     * 根据将结果集中的值赋给对应的属性
     * @param field
     * @param rsmd
     * @param rs
     * @return
     */
    public Object setFieldValueByResultSet(Field field,ResultSetMetaData rsmd,ResultSet rs){
        Object result = null;
        try {
            int count = rsmd.getColumnCount();
            for(int i=1;i<=count;i++){
                if(field.getName().equals(rsmd.getColumnName(i))){
                    String type = field.getType().getName();
                    switch (type) {
                        case "int":
                            result = rs.getInt(field.getName());
                            break;
                        case "java.lang.String":
                            result = rs.getString(field.getName());
                            break;
                    default:
                        break;
                    }
                }
            }
        } catch (SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return result;
    }


}

代码测试:

StudnetDAO.getById

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student stu = studentDAO.getById(1);
        System.out.println(stu);

    }

代码中的studentDAO为动态代理对象,此对象通过 MyInvocationHandler().getInstance(StudentDAO.class)方法动态创建, 并且结合StudentDAO.xml实现了StudentDAO接口的全部方法,直接调用studentDAO对象的方法即可完成业务需求。

VR3qmeF.jpg!web

StudnetDAO.getByName

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student stu = studentDAO.getByName("李四");
        System.out.println(stu);

    }

vMJraqZ.jpg!web

StudnetDAO.getByStudent(根据id和name查询)

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student student = new Student();
        student.setId(1);
        student.setName("张三");
        Student stu = studentDAO.getByStudent(student);
        System.out.println(stu);

    }

ZNfm63A.jpg!web

StudnetDAO.getByStudent2(根据name和tel查询)

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student student = new Student();
        student.setName("李四");
        student.setTel("18367895678");
        Student stu = studentDAO.getByStudent2(student);
        System.out.println(stu);

    }

yQnYRji.jpg!web

以上就是仿MyBatis实现自定义小工具的大致思路,细节之处还需具体查看源码,最后附上小工具源码链接。

源码:

链接:  https://pan.baidu.com/s/ 1pMz0FDh  

密码:  fnjb

iaiAR3V.png!web


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK