4

H2 Database hack —— 批量插入的猥琐实现

 2 years ago
source link: https://blog.yuantops.com/tech/h2_database_hack_batch_insert/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

H2 Database hack —— 批量插入的猥琐实现

H2 数据库是一款优秀的内存数据库,它具备几个特点:体积小,文档全,功能完善,而且是Java写的。

最近用到它这些优良特性,做内存计算。以内存模式启动了一个H2实例。接下来,要把外部数据导入H2数据库。这就面临一个问题:数据量大(几万+)的情况下,如何保证插入速度?

随便一种JDBC 持久层工具, 例如 JdbcTemplate, MyBatis,都封装了批量接口。怀着封装越少、效率越高的朴素信念,用H2原生JDBC Connection.insert() 方法,循环插入。2.7 万条数据,耗时约 3s。

另外,h2 database 官方有一种做法:把数据先导到 csv 文件,然后加载csv。虽没有实际验证这种方案,但纸上谈兵分析,即使数据加载变快,但增加了两次I/O。效果估计不会特别优秀。

同事脑洞大开:内存数据库插入语句,先是SQL解析,再把Java对象写进内存。既然都是Java 对象,能不能跳过SQL这一遭,直接写内存?

不经过JDBC,不经过SQL,这种思路也是不按常规出牌了。但原理非常说得通,而且肯定更快。

经过一步步断点调试,找到了关键类: org.h2.table.Table 。insert() 语句走到最后,是往table 里添加行(org.h2.result.Row)。换言之,只要拿到 table,又按格式构造行,就可以了。

  • 获取Table 按作者原意,应该是不希望使用者直接操作 Table 对象的。但是架不住我们猥琐啊,借助反射机制,什么都拿得到。 下面,是一步步抠出 Table 对象的实现。

     String sql = "select * from " + tableName;
    try (JdbcPreparedStatement ps = (JdbcPreparedStatement) connection.prepareStatement(sql)) {
        CommandContainer commandContainer = (CommandContainer) getFieldByForce(ps, JdbcPreparedStatement.class,
                "command");
        Session session = (Session) getFieldByForce(ps, JdbcPreparedStatement.class, "session");
        Select command = (Select) getFieldByForce(commandContainer, CommandContainer.class, "prepared");
        Table table = new ArrayList<>(command.getTables()).get(0);
    
  • 构造行 待插入的数据格式是Map, key是列名,value是值。对应到 org.h2.result.Row 的话 ,map每个entry对应一列。当然,涉及一些列名提取与转化,数据类型处理的工作。 下面是构造行的实现。

    Row newRow = table.getTemplateRow();
    Column[] columns = table.getColumns();
    for (Column c : columns) {
        int index = c.getColumnId();
        String columnName = c.getName();
        if (!map.containsKey(columnName)) {
            newRow.setValue(c.getColumnId(), ValueNull.INSTANCE);
        } else {
            Object value = map.get(columnName);
            if (value instanceof String) {
                newRow.setValue(index, ValueString.get(value.toString()));
            } else if (value instanceof Integer) {
                newRow.setValue(index, ValueInt.get((Integer) value));
            } else if (value instanceof Timestamp) {
                newRow.setValue(index, ValueTimestamp.get(TimeZone.getDefault(), (Timestamp) value));
            } else if (value instanceof BigDecimal) {
                newRow.setValue(index, ValueDecimal.get((BigDecimal) value));
            } else {
                // todo 类型还需充分枚举
                newRow.setValue(index, ValueString.get(value.toString()));
            }
        }
    
  • 提交插入 因为从 org.h2.engine.Session 剥离出了Table对象,而h2是支持事务的数据库,所以在插入结束后,还需要执行commit,让改变生效。

    session.commit(false);
    

2.7w 条数据,耗时 700ms。相比传统方案(2.7w条数据,3000ms),耗时减少了将近八成,颇为可观了。


import lombok.extern.slf4j.Slf4j;
import org.h2.command.CommandContainer;
import org.h2.command.dml.Select;
import org.h2.engine.Session;
import org.h2.jdbc.JdbcConnection;
import org.h2.jdbc.JdbcPreparedStatement;
import org.h2.result.Row;
import org.h2.table.Column;
import org.h2.table.Table;
import org.h2.value.*;
import org.springframework.util.ReflectionUtils;

import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;

@Slf4j
public class H2InsertUtil {

    public static void batchInsert(Connection toSqlSession, String tableName, List<Map<String, Object>> data) {
        assert isH2Dialect(toSqlSession);

        try {
            JdbcConnection connection = toSqlSession.unwrap(JdbcConnection.class);
            doBatchInsert(connection, tableName, data);
        } catch (SQLException e) {
            throw new RuntimeException("使用H2批量插入出错.", e);
        }
    }

    private static boolean isH2Dialect(Connection sqlSession) {
        try {
            return sqlSession.isWrapperFor(JdbcConnection.class);
        } catch (SQLException e) {
            log.warn("判断connection类型时出错", e);
            return false;
        }
    }

    private static void doBatchInsert(JdbcConnection connection, String tableName, List<Map<String, Object>> batchData) throws SQLException {
        String sql = "select * from " + tableName;
        try (JdbcPreparedStatement ps = (JdbcPreparedStatement) connection.prepareStatement(sql)) {
            CommandContainer commandContainer = (CommandContainer) getFieldByForce(ps, JdbcPreparedStatement.class,
                    "command");
            Session session = (Session) getFieldByForce(ps, JdbcPreparedStatement.class, "session");
            Select command = (Select) getFieldByForce(commandContainer, CommandContainer.class, "prepared");
            Table table = new ArrayList<>(command.getTables()).get(0);

            for (Map<String, Object> data : batchData) {
                Row newRow = createRow(table, data);
                table.addRow(session, newRow);
            }
            session.commit(false);
        } catch (Exception e) {
            log.error("", e);
            throw e;
        }
    }

    private static Object getFieldByForce(Object obj, Class<?> clazz, String fieldName) {
        Field field = ReflectionUtils.findField(clazz, fieldName);
        ReflectionUtils.makeAccessible(field);
        return ReflectionUtils.getField(field, obj);
    }

    private static Row createRow(Table table, Map<String, Object> map) {
        Row newRow = table.getTemplateRow();
        Column[] columns = table.getColumns();
        for (Column c : columns) {
            int index = c.getColumnId();
            String columnName = c.getName();
            if (!map.containsKey(columnName)) {
                newRow.setValue(c.getColumnId(), ValueNull.INSTANCE);
            } else {
                Object value = map.get(columnName);
                if (value instanceof String) {
                    newRow.setValue(index, ValueString.get(value.toString()));
                } else if (value instanceof Integer) {
                    newRow.setValue(index, ValueInt.get((Integer) value));
                } else if (value instanceof Timestamp) {
                    newRow.setValue(index, ValueTimestamp.get(TimeZone.getDefault(), (Timestamp) value));
                } else if (value instanceof BigDecimal) {
                    newRow.setValue(index, ValueDecimal.get((BigDecimal) value));
                } else {
                    // todo 类型还需充分枚举
                    newRow.setValue(index, ValueString.get(value.toString()));
                }
            }
        }
        return newRow;
    }
}

About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK