本篇是solver.py
1 # -*- coding: utf-8 -*- 2 """Contains main LSPI method and various LSTDQ solvers.""" 3 4 import abc 5 import logging 6 7 import numpy as np 8 9 import scipy.linalg 10 11 12 class Solver(object):#这里也出现一个继承ABC类的类了 13 14 r"""ABC for LSPI solvers. 15 16 Implementations of this class will implement the various LSTDQ algorithms 17 with various linear algebra solving techniques. This solver will be used 18 by the lspi.learn method. The instance will be called iteratively until 19 the convergence parameters are satisified. 20 21 """ 22 23 __metaclass__ = abc.ABCMeta#继承 24 25 @abc.abstractmethod#必须覆盖的函数 26 def solve(self, data, policy):#求解函数 27 r"""Return one-step update of the policy weights for the given data. 28 #该函数对于给出的数据更新一步权重 29 Parameters#输入参数 30 ---------- 31 data:#数据 32 #求解器需要的数据,通常是一个元素是采样的列表,当然也可以是各种求解器支持的方法 33 This is the data used by the solver. In most cases this will be 34 a list of samples. But it can be anything supported by the specific 35 Solver implementation's solve method. 36 policy: Policy#策略 37 当前的策略,要对它进行提升 38 The current policy to find an improvement to. 39 40 Returns 41 ------- 42 numpy.array#输出的权重 43 Return the new weights as determined by this method. 44 45 """ 46 pass # pragma: no cover 47 48 49 class LSTDQSolver(Solver):#最小二乘TDQ求解器 50 51 """LSTDQ Implementation with standard matrix solvers. 52 #用矩阵的形式实现 53 #算法根据文献的第五张图,如果矩阵A是满秩的,那么就用scipy的库来计算 54 #如果不满秩,就用最小二乘的方法 55 Uses the algorithm from Figure 5 of the LSPI paper. If the A matrix 56 turns out to be full rank then scipy's standard linalg solver is used. If 57 the matrix turns out to be less than full rank then least squares method 58 will be used. 59 #通常矩阵A的对角线值是小的正数值,这用来保证即使是很少的采样,矩阵A也能满秩,如果 60 #不想要这样的前提,可以让前提条件值为0 61 By default the A matrix will have its diagonal preconditioned with a small 62 positive value. This will help to ensure that even with few samples the 63 A matrix will be full rank. If you do not want the A matrix to be 64 preconditioned then you can set this value to 0. 65 66 Parameters前提条件值 67 ---------- 68 precondition_value: float 69 Value to set A matrix diagonals to. Should be a small positive number. 70 If you do not want preconditioning enabled then set it 0. 71 """ 72 73 def __init__(self, precondition_value=.1):#初始化 74 """Initialize LSTDQSolver.""" 75 self.precondition_value = precondition_value#对前提条件值赋值 76 77 def solve(self, data, policy):#求解函数 78 """Run LSTDQ iteration. 79 80 See Figure 5 of the LSPI paper for more information. 81 """ 82 k = policy.basis.size()#k是特征phi向量的长度 83 a_mat = np.zeros((k, k))#建立A矩阵,k行k列 84 np.fill_diagonal(a_mat, self.precondition_value)#向矩阵A中填充前提条件值 85 #说明前提条件值是用来保证矩阵是满秩的 86 87 b_vec = np.zeros((k, 1))#b向量 88 89 for sample in data:#对于data中的每一个采样进行循环 90 phi_sa = (policy.basis.evaluate(sample.state, sample.action) 91 .reshape((-1, 1)))#通过basisfunction求出phi值 92 93 if not sample.absorb: 94 best_action = policy.best_action(sample.next_state)#计算下一个状态下的最佳动作 95 phi_sprime = (policy.basis 96 .evaluate(sample.next_state, best_action) 97 .reshape((-1, 1)))#计算一个新的phi 98 else: 99 phi_sprime = np.zeros((k, 1)) 100 101 a_mat += phi_sa.dot((phi_sa - policy.discount*phi_sprime).T)#计算a矩阵 102 b_vec += phi_sa*sample.reward#计算b矩阵 103 104 a_rank = np.linalg.matrix_rank(a_mat) 105 if a_rank == k:#如果满秩 106 w = scipy.linalg.solve(a_mat, b_vec)#求逆解出w值 107 else: 108 logging.warning('A matrix is not full rank. %d < %d', a_rank, k) 109 w = scipy.linalg.lstsq(a_mat, b_vec)[0] 110 return w.reshape((-1, ))#返回已经优化后的w值.