package hjp.smart4j.framework.helper; import hjp.smart4j.framework.util.CollectionUtil; import hjp.smart4j.framework.util.PropsUtil; import org.apache.commons.dbcp2.BasicDataSource; import org.apache.commons.dbutils.QueryRunner; import org.apache.commons.dbutils.handlers.BeanHandler; import org.apache.commons.dbutils.handlers.BeanListHandler; import org.apache.commons.dbutils.handlers.MapListHandler; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.sql.Connection; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Properties; /** * 数据库操作助手类 */ public class DatabaseHelper { private static final Logger LOGGER = LoggerFactory.getLogger(DatabaseHelper.class); private static final ThreadLocal<Connection> CONNECTION_HOLDER; private static final QueryRunner QUERY_RUNNER; private static final BasicDataSource DATA_SOURCE; static { CONNECTION_HOLDER = new ThreadLocal<Connection>(); QUERY_RUNNER = new QueryRunner(); DATA_SOURCE = new BasicDataSource(); Properties conf = PropsUtil.loadProps("config.properties"); String driver = conf.getProperty("jdbc.driver"); String url = conf.getProperty("jdbc.url"); String username = conf.getProperty("jdbc.username"); String password = conf.getProperty("jdbc.password"); DATA_SOURCE.setDriverClassName(driver); DATA_SOURCE.setUrl(url); DATA_SOURCE.setUsername(username); DATA_SOURCE.setPassword(password); } /** * 获取数据库连接 */ public static Connection getConnection() { Connection connection = CONNECTION_HOLDER.get(); if (connection == null) { try { connection = DATA_SOURCE.getConnection(); } catch (SQLException e) { LOGGER.error("get connection failure", e); throw new RuntimeException(e); } finally { CONNECTION_HOLDER.set(connection); } } return connection; } /** * 查询实体列表 */ public static <T> List<T> queryEntityList(Class<T> entityClass, String sql, Object... params) { List<T> entityList; try { Connection conn = getConnection(); entityList = QUERY_RUNNER.query(conn, sql, new BeanListHandler<T>(entityClass), params); } catch (SQLException e) { LOGGER.error("query entity list failure", e); throw new RuntimeException(e); } return entityList; } /** * 查询实体 */ public static <T> T queryEntity(Class<T> entityClass, String sql, Object... params) { T entity; try { Connection connection = getConnection(); entity = QUERY_RUNNER.query(connection, sql, new BeanHandler<T>(entityClass), params); } catch (SQLException e) { LOGGER.error("query entity failure", e); throw new RuntimeException(e); } return entity; } /** * 执行查询语句 */ public static List<Map<String, Object>> executeQuery(String sql, Object... params) { List<Map<String, Object>> result; try { Connection connection = getConnection(); result = QUERY_RUNNER.query(connection, sql, new MapListHandler(), params); } catch (Exception e) { LOGGER.error("execute query failure", e); throw new RuntimeException(e); } return result; } /** * 执行更新语句(update、insert、delete) */ public static int executeUpdate(String sql, Object... params) { int rows = 0; try { Connection connection = getConnection(); rows = QUERY_RUNNER.update(connection, sql, params); } catch (SQLException e) { LOGGER.error("execute update failure", e); throw new RuntimeException(e); } return rows; } /** * 插入实体 */ public static <T> boolean insertEntity(Class<T> entityClass, Map<String, Object> fieldMap) { if (CollectionUtil.isEmpty(fieldMap)) { LOGGER.error("can not insert entity: fieldMap is empty"); return false; } String sql = "insert into " + getTableName(entityClass); StringBuilder columns = new StringBuilder(" ("); StringBuilder values = new StringBuilder(" ("); for (String fieldName : fieldMap.keySet()) { columns.append(fieldName).append(", "); values.append("?, "); } columns.replace(columns.lastIndexOf(", "), columns.length(), ") "); values.replace(values.lastIndexOf(", "), values.length(), ") "); sql += columns + " values " + values; Object[] params = fieldMap.values().toArray(); return executeUpdate(sql, params) == 1; } /** * 更新实体 */ public static <T> boolean updateEntity(Class<T> entityClass, long id, Map<String, Object> fieldMap) { if (CollectionUtil.isEmpty(fieldMap)) { LOGGER.error("can not update entity: fieldMap is empty"); return false; } String sql = "update " + getTableName(entityClass) + " set "; StringBuilder columns = new StringBuilder(); for (String fieldName : fieldMap.keySet()) { columns.append(fieldName).append("=?, "); } sql += columns.substring(0, columns.lastIndexOf(", ")) + " where id=?"; List<Object> paramList = new ArrayList<Object>(); paramList.addAll(fieldMap.values()); paramList.add(id); Object[] params = paramList.toArray(); return executeUpdate(sql, params) == 1; } /** * 删除实体 */ public static <T> boolean deleteEntity(Class<T> entityClass, long id) { String sql = "delete from " + getTableName(entityClass) + " where id=?"; return executeUpdate(sql, id) == 1; } /** * 通过类名获取表名 */ private static String getTableName(Class<?> entityClass) { return entityClass.getSimpleName(); } /** * 执行SQL文件 */ public static void executeSqlFile(String filePath) { InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(filePath); BufferedReader reader = new BufferedReader(new InputStreamReader(is)); try { String sql; while ((sql = reader.readLine()) != null) { executeUpdate(sql); } } catch (IOException e) { LOGGER.error("execute sql file failure", e); throw new RuntimeException(e); } } /** * 开启事务 */ public static void beginTransaction() { Connection conn = getConnection(); if (conn != null) { try { conn.setAutoCommit(false); } catch (SQLException e) { LOGGER.error("begin transaction failure", e); throw new RuntimeException(e); } finally { CONNECTION_HOLDER.set(conn); } } } /** * 提交事务 */ public static void commitTransaction() { Connection conn = getConnection(); if (conn != null) { try { conn.commit(); conn.close(); } catch (SQLException e) { LOGGER.error("commit transaction failure", e); throw new RuntimeException(e); } finally { CONNECTION_HOLDER.remove(); } } } /** * 回滚事务 */ public static void rollbackTransaction() { Connection conn = getConnection(); if (conn != null) { try { conn.rollback(); conn.close(); } catch (SQLException e) { LOGGER.error("rollback transaction failure", e); throw new RuntimeException(e); } finally { CONNECTION_HOLDER.remove(); } } } }
需要关注的地方
1、ThreadLocal<Connection>
泛型内指定存储数据类型
2、数据库连接对象创建后不再关闭,交由BasicDataSource管理
3、java中泛型T和?
T用于指定参数类型或返回类型,泛型方法要在方法名前标记<T>
?用于泛型类型指定,如Class<?>,任意类型的Class类;List<?> list,任意类型集合