• 增强学习--值迭代



     1 class ValueIteration:
     2     def __init__(self, env):
     3         self.env = env
     4         # 2-d list for the value function
     5         self.value_table = [[0.0] * env.width for _ in range(env.height)]
     6         self.discount_factor = 0.9
     8     # get next value function table from the current value function table
     9     def value_iteration(self):
    10         next_value_table = [[0.0] * self.env.width
    11                                     for _ in range(self.env.height)]
    12         for state in self.env.get_all_states():
    13             if state == [2, 2]:
    14                 next_value_table[state[0]][state[1]] = 0.0
    15                 continue
    16             value_list = []
    18             for action in self.env.possible_actions:
    19                 next_state = self.env.state_after_action(state, action)
    20                 reward = self.env.get_reward(state, action)
    21                 next_value = self.get_value(next_state)
    22                 value_list.append((reward + self.discount_factor * next_value))
    23             # return the maximum value(it is the optimality equation!!)
    24             next_value_table[state[0]][state[1]] = round(max(value_list), 2)#每一次更新值函数表时取最大回报的动作更新
    25         self.value_table = next_value_table
    27     # get action according to the current value function table
    28     def get_action(self, state):
    29         import pdb; pdb.set_trace()
    30         action_list = []
    31         max_value = -99999
    33         if state == [2, 2]:
    34             return []
    36         # calculating q values for the all actions and
    37         # append the action to action list which has maximum q value
    38         for action in self.env.possible_actions:
    40             next_state = self.env.state_after_action(state, action)
    41             reward = self.env.get_reward(state, action)
    42             next_value = self.get_value(next_state)
    43             value = (reward + self.discount_factor * next_value)
    45             if value > max_value:
    46                 action_list.clear()
    47                 action_list.append(action)
    48                 max_value = value
    49             elif value == max_value:
    50                 action_list.append(action)
    52         return action_list
    54     def get_value(self, state):
    55         return round(self.value_table[state[0]][state[1]], 2)
  • 相关阅读:
    ls 按大小排序 按时间排序
    【33.33%】【codeforces 608C】Chain Reaction
    【44.19%】【codeforces 608D】Zuma
    【22.73%】【codeforces 606D】Lazy Student
    【27.40%】【codeforces 599D】Spongebob and Squares
    【26.67%】【codeforces 596C】Wilbur and Points
    【13.91%】【codeforces 593D】Happy Tree Party
  • 原文地址:https://www.cnblogs.com/buyizhiyou/p/10250089.html
Copyright © 2020-2023  润新知