◆ 特別說明CountDownLatch
CountDownLatch是一個類springboot自帶的類,可以直接用,變量AtomicBoolean 也是可以直接使用
◆ CountDownLatch的用法
CountDownLatch典型用法:
1、某一線程在開始運行前等待n個線程執行完畢。 將CountDownLatch的計數器初始化為new CountDownLatch(n),每當一個任務線程執行完畢,就將計數器減1 countdownLatch.countDown(),當計數器的值變為0時,在CountDownLatch上awAIt()的線程就會被喚醒。一個典型應用場景就是啟動一個服務時,主線程需要等待多個組件加載完畢,之后再繼續執行。
2、實現多個線程開始執行任務的最大并行性。 注意是并行性,不是并發,強調的是多個線程在某一時刻同時開始執行。類似于賽跑,將多個線程放到起點,等待發令槍響,然后同時開跑。做法是初始化一個共享的CountDownLatch(1),將其計算器初始化為1,多個線程在開始執行任務前首先countdownlatch.await(),當主線程調用countDown()時,計數器變為0,多個線程同時被喚醒。
◆ CountDownLatch(num) 簡單說明
new 一個 CountDownLatch(num) 對象
建立對象的時候 num 代表的是需要等待 num 個線程
// 建立對象的時候 num 代表的是需要等待 num 個線程
//主線程
CountDownLatch mainThreadLatch = new CountDownLatch(num);
//子線程
CountDownLatch rollBackLatch = new CountDownLatch(1);
◆ 主線程:mainThreadLatch.await() 和mainThreadLatch.countDown()
新建對象
CountDownLatch mainThreadLatch = new CountDownLatch(num);
卡住主線程,讓其等待子線程,代碼mainThreadLatch.await(),放在主線程里
mainThreadLatch.await();
代碼mainThreadLatch.countDown(),放在子線程里,每一個子線程運行一到這個代碼,意味著CountDownLatch(num),里面的num-1(自動減一)
mainThreadLatch.countDown();
CountDownLatch(num)里面的num減到0,也就是CountDownLatch(0),被卡住的主線程mainThreadLatch.await(),就會往下執行
◆ 子線程:rollBackLatch.await() 和rollBackLatch.countDown()
新建對象,特別注意:子線程這個num就是1(關于只能為1的解答在后面)
CountDownLatch rollBackLatch = new CountDownLatch(1);
卡住子線程,阻止每一個子線程的事務提交和回滾
rollBackLatch.await();
代碼rollBackLatch.countDown();放在主線程里,而且是放在主線程的等待代碼mainThreadLatch.await();后面。
rollBackLatch.countDown();
為什么所有的子線程會在一瞬間就被所有都釋放了?
◆ 事務的回滾是怎么結合進去的?
假設總共20個子線程,那么其中一個線程報錯了怎么實現所有線程回滾。
引入變量
AtomicBoolean rollbackFlag = new AtomicBoolean(false)
和字面意思是一樣的:根據 rollbackFlag 的true或者false 判斷子線程里面,是否回滾。
首先我們確定的一點:rollbackFlag 是所有的子線程都用著這一個判斷
主線程類Entry
package org.Apache.dolphinscheduler.api.utils;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.dolphinscheduler.api.controller.WorkThread;
import org.apache.dolphinscheduler.common.enums.DbType;
import org.springframework.web.bind.annotation.*;
import JAVA.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.TimeZone;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
@RestController
@RequestMApping("importDatabase")
public class Entry {
/**
* @param dbid 數據庫的id
* @param tablename 表名
* @param sftpFileName 文件名稱
* @param head 是否有頭文件
* @param splitSign 分隔符
* @param type 數據庫類型
*/
private static String SFTP_HOST = "192.168.1.92";
private static int SFTP_PORT = 22;
private static String SFTP_USERNAME = "root";
private static String SFTP_PASSword = "rootroot";
private static String SFTP_BASEPATH = "/opt/testSFTP/";
@PostMapping("/thread")
@ResponseBody
public static JSONObject importDatabase(@RequestParam("dbid") int dbid
,@RequestParam("tablename") String tablename
,@RequestParam("sftpFileName") String sftpFileName
,@RequestParam("head") String head
,@RequestParam("splitSign") String splitSign
,@RequestParam("type") DbType type
,@RequestParam("heads") String heads
,@RequestParam("scolumns") String scolumns
,@RequestParam("tcolumns") String tcolumns ) throws Exception {
JSONObject obForRetrun = new JSONObject();
try {
JSONArray jsonArray = JSONArray.parseArray(tcolumns);
JSONArray scolumnArray = JSONArray.parseArray(scolumns);
JSONArray headsArray = JSONArray.parseArray(heads);
List<Integer> listInteger = getRrightDataNum(headsArray,scolumnArray);
JSONArray bodys = SFTPUtils.getSftpContent(SFTP_HOST,SFTP_PORT,SFTP_USERNAME,SFTP_PASSWORD,SFTP_BASEPATH,sftpFileName,head,splitSign);
int total = bodys.size();
int num = 20; //一個批次的數據有多少
int count = total/num;//周期
int lastNum =total- count*num;//余數
List<Thread> list = new ArrayList<Thread>();
SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS");
TimeZone t = sdf.getTimeZone();
t.setRawOffset(0);
sdf.setTimeZone(t);
Long startTime=System.currentTimeMillis();
int countForCountDownLatch = 0;
if(lastNum==0){//整除
countForCountDownLatch= count;
}else{
countForCountDownLatch= count + 1;
}
//子線程
CountDownLatch rollBackLatch = new CountDownLatch(1);
//主線程
CountDownLatch mainThreadLatch = new CountDownLatch(countForCountDownLatch);
AtomicBoolean rollbackFlag = new AtomicBoolean(false);
StringBuffer message = new StringBuffer();
message.append("報錯信息:");
//子線程
for(int i=0;i<count;i++) {//這里的count代表有幾個線程
Thread g = new Thread(new WorkThread(i,num,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message ));
g.start();
list.add(g);
}
if(lastNum!=0){//有小數的情況下
Thread g = new Thread(new WorkThread(0,lastNum,tablename,jsonArray,dbid,type,bodys,listInteger,mainThreadLatch,rollBackLatch,rollbackFlag,message ));
g.start();
list.add(g);
}
// for(Thread thread:list){
// System.out.println(thread.getState());
// thread.join();//是等待這個線程結束;
// }
mainThreadLatch.await();
//所有等待的子線程全部放開
rollBackLatch.countDown();
//是主線程等待子線程的終止。也就是說主線程的代碼塊中,如果碰到了t.join()方法,此時主線程需要等待(阻塞),等待子線程結束了(Waits for this thread to die.),才能繼續執行t.join()之后的代碼塊。
Long endTime=System.currentTimeMillis();
System.out.println("總共用時: "+sdf.format(new Date(endTime-startTime)));
if(rollbackFlag.get()){
obForRetrun.put("code",500);
obForRetrun.put("msg",message);
}else{
obForRetrun.put("code",200);
obForRetrun.put("msg","提交成功!");
}
obForRetrun.put("data",null);
}catch (InterruptedException e){
e.printStackTrace();
obForRetrun.put("code",500);
obForRetrun.put("msg",e.getMessage());
obForRetrun.put("data",null);
}
return obForRetrun;
}
/**
* 文件里第幾列被作為導出列
* @param headsArray
* @param scolumnArray
* @return
*/
public static List<Integer> getRrightDataNum(JSONArray headsArray, JSONArray scolumnArray){
List<Integer> list = new ArrayList<Integer>();
String arrayA [] = new String[headsArray.size()];
for(int i=0;i<headsArray.size();i++){
JSONObject ob = (JSONObject)headsArray.get(i);
arrayA[i] =String.valueOf(ob.get("title"));
}
String arrayB [] = new String[scolumnArray.size()];
for(int i=0;i<scolumnArray.size();i++){
JSONObject ob = (JSONObject)scolumnArray.get(i);
arrayB[i] =String.valueOf(ob.get("columnName"));
}
for(int i =0;i<arrayA.length;i++){
for(int j=0;j<arrayB.length;j++){
if(arrayA[i].equals(arrayB[j])){
list.add(i);
break;
}
}
}
return list;
}
}
子線程類WorkThread
package org.apache.dolphinscheduler.api.controller;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import org.apache.dolphinscheduler.api.service.DataSourceService;
import org.apache.dolphinscheduler.common.enums.DbType;
import org.apache.dolphinscheduler.dao.entity.DataSource;
import org.apache.dolphinscheduler.dao.mapper.DataSourceMapper;
import org.apache.dolphinscheduler.service.bean.SpringApplicationContext;
import org.springframework.transaction.PlatformTransactionManager;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.List;
import java.util.TimeZone;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* 多線程
*/
public class WorkThread implements Runnable{ //建立線程的兩種方法 1 實現Runnable 接口 2 繼承 Thread 類
private DataSourceService dataSourceService;
private DataSourceMapper dataSourceMapper;
private Integer begin;
private Integer end;
private String tableName;
private JSONArray columnArray;
private Integer dbid;
private DbType type;
private JSONArray bodys;
private List<Integer> listInteger;
private PlatformTransactionManager transactionManager;
private CountDownLatch mainThreadLatch;
private CountDownLatch rollBackLatch;
private AtomicBoolean rollbackFlag;
private StringBuffer message;
/**
* @param i
* @param num
* @param tableFrom
* @param columnArrayFrom
* @param dbidFrom
* @param typeFrom
*/
public WorkThread(int i, int num, String tableFrom, JSONArray columnArrayFrom, int dbidFrom
, DbType typeFrom, JSONArray bodysFrom, List<Integer> listIntegerFrom
,CountDownLatch mainThreadLatch,CountDownLatch rollBackLatch,AtomicBoolean rollbackFlag
,StringBuffer messageFrom) {
begin=i*num;
end=begin+num;
tableName = tableFrom;
columnArray = columnArrayFrom;
dbid = dbidFrom;
type = typeFrom;
bodys = bodysFrom;
listInteger = listIntegerFrom;
this.dataSourceMapper = SpringApplicationContext.getBean(DataSourceMapper.class);
this.dataSourceService = SpringApplicationContext.getBean(DataSourceService.class);
this.transactionManager = SpringApplicationContext.getBean(PlatformTransactionManager.class);
this.mainThreadLatch = mainThreadLatch;
this.rollBackLatch = rollBackLatch;
this.rollbackFlag = rollbackFlag;
this.message = messageFrom;
}
public void run() {
DataSource dataSource = dataSourceMapper.queryDataSourceByID(dbid);
String cp = dataSource.getConnectionParams();
Connection con=null;
con = dataSourceService.getConnection(type,cp);
if(con!=null)
{
SimpleDateFormat sdf = new SimpleDateFormat("HH:mm:ss:SS");
TimeZone t = sdf.getTimeZone();
t.setRawOffset(0);
sdf.setTimeZone(t);
Long startTime = System.currentTimeMillis();
try {
con.setAutoCommit(false);
//---------------------------- 獲取字段和類型
String columnString = null;//活動的字段
int intForType = 0;
String type[] = new String[columnArray.size()];//類型集合
for(int i=0;i<columnArray.size();i++){
JSONObject ob = (JSONObject)columnArray.get(i);
if(columnString==null){
columnString = String.valueOf(ob.get("name"));
}else{
columnString = columnString + ","+ String.valueOf(ob.get("name"));
}
type[intForType] = String.valueOf(ob.get("type"));
intForType = intForType + 1;
}
intForType = 0;
//這一步是為了形成 insert into "+tableName+"(id,name,age) values (?,?,?);
String dataString = null;
for(int i=0;i<columnArray.size();i++){
if(dataString==null){
dataString = "?";
}else{
dataString = dataString +","+"?";
}
}
//--------------------------------
StringBuffer sql = new StringBuffer();
sql = sql.append("insert into "+tableName+"("+columnString+") values ("+dataString+")") ;
PreparedStatement pst= (PreparedStatement)con.prepareStatement(sql.toString());
for(int i=begin;i<end;i++) {
JSONObject ob = (JSONObject)bodys.get(i);
if(ob!=null){
String [] array = ob.get(i).toString().split("\,");
String [] arrayFinal = getFinalData(listInteger,array);
for(int j=0;j<type.length;j++){
String typeString = type[j].toLowerCase();
int z = j+1;
if("string".equals(typeString)||"varchar".equals(typeString)){
pst.setString(z,arrayFinal[j]);//這里的第一個參數 是指 替換第幾個?
}else if("int".equals(typeString)||"bigint".equals(typeString)){
pst.setInt(z,Integer.valueOf(arrayFinal[j]));//這里的第一個參數 是指 替換第幾個?
}else if("long".equals(typeString)){
pst.setLong(z,Long.valueOf(arrayFinal[j]));//這里的第一個參數 是指 替換第幾個?
}else if("double".equals(typeString)){
pst.setDouble(z,Double.parseDouble(arrayFinal[j]));
}else if("date".equals(typeString)||"datetime".equals(typeString)){
pst.setDate(z, setDateback(arrayFinal[j]));
}else if("Timestamp".equals(typeString)){
pst.setTimestamp(z, setTimestampback(arrayFinal[j]));
}
}
}
pst.addBatch();
}
pst.executeBatch();
mainThreadLatch.countDown();
rollBackLatch.await();
if(rollbackFlag.get()){
con.rollback();//表示回滾事務;
}else{
con.commit();//事務提交
}
con.close();
} catch (Exception e) {
System.out.println(e.getMessage());
message = message.append(e.getMessage());
rollbackFlag.set(true);
mainThreadLatch.countDown();
try {
con.close();
} catch (SQLException throwables) {
throwables.printStackTrace();
}
}
Long endTime = System.currentTimeMillis();
System.out.println(Thread.currentThread().getName()+":startTime= "+sdf.format(new Date(startTime))+",endTime= "+sdf.format(new Date(endTime))
+" 用時:"+sdf.format(new Date(endTime - startTime)));
}
}
public java.sql.Date setDateback(String dateString) throws ParseException {
SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" );
java.util.Date date = sdf.parse( "2015-5-6 10:30:00" );
long lg = date.getTime();// 日期 轉 時間戳
return new java.sql.Date( lg );
}
public java.sql.Timestamp setTimestampback(String dateString) throws ParseException {
SimpleDateFormat sdf = new SimpleDateFormat( "yyyy-MM-dd HH:mm:ss" );
java.util.Date date = sdf.parse( "2015-5-6 10:30:00" );
long lg = date.getTime();// 日期 轉 時間戳
return new java.sql.Timestamp( lg );
}
public String [] getFinalData(List<Integer> listInteger,String[] array){
String [] arrayFinal = new String [listInteger.size()];
for(int i=0;i<listInteger.size();i++){
int a = listInteger.get(i);
arrayFinal[i] = array[a];
}
return arrayFinal;
}
}
◆ 代碼實際運用踩坑!!!!
還記得這里有個一批次處理多少數據么,我這邊設置了20,實際到運用中的時候客戶給了個20W的數據,我批次設置為20,那就有1W個子線程!!!!
這還不是最糟糕的,最糟糕的是每個子線程都會創建一個數據庫連接,數據庫直接被我搞炸了
所以這里需要把:
int num = 20; //一個批次的數據有多少
改成:
int num = 20000; //一個批次的數據有多少