使用__iter__, __getitem__来模拟数据处理部分
import torch.utils.data class Model(): def __init__(self, animal_list): self.animal_list = animal_list # 根据迭代batch_size进行返回 def __getitem__(self, index): root = {'A': self.animal_list[index], 'B': 1} return root def __len__(self): return len(self.animal_list) class Animal: def __init__(self, animal_list): self.animals_name = animal_list self.m = Model(self.animals_name) self.model = torch.utils.data.DataLoader( self.m, # 构造两个self.m的输出结果 batch_size=2, shuffle=True # idx 是随机值 ) def __iter__(self): for i, data in enumerate(self.model): yield data animals = Animal(['dog', 'cat', 'fish']) for i, animal in enumerate(animals): print(animal)