Q-learning
实例代码
1 import numpy as np
2 import random
3 from environment import Env
4 from collections import defaultdict
5
6 class QLearningAgent:
7 def __init__(self, actions):
8 # actions = [0, 1, 2, 3]
9 self.actions = actions
10 self.learning_rate = 0.01
11 self.discount_factor = 0.9
12 self.epsilon = 0.1
13 self.q_table = defaultdict(lambda: [0.0, 0.0, 0.0, 0.0])#待更新q表
14
15 # update q function with sample <s, a, r, s'>
16 def learn(self, state, action, reward, next_state):
17 current_q = self.q_table[state][action]
18 # using Bellman Optimality Equation to update q function
19 new_q = reward + self.discount_factor * max(self.q_table[next_state])
20 self.q_table[state][action] += self.learning_rate * (new_q - current_q)#更新公式,off-policy
21
22 # get action for the state according to the q function table
23 # agent pick action of epsilon-greedy policy
24 def get_action(self, state):
25 #epsilon-greedy policy
26 if np.random.rand() < self.epsilon:
27 # take random action
28 action = np.random.choice(self.actions)
29 else:
30 # take action according to the q function table
31 state_action = self.q_table[state]
32 action = self.arg_max(state_action)
33 return action
34
35 @staticmethod
36 def arg_max(state_action):
37 max_index_list = []
38 max_value = state_action[0]
39 for index, value in enumerate(state_action):
40 if value > max_value:
41 max_index_list.clear()
42 max_value = value
43 max_index_list.append(index)
44 elif value == max_value:
45 max_index_list.append(index)
46 return random.choice(max_index_list)
47
48 if __name__ == "__main__":
49 env = Env()
50 agent = QLearningAgent(actions=list(range(env.n_actions)))
51
52 for episode in range(1000):
53 state = env.reset()
54
55 while True:
56 env.render()
57
58 # take action and proceed one step in the environment
59 action = agent.get_action(str(state))
60 next_state, reward, done = env.step(action)
61
62 # with sample <s,a,r,s'>, agent learns new q function
63 agent.learn(str(state), action, reward, str(next_state))
64
65 state = next_state
66 env.print_value_all(agent.q_table)
67
68 # if episode ends, then break
69 if done:
70 break