最近在做bert文本分类,有一个生成器,记录一下使用,跟我网上查到的不太一样,主要在.iter()这个地方,很多代码都是没有这个,不知道是不是版本原因
datalist, labellist = get_data_from_excel(r'data/test.xlsx')
data = data_generator(datalist).__iter__() # 注意这个.__iter__()
# 获取一批数据
print(next(data))
# 或者
for x in data:
print(x)
点击查看代码
class data_generator:
"""
data_generator只是一种为了节约内存的数据方式
"""
def __init__(self, data, batch_size=Batch_size, shuffle=True):
"""
:param data: 训练的文本列表
:param batch_size: 每次训练的个数
:param shuffle: 文本是否打乱
"""
self.data = data
self.batch_size = batch_size
self.shuffle = shuffle
self.steps = len(self.data) // self.batch_size
if len(self.data) % self.batch_size != 0:
self.steps += 1
def __len__(self):
return self.steps
def __iter__(self):
while True:
idxs = list(range(len(self.data))) # 生成一个序列
if self.shuffle:
np.random.shuffle(idxs) # 打乱序列
X1, X2, Y = [], [], []
for i in idxs:
d = self.data[i]
text = d[0][:maxlen]
x1, x2 = tokenizer.encode(first=text) # 添加[CLS]和[SEP]
y = d[1]
X1.append(x1)
X2.append(x2)
Y.append([y])
if len(X1) == self.batch_size or i == idxs[-1]:
# 对一批数据(最后一批不满batch_size)进行padding
X1 = seq_padding(X1) # 内部转为了np.array
X2 = seq_padding(X2)
Y = seq_padding(Y)
yield [X1, X2], Y[:, 0, :]
[X1, X2, Y] = [], [], []