我正在学习强化学习,我想实现 Q-Network 来解决 OpenAI Taxi 环境问题。我在网上找到了这段代码,我正在尝试运行代码,但出现错误。下面是代码
import numpy as np
import gym
import random
def main():
# create Taxi environment
env = gym.make('Taxi-v3')
# initialize q-table
state_size = env.observation_space.n
action_size = env.action_space.n
qtable = np.zeros((state_size, action_size))
# hyperparameters
learning_rate = 0.9
discount_rate = 0.8
epsilon = 1.0
decay_rate= 0.005
# training variables
num_episodes = 1000
max_steps = 99 # per episode
# training
for episode in range(num_episodes):
# reset the environment
state = env.reset()
done = False
for s in range(max_steps):
# exploration-exploitation tradeoff
if random.uniform(0,1) < epsilon:
# explore
action = env.action_space.sample()
else:
# exploit
action = np.argmax(qtable[state,:])
# take action and observe reward
new_state, reward, done, trunc, info = env.step(action)
# Q-learning algorithm
qtable[state,action] = qtable[state,action] + learning_rate * int(reward + discount_rate * np.max(qtable[new_state,:]) - qtable[state,action])
# Update to our new state
state = new_state
# if done, finish episode
if done == True:
break
# Decrease epsilon
epsilon = np.exp(-decay_rate*episode)
print(f"Training completed over {num_episodes} episodes")
input("Press Enter to watch trained agent...")
# watch trained agent
state = env.reset()
done = False
rewards = 0
for s in range(max_steps):
print(f"TRAINED AGENT")
print("Step {}".format(s+1))
action = np.argmax(qtable[state,:])
new_state, reward, done, trunc, info = env.step(action)
rewards += reward
env.render()
print(f"score: {rewards}")
state = new_state
if done == True:
break
env.close()
if __name__ == "__main__":
main()
当我尝试运行上面的代码时,我收到以下错误消息:
Traceback (most recent call last):
File "/tmp/ipykernel_2838/974516385.py", line 84, in <module>
main()
File "/tmp/ipykernel_2838/974516385.py", line 46, in main
qtable[state,action] = qtable[state,action] + learning_rate * int(reward +
discount_rate * np.max(qtable[new_state,:]) - qtable[state,action])
IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and
integer or boolean arrays are valid indices
这个错误是什么意思,我该如何解决?
这是什么意思?
你希望每个人
new_state
, state
, action
是整数。
至少其中一个不是。
关注 通常的建议。 插入
print(...)
语句
或者以其他方式说服自己
下标有效。
修复那些具有意外值的值。