我想了解如何在 tf_agents 中使用
Actor
类。我正在使用 DDPG(演员评论家,尽管这并不重要)。我也在学习 gym
包,虽然这对问题来说并不完全重要。
我进入了
train.Actor
的类定义,在后台运行方法调用了 py_driver.PyDriver。据我了解,达到终端状态后,需要重置健身房环境。但是,在 Actor 和 PyDriver 类之后,我没有看到调用 env.reset() 的任何地方(在 init 方法之外)。然后查看 sac_agent.SacAgent
的教程,我也没有看到他们调用 env.reset() 。
有人可以帮助我了解缺少的内容吗?我不需要调用 env.reset() 吗?或者是否有一些我缺少的正在调用的代码?
这里是 PyDriver.run() 的方法:
def run(
self,
time_step: ts.TimeStep,
policy_state: types.NestedArray = ()
) -> Tuple[ts.TimeStep, types.NestedArray]:
num_steps = 0
num_episodes = 0
while num_steps < self._max_steps and num_episodes < self._max_episodes:
# For now we reset the policy_state for non batched envs.
if not self.env.batched and time_step.is_first() and num_episodes > 0:
policy_state = self._policy.get_initial_state(self.env.batch_size or 1)
action_step = self.policy.action(time_step, policy_state)
next_time_step = self.env.step(action_step.action)
# When using observer (for the purpose of training), only the previous
# policy_state is useful. Therefore substitube it in the PolicyStep and
# consume it w/ the observer.
action_step_with_previous_state = action_step._replace(state=policy_state)
traj = trajectory.from_transition(time_step, action_step_with_previous_state, next_time_step)
for observer in self._transition_observers:
observer((time_step, action_step_with_previous_state, next_time_step))
for observer in self.observers:
observer(traj)
for observer in self.info_observers:
observer(self.env.get_info())
if self._end_episode_on_boundary:
num_episodes += np.sum(traj.is_boundary())
else:
num_episodes += np.sum(traj.is_last())
num_steps += np.sum(~traj.is_boundary())
time_step = next_time_step
policy_state = action_step.state
return time_step, policy_state
如您所见,如果它达到边界,它会增加步数,如果它达到终止状态,它会增加剧集数。但是没有电话
env.reset()
.