自动缓存文件
读取网络硬盘上的文件时,常常因为网速问题导致大量的时间浪费在IO操作上, 这个方法在第一次调用时会自动将网络文件缓存到本地临时文件夹
在第二次运行时就会调用本地的缓存文件, 免去网络IO的限制.
import shutil
import time
class auto_cache(object):
def __init__(self, myfile, cache_path="/tmp"):
assert os.path.isfile(myfile)
path, filename = os.path.split(myfile)
self.cache_file = os.path.join(cache_path, filename)
if not os.path.isfile(self.cache_file):
print("Caching {} from {} to {}".format(filename, path, cache_path))
start_time = time.perf_counter()
shutil.copy(myfile, self.cache_file)
end_time = time.perf_counter()
print("Elapsed Time: {:.2f}s".format(end_time-start_time))
else:
print("Cache file {} been detected!".format(self.cache_file))
def __enter__(self):
return self.cache_file
def __exit__(self, exc_type, exc_val, exc_tb):
return True
使用方法举例:
# original code
data = pd.read_csv(filename)
# new code
with auto_cache(filename) as filename:
data = pd.read_csv(filename)
自动计时方法
class tic_toc(object):
def __init__(self, comment):
self.comment = comment
def __enter__(self):
self.st_time = time.time()
def __exit__(self,a,b,c):
ed_time = time.time()
print(self.comment+", 耗时:{:.2f}s
".format(ed_time-self.st_time))
import sys
class print_and_save(object):
def __init__(self, filepath):
self.file = open(filepath, 'a')
self.old = sys.stdout # 将当前系统输出储存到临时变量
sys.stdout = self
def __enter__(self):
pass
def __call__(self,func):
def wrapper(*args, **kwargs):
frs = func(*args, **kwargs)
self._exit()
return frs
return wrapper
def write(self, message):
self.old.write(message)
self.file.write(message)
def flush(self):
self.old.flush()
self.file.flush()
def __exit__(self, exc_type, exc_val, exc_tb):
self._exit()
def _exit(self):
self.file.flush()
self.file.close()
sys.stdout = self.old
def zeroPadding(seqs, fillvalue=0, max_seq_length=None):
'''转置补零,输出的句子的列表示句子的词标
'''
if max_seq_length:
assert max_seq_length > 0
seqs = [seq[:min(len(seq), max_seq_length)] for seq in seqs]
return np.array(list(itertools.zip_longest(*seqs, fillvalue=fillvalue)))
def batch_itr(data, batch_size):
'''返回的每个batch都是排序好的
'''
data_size = len(data)
ids = list(range(data_size))
random.shuffle(ids)
num_batch = int(data_size/batch_size)
for i in range(num_batch):
gen_data = data[i*batch_size:min((i+1)*batch_size, data_size)]
gen_data.sort(key=lambda x: len(x[0]), reverse=True)
yield np.array(gen_data)
def model_load(resultpath):
'''模型重载
'''
if not os.path.exists(resultpath):
os.makedirs(resultpath,exist_ok=True)
return False,0
dirlist = sorted(os.listdir(resultpath), reverse=True)
print("============== 模型列表 ======================")
for i, p in enumerate(dirlist):
print(" [{}]: {}".format(i, p))
print("==============================================")
model_id = int(input("请选择模型:"))
try:
path = os.path.join(resultpath, dirlist[model_id])
with open(os.path.join(path, "checkpoint"), "r") as f:
checkpoint, model_save_name = f.read().split(",")
with open(model_save_name, "rb") as f:
model = pickle.load(f)
print("成功加载模型:{}".format(os.path.join(path, model_save_name)))
return int(checkpoint), model
except Exception as e:
print("模型加载失败:{}".format(e))
return False, 0