model.fit源码分析
首先来到module模块中,即https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/module,进入base_module.py中,我们便可以看到fit()
的原型。
class BaseModule(object):
################################################################################
# High Level API
################################################################################
def forward_backward(self, data_batch):
"""A convenient function that calls both ``forward`` and ``backward``."""
self.forward(data_batch, is_train=True)
self.backward()
# 验证集评测
def score(self, eval_data, eval_metric, num_batch=None, batch_end_callback=None,
score_end_callback=None,
reset=True, epoch=0, sparse_row_id_fn=None):
"""Runs prediction on ``eval_data`` and evaluates the performance according to
the given ``eval_metric``.
Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
to see an end-to-end use-case.
Parameters
----------
eval_data : DataIter
Evaluation data to run prediction on.
eval_metric : EvalMetric or list of EvalMetrics
Evaluation metric to use.
num_batch : int
Number of batches to run. Defaults to ``None``, indicating run until the `DataIter`
finishes.
batch_end_callback : function
Could also be a list of functions.
reset : bool
Defaults to ``True``. Indicates whether we should reset `eval_data` before starting
evaluating.
epoch : int
Defaults to 0. For compatibility, this will be passed to callbacks (if any).
During training, this will correspond to the training epoch number.
sparse_row_id_fn : A callback function
The function takes `data_batch` as an input and returns a dict of
str -> NDArray. The resulting dict is used for pulling row_sparse
parameters from the kvstore, where the str key is the name of the param,
and the value is the row id of the param to pull.
Examples
--------
>>> # An example of using score for prediction.
>>> # Evaluate accuracy on val_dataiter
>>> metric = mx.metric.Accuracy()
>>> mod.score(val_dataiter, metric)
>>> mod.score(val_dataiter, ['mse', 'acc'])
"""
assert self.binded and self.params_initialized
# reset验证集
if reset:
eval_data.reset()
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
eval_metric.reset()
actual_num_batch = 0
# 验证集batch获取
for nbatch, eval_batch in enumerate(eval_data):
if num_batch is not None and nbatch == num_batch:
break
# 模型加载数据集
self.prepare(eval_batch, sparse_row_id_fn=sparse_row_id_fn)
# 前向传播
self.forward(eval_batch, is_train=False)
# 调用metric列表update函数
if isinstance(eval_batch, list):
self.update_metric(eval_metric, [eb.label for eb in eval_batch], pre_sliced=True)
else:
self.update_metric(eval_metric, eval_batch.label)
# batch结束回调
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch,
nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
actual_num_batch += 1
# 验证集评测结束回调
if score_end_callback:
params = BatchEndParam(epoch=epoch,
nbatch=actual_num_batch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(score_end_callback):
callback(params)
# 返回metric列表结果name:value
return eval_metric.get_name_value()
def fit(self, train_data, eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None, kvstore='local',
optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
eval_end_callback=None,
eval_batch_end_callback=None, initializer=Uniform(0.01),
arg_params=None, aux_params=None, allow_missing=False,
force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
validation_metric=None, monitor=None, sparse_row_id_fn=None):
"""Trains the module parameters.
Checkout `Module Tutorial <https://mxnet.apache.org/api/python/tutorials/packages/module/index.html>`_
to see an end-to-end use-case.
Parameters
----------
train_data : DataIter
训练集数据迭代器
eval_data : DataIter
如果不是'None',将用作验证集,并将评估每个时期之后的性能。
eval_metric : str or EvalMetric
默认是字符串'accuracy'.训练期间用来显示的绩效指标。
其他可能的预定义指标是:'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
epoch_end_callback : function or list of functions
每个epoch结束时回调,参数 `epoch`, `symbol`, `arg_params`and `aux_params`
batch_end_callback : function or list of function
每个batch结束时回调,参数 `BatchEndParam`.
kvstore : str or KVStore
参数更新设备,默认值'local'.
"device",GPU计算梯度更新权重
"local",CPU更新
"dist_device_sync",分布式训练
optimizer : str or Optimizer
优化器,默认值'sgd'.
optimizer_params : dict
优化器参数,默认值(('learning_rate', 0.01),)
eval_end_callback : function or list of function
evaluation全跑完回调
eval_batch_end_callback : function or list of function
evaluation一个batch跑完回调
initializer : Initializer
如果尚未初始化模块参数,则调用初始化程序来初始化它们
arg_params : dict
默认None, 值不为None,则替代initializer初始化参数
aux_params : dict
默认None, 值不为None,则替代initializer初始化参数
allow_missing : bool
默认False,是否允许丢失参数
指示当arg_params和aux_params不为None时是否允许缺少参数。
allow_missing=True,那么缺少的参数将通过initializer进行初始化。
force_rebind : bool
默认False
如果已经绑定执行器,是否强制重新绑定执行器。
force_init : bool
默认False
指示即使参数已经初始化也是否强制初始化。
begin_epoch : int
默认值0
指示开始epoch。通常,如果从前一个训练阶段在Epoch[n]保存,重新训练则该值应为n+1
num_epoch : int
训练的epoch数量
sparse_row_id_fn : A callback function
The function takes `data_batch` as an input and returns a dict of
str -> NDArray. The resulting dict is used for pulling row_sparse
parameters from the kvstore, where the str key is the name of the param,
and the value is the row id of the param to pull.
Examples
--------
>>> # An example of using fit for training.
>>> # Assume training dataIter and validation dataIter are ready
>>> # Assume loading a previously checkpointed model
>>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
>>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
... optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
... arg_params=arg_params, aux_params=aux_params,
... eval_metric='acc', num_epoch=10, begin_epoch=3)
"""
assert num_epoch is not None, 'please specify number of epochs'
# 绑定训练集数据symbols name
self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
for_training=True, force_rebind=force_rebind)
if monitor is not None:
self.install_monitor(monitor)
# 初始化权重参数,初始化策略参考以上的参数说明
self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
# 初始化优化器
self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
optimizer_params=optimizer_params)
# 验证评估
if validation_metric is None:
validation_metric = eval_metric
# str类型的eval_metric转metric.EvalMetric
if not isinstance(eval_metric, metric.EvalMetric):
eval_metric = metric.create(eval_metric)
################################################################################
# training loop
################################################################################
# for循环训练
for epoch in range(begin_epoch, num_epoch):
tic = time.time()
# 每一轮的评估reset
eval_metric.reset()
# nbatch计数
nbatch = 0
data_iter = iter(train_data)
end_of_batch = False
next_data_batch = next(data_iter)
# 循环next()获取训练集一个batch数据
while not end_of_batch:
data_batch = next_data_batch
if monitor is not None:
monitor.tic()
# 前向传播 + 反向传播计算梯度
self.forward_backward(data_batch)
# 根据优化器梯度更新权重
self.update()
# 评估更新,调用metric的update
if isinstance(data_batch, list):
self.update_metric(eval_metric,
[db.label for db in data_batch],
pre_sliced=True)
else:
self.update_metric(eval_metric, data_batch.label)
# 获取下一个batch数据
try:
# pre fetch next batch
next_data_batch = next(data_iter)
self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
except StopIteration:
end_of_batch = True
if monitor is not None:
monitor.toc_print()
# 获取eval_metric列表的结果name:value
if end_of_batch:
eval_name_vals = eval_metric.get_global_name_value()
# batch结束回调
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
for callback in _as_list(batch_end_callback):
callback(batch_end_params)
nbatch += 1
# one epoch of training is finished
# 每一个epoch结束,输出eval_metric评价列表结果, Train-xxx=xxx
for name, val in eval_name_vals:
self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
# 输出每一个epoch时间
toc = time.time()
self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
# 参数同步
# sync aux params across devices
arg_params, aux_params = self.get_params()
self.set_params(arg_params, aux_params)
# 每一个epoch结束回调
if epoch_end_callback is not None:
for callback in _as_list(epoch_end_callback):
callback(epoch, self.symbol, arg_params, aux_params)
#----------------------------------------
# evaluation on validation set
# 验证集评测,validation_metric为None时与训练集的metric列表一致
if eval_data:
res = self.score(eval_data, validation_metric,
score_end_callback=eval_end_callback,
batch_end_callback=eval_batch_end_callback, epoch=epoch)
#TODO: pull this into default
# 输出验证集评测log
for name, val in res:
self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
# end of 1 epoch, reset the data-iter for another epoch
# 复位训练集数据
train_data.reset()
训练log源码分析
_cb = mx.callback.Speedometer(batch_size, frequent)
def _batch_callback(param):
# 显示训练log,INFO:root:Epoch[26] Batch [0-20] Speed: 257.26 samples/sec acc=0.968571 lossvalue=0.331392
_cb(param)
class Speedometer(object): """Logs training speed and evaluation metrics periodically. Parameters ---------- batch_size: int Batch size of data. frequent: int Specifies how frequently training speed and evaluation metrics must be logged. Default behavior is to log once every 50 batches. auto_reset : bool Reset the evaluation metrics after each log. Example ------- >>> # Print training speed and evaluation metrics every ten batches. Batch size is one. >>> module.fit(iterator, num_epoch=n_epoch, ... batch_end_callback=mx.callback.Speedometer(1, 10)) Epoch[0] Batch [10] Speed: 1910.41 samples/sec Train-accuracy=0.200000 Epoch[0] Batch [20] Speed: 1764.83 samples/sec Train-accuracy=0.400000 Epoch[0] Batch [30] Speed: 1740.59 samples/sec Train-accuracy=0.500000 """ def __init__(self, batch_size, frequent=50, auto_reset=True): self.batch_size = batch_size self.frequent = frequent self.init = False self.tic = 0 self.last_count = 0 self.auto_reset = auto_reset def __call__(self, param): """Callback to Show speed.""" count = param.nbatch # 跳过nbatch=0的log输出 if self.last_count > count: self.init = False self.last_count = count if self.init: # frequent个batch进行一次log输出 if count % self.frequent == 0: # #11504 # 计算每一个frequent训练的速度,Speed: 257.26 samples/sec代表1s能训练多少张 try: speed = self.frequent * self.batch_size / (time.time() - self.tic) except ZeroDivisionError: speed = float('inf') # 输出log,Speed训练速度,eval_metric列表的name:value if param.eval_metric is not None: # 获取模型eval_metric的计算结果name:value name_value = param.eval_metric.get_name_value() if self.auto_reset: param.eval_metric.reset_local() msg = 'Epoch[%d] Batch [%d-%d] Speed: %.2f samples/sec' msg += ' %s=%f'*len(name_value) logging.info(msg, param.epoch, count-self.frequent, count, speed, *sum(name_value, ())) else: msg = 'Epoch[%d] Batch [0-%d] Speed: %.2f samples/sec' msg += ' %s=%f'*len(name_value) logging.info(msg, param.epoch, count, speed, *sum(name_value, ())) else: logging.info("Iter[%d] Batch [%d] Speed: %.2f samples/sec", param.epoch, count, speed) self.tic = time.time() else: self.init = True self.tic = time.time()