接上一篇,对LSPI算法的code进行解释
1 # -*- coding: utf-8 -*- 2 """Contains main interface to LSPI algorithm.""" 3 #LSPI算法的主要接口 4 from copy import copy 5 6 import numpy as np 7 8 9 def learn(data, initial_policy, solver, epsilon=10**-5, max_iterations=10): 10 r"""Find the optimal policy for the specified data. 11 #对于特定的数据找到最优的策略 12 Parameters输入 13 ----------
14
#data通常是一个采样的列表,然而data的类型并不是很重要,会在求解器里解决这些问题,比如当进行基于模型的学习时,我们输入的是模型而不是数据!!
data:# 15 Generally a list of samples, however, the type of data does not matter 16 so long as the specified solver can handle it in its solve routine. For 17 example when doing model based learning one might pass in a model 18 instead of sample data
#最初的策略会被保留?
19 initial_policy: Policy#策略
20 Starting policy. A copy of this policy will be made at the start of the
21 method. This means that the provided initial policy will be preserved.
#Solver ABC的一个子类,实现各种计算方法比如梯度下降,线性求解等等 22 solver: Solver#求解器 23 A subclass of the Solver abstract base class. This class must implement 24 the solve method. Examples of solvers might be steepest descent or 25 any other linear system of equation matrix solver. This is basically 26 going to be implementations of the LSTDQ algorithm.
#策略权重更新的阈值?决定策略是否收敛,如果权重更新的大小小于这个值就认为是收敛了 27 epsilon: float 28 The threshold of the change in policy weights. Determines if the policy 29 has converged. When the L2-norm of the change in weights is less than 30 this value the policy is considered converged
#最大的迭代次数 31 max_iterations: int 32 The maximum number of iterations to run before giving up on 33 convergence. The change in policy weights are not guaranteed to ever 34 go below epsilon. To prevent an infinite loop this parameter must be 35 specified. 36 37 Return#返回 38 ------
#收敛了的策略,如果没有收敛就返回最后一次计算的数值 39 Policy 40 The converged policy. If the policy does not converge by max_iterations 41 then this will be the last iteration's policy. 42 43 Raises#一些错误的定义 44 ------ 45 ValueError 46 If epsilon is <= 0 47 ValueError 48 If max_iteration <= 0 49 50 """ 51 if epsilon <= 0:#检查收敛阈值 52 raise ValueError('epsilon must be > 0: %g' % epsilon) 53 if max_iterations <= 0:#检查最大迭代次数 54 raise ValueError('max_iterations must be > 0: %d' % max_iterations) 55 56 # this is just to make sure that changing the weight vector doesn't 57 # affect the original policy weights 58 curr_policy = copy(initial_policy)#这时为了保证更新的策略不会影响到最初的策略,所以我们复制出来一份最初策略 59 60 distance = float('inf')#距离初始化 61 iteration = 0#迭代次数初始化 62 while distance > epsilon and iteration < max_iterations:#当更新长度比较大,并且迭代次数没达到最大值时进行循环 63 iteration += 1#迭代次数加1 64 new_weights = solver.solve(data, curr_policy)#用求解器求解最新的权重@!只更新一次! 65 66 distance = np.linalg.norm(new_weights - curr_policy.weights)#计算新的权重和老权重的距离 67 curr_policy.weights = new_weights更新权重 68 69 return curr_policy返回计算后的策略