如题:
import com.yanwu.spring.cloud.common.core.common.Contents; import lombok.Data; import lombok.SneakyThrows; import lombok.experimental.Accessors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClients; import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; import org.apache.http.protocol.BasicHttpContext; import org.apache.http.protocol.HttpContext; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import org.springframework.util.Assert; import java.io.*; import java.net.HttpURLConnection; import java.net.URL; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadPoolExecutor; /** * @author <a herf="mailto:yanwu0527@163.com">XuBaofeng</a> * @date 2020/5/8 17:26. * <p> * description: 文件分片下载 */ @Slf4j public class DownLoadUtil { /*** 文件下载线程池 ***/ private static final ThreadPoolTaskExecutor EXECUTOR; /*** 每个线程下载的字节数 */ private static final Long UNIT_SIZE = 1000 * 1024L; /*** 客户端 */ private static final CloseableHttpClient HTTP_CLIENT; static { PoolingHttpClientConnectionManager cm = new PoolingHttpClientConnectionManager(); cm.setMaxTotal(100); HTTP_CLIENT = HttpClients.custom().setConnectionManager(cm).build(); EXECUTOR = new ThreadPoolTaskExecutor(); // ----- 设置核心线程数 EXECUTOR.setCorePoolSize(50); // ----- 设置最大线程数 EXECUTOR.setMaxPoolSize(100); // ----- 设置队列容量 EXECUTOR.setQueueCapacity(Integer.MAX_VALUE); // ----- 设置线程活跃时间(秒) EXECUTOR.setKeepAliveSeconds(120); // ----- 设置默认线程名称 EXECUTOR.setThreadNamePrefix("down-pool-"); // ----- 设置拒绝策略 EXECUTOR.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy()); // ----- 执行初始化 EXECUTOR.initialize(); } /** * 下载 * * @param fileUrl 资源路径 * @param localPath 文件路径 */ public static void download(String fileUrl, String localPath) throws Exception { // ----- 获取远程文件的大小,根据文件大小决定线程的个数 log.info("download file begin, localPath: {}, fileUrl: {}", localPath, fileUrl); HttpURLConnection httpConnection = (HttpURLConnection) new URL(fileUrl).openConnection(); httpConnection.setRequestMethod("HEAD"); int responseCode = httpConnection.getResponseCode(); Assert.isTrue((responseCode <= 400), "get remote file size error. code: " + responseCode); long fileSize = Long.parseLong(httpConnection.getHeaderField("Content-Length")); Assert.isTrue((fileSize > 0), "get remote file size error. size is zero"); long threadCount = Math.floorDiv(fileSize, UNIT_SIZE); threadCount = fileSize == threadCount * UNIT_SIZE ? threadCount : threadCount + 1; // ----- 检查文件[父目录是否存在 && 文件是否存在] File file = new File(localPath); if (file.exists() && file.isFile()) { Assert.isTrue(file.delete(), "file delete error."); } FileUtil.checkDirectoryPath(file.getParentFile()); Assert.isTrue(file.createNewFile(), "file create error."); // ----- 根据threadCount开始下载文件 CountDownLatch end = new CountDownLatch((int) threadCount); long offset = 0; long start = System.currentTimeMillis(); while (fileSize > 0) { long length = fileSize > UNIT_SIZE ? UNIT_SIZE : fileSize; EXECUTOR.execute(DownLoadTask.getInstance(fileUrl, localPath, offset, length, end)); fileSize -= length; offset += UNIT_SIZE; } try { end.await(); } catch (InterruptedException e) { log.error("downLoad await error.", e); } log.info("download file done!localPath: {}, time: {} S", localPath, (System.currentTimeMillis() - start) / 1000); } public static void main(String[] args) throws Exception { String param = "F:\file\2020\new 1.txt"; File file = new File(param); Reader reader = new FileReader(file); BufferedReader bufferedReader = new BufferedReader(reader); String lien; while (StringUtils.isNotBlank((lien = bufferedReader.readLine()))) { String puffix = lien.substring(lien.lastIndexOf("/")); String path = "F:\file\2020\111\" + puffix; download(lien, path); } } /** * 文件下载任务 */ @Data @Accessors(chain = true) private static class DownLoadTask implements Runnable { /*** 待下载的文件 */ private String url; /*** 本地文件名 */ private String fileName; /*** 偏移量 */ private Long offset; /*** 分配给本线程的下载字节数 */ private Long length; private CountDownLatch end; private HttpContext context; @Override @SneakyThrows public void run() { HttpGet httpGet = new HttpGet(url); httpGet.addHeader("Range", "bytes=" + offset + "-" + (offset + length - 1)); File file = new File(fileName); try (RandomAccessFile raf = new RandomAccessFile(file, "rw"); CloseableHttpResponse response = HTTP_CLIENT.execute(httpGet, context); BufferedInputStream bis = new BufferedInputStream(response.getEntity().getContent())) { int read; byte[] bytes = new byte[Contents.DEFAULT_SIZE]; while ((read = bis.read(bytes, 0, bytes.length)) != -1) { raf.seek(offset); raf.write(bytes, 0, read); offset += read; } } finally { end.countDown(); log.info("task: {} is go on!", end.getCount()); } } static DownLoadTask getInstance(String url, String fileName, long offset, long length, CountDownLatch end) { return new DownLoadTask().setUrl(url).setFileName(fileName).setOffset(offset) .setLength(length).setEnd(end).setContext(new BasicHttpContext()); } private DownLoadTask() { } } }