trainECAPAModel.py
import argparse, glob, os, torch, warnings, time
from tools import *
from dataLoader import train_loader
from ECAPAModel import ECAPAModel
parser = argparse.ArgumentParser(description = "ECAPA_trainer")
## Training Settings
parser.add_argument('--num_frames', type=int, default=200, help='Duration of the input segments, eg: 200 for 2 second')
parser.add_argument('--max_epoch', type=int, default=80, help='Maximum number of epochs')
parser.add_argument('--batch_size', type=int, default=400, help='Batch size')
parser.add_argument('--n_cpu', type=int, default=4, help='Number of loader threads')
parser.add_argument('--test_step', type=int, default=1, help='Test and save every [test_step] epochs')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
parser.add_argument("--lr_decay", type=float, default=0.97, help='Learning rate decay every [test_step] epochs')
## Training and evaluation path/lists, save path
parser.add_argument('--train_list', type=str, default="/data08/VoxCeleb2/train_list.txt", help='The path of the training list, eg:"/data08/VoxCeleb2/train_list.txt" in my case, which contains 1092009 lins)
parser.add_argument('--train_path', type=str, default="/data08/VoxCeleb2/train/wav", help='The path of the training data, eg:"/data08/VoxCeleb2/train/wav" in my case')
parser.add_argument('--eval_list', type=str, default="/data08/VoxCeleb1/veri_test2.txt", help='The path of the evaluation list, veri_test2.txt comes from https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt)
parser.add_argument('--eval_path', type=str, default="/data08/VoxCeleb1/test/wav", help='The path of the evaluation data, eg:"/data08/VoxCeleb1/test/wav" in my case')
parser.add_argument('--musan_path', type=str, default="/data08/Others/musan_split", help='The path to the MUSAN set, eg:"/data08/Others/musan_split" in my case')
parser.add_argument('--rir_path', type=str, default="/data08/Others/RIRS_NOISES/simulated_rirs", help='The path to the RIR set, eg:"/data08/Others/RIRS_NOISES/simulated_rirs" in my case');
parser.add_argument('--save_path', type=str, default="exps/exp1", help='Path to save the score.txt and models')
parser.add_argument('--initial_model', type=str, default="", help='Path of the initial_model')
## Model and Loss settings
parser.add_argument('--C', type=int, default=1024, help='Channel size for the speaker encoder')
parser.add_argument('--m', type=float, default=0.2, help='Loss margin in AAM softmax')
parser.add_argument('--s', type=float, default=30, help='Loss scale in AAM softmax')
parser.add_argument('--n_class', type=int, default=5994, help='Number of speakers')
## Command
parser.add_argument('--eval', dest='eval', action='store_true', help='Only do evaluation')
## Initialization
warnings.simplefilter("ignore")
torch.multiprocessing.set_sharing_strategy('file_system')
args = parser.parse_args()
args = init_args(args)
## Define the data loader
trainloader = train_loader(**vars(args))
trainLoader = torch.utils.data.DataLoader(trainloader, batch_size = args.batch_size, shuffle = True, num_workers = args.n_cpu, drop_last = True)
## Search for the exist models
modelfiles = glob.glob('%s/model_0*.model'%args.model_save_path)
modelfiles.sort()
## Only do evaluation, the initial_model is necessary
if args.eval == True:
s = ECAPAModel(**vars(args))
print("Model %s loaded from previous state!"%args.initial_model)
s.load_parameters(args.initial_model)
EER, minDCF = s.eval_network(eval_list = args.eval_list, eval_path = args.eval_path)
print("EER %2.2f%%, minDCF %.4f%%"%(EER, minDCF))
quit()
## If initial_model is exist, system will train from the initial_model
if args.initial_model != "":
print("Model %s loaded from previous state!"%args.initial_model)
s = ECAPAModel(**vars(args))
s.load_parameters(args.initial_model)
epoch = 1
## Otherwise, system will try to start from the saved model&epoch
elif len(modelfiles) >= 1:
print("Model %s loaded from previous state!"%modelfiles[-1])
epoch = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][6:]) + 1
s = ECAPAModel(**vars(args))
s.load_parameters(modelfiles[-1])
## Otherwise, system will train from scratch
else:
epoch = 1
s = ECAPAModel(**vars(args))
EERs = []
score_file = open(args.score_save_path, "a+")
while(1):
## Training for one epoch
loss, lr, acc = s.train_network(epoch = epoch, loader = trainLoader)
## Evaluation every [test_step] epochs
if epoch % args.test_step == 0:
s.save_parameters(args.model_save_path + "/model_%04d.model"%epoch)
EERs.append(s.eval_network(eval_list = args.eval_list, eval_path = args.eval_path)[0])
print(time.strftime("%Y-%m-%d %H:%M:%S"), "%d epoch, ACC %2.2f%%, EER %2.2f%%, bestEER %2.2f%%"%(epoch, acc, EERs[-1], min(EERs)))
score_file.write("%d epoch, LR %f, LOSS %f, ACC %2.2f%%, EER %2.2f%%, bestEER %2.2f%%\n"%(epoch, lr, loss, acc, EERs[-1], min(EERs)))
score_file.flush()
if epoch >= args.max_epoch:
quit()
epoch += 1
loss.py
import torch, math
import torch.nn as nn
import torch.nn.functional as F
from tools import *
class AAMsoftmax(nn.Module):
def __init__(self, n_class, m, s):
super(AAMsoftmax, self).__init__()
self.m = m
self.s = s
self.weight = torch.nn.Parameter(torch.FloatTensor(n_class, 192), requires_grad=True)
self.ce = nn.CrossEntropyLoss()
nn.init.xavier_normal_(self.weight, gain=1)
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
self.th = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, x, label=None):
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, label.view(-1, 1), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output = output * self.s
loss = self.ce(output, label)
prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
return loss, prec1
dataloader.py
'''
DataLoader for training
'''
import glob, numpy, os, random, soundfile, torch
from scipy import signal
class train_loader(object):
def __init__(self, train_list, train_path, musan_path, rir_path, num_frames, **kwargs):
self.train_path = train_path
self.num_frames = num_frames
# Load and configure augmentation files
self.noisetypes = ['noise','speech','music']
self.noisesnr = {'noise':[0,15],'speech':[13,20],'music':[5,15]}
self.numnoise = {'noise':[1,1], 'speech':[3,8], 'music':[1,1]}
self.noiselist = {}
augment_files = glob.glob(os.path.join(musan_path,'*/*/*/*.wav'))
for file in augment_files:
if file.split('/')[-4] not in self.noiselist:
self.noiselist[file.split('/')[-4]] = []
self.noiselist[file.split('/')[-4]].append(file)
self.rir_files = glob.glob(os.path.join(rir_path,'*/*/*.wav'))
# Load data & labels
self.data_list = []
self.data_label = []
lines = open(train_list).read().splitlines()
dictkeys = list(set([x.split()[0] for x in lines]))
dictkeys.sort()
dictkeys = { key : ii for ii, key in enumerate(dictkeys) }
for index, line in enumerate(lines):
speaker_label = dictkeys[line.split()[0]]
file_name = os.path.join(train_path, line.split()[1])
self.data_label.append(speaker_label)
self.data_list.append(file_name)
def __getitem__(self, index):
# Read the utterance and randomly select the segment
audio, sr = soundfile.read(self.data_list[index])
length = self.num_frames * 160 + 240
if audio.shape[0] <= length:
shortage = length - audio.shape[0]
audio = numpy.pad(audio, (0, shortage), 'wrap')
start_frame = numpy.int64(random.random()*(audio.shape[0]-length))
audio = audio[start_frame:start_frame + length]
audio = numpy.stack([audio],axis=0)
# Data Augmentation
augtype = random.randint(0,5)
if augtype == 0: # Original
audio = audio
elif augtype == 1: # Reverberation
audio = self.add_rev(audio)
elif augtype == 2: # Babble
audio = self.add_noise(audio, 'speech')
elif augtype == 3: # Music
audio = self.add_noise(audio, 'music')
elif augtype == 4: # Noise
audio = self.add_noise(audio, 'noise')
elif augtype == 5: # Television noise
audio = self.add_noise(audio, 'speech')
audio = self.add_noise(audio, 'music')
return torch.FloatTensor(audio[0]), self.data_label[index]
def __len__(self):
return len(self.data_list)
def add_rev(self, audio):
rir_file = random.choice(self.rir_files)
rir, sr = soundfile.read(rir_file)
rir = numpy.expand_dims(rir.astype(numpy.float),0)
rir = rir / numpy.sqrt(numpy.sum(rir**2))
return signal.convolve(audio, rir, mode='full')[:,:self.num_frames * 160 + 240]
def add_noise(self, audio, noisecat):
clean_db = 10 * numpy.log10(numpy.mean(audio ** 2)+1e-4)
numnoise = self.numnoise[noisecat]
noiselist = random.sample(self.noiselist[noisecat], random.randint(numnoise[0],numnoise[1]))
noises = []
for noise in noiselist:
noiseaudio, sr = soundfile.read(noise)
length = self.num_frames * 160 + 240
if noiseaudio.shape[0] <= length:
shortage = length - noiseaudio.shape[0]
noiseaudio = numpy.pad(noiseaudio, (0, shortage), 'wrap')
start_frame = numpy.int64(random.random()*(noiseaudio.shape[0]-length))
noiseaudio = noiseaudio[start_frame:start_frame + length]
noiseaudio = numpy.stack([noiseaudio],axis=0)
noise_db = 10 * numpy.log10(numpy.mean(noiseaudio ** 2)+1e-4)
noisesnr = random.uniform(self.noisesnr[noisecat][0],self.noisesnr[noisecat][1])
noises.append(numpy.sqrt(10 ** ((clean_db - noise_db - noisesnr) / 10)) * noiseaudio)
noise = numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True)
return noise + audio