MLP a2c 策略抱怨 0 不大于 0,或者无穷大不大于 0?

问题描述 投票:0回答:1

在训练一些火炬模型时出现以下错误:

ValueError('Expected parameter scale (Tensor of shape (1, 4)) of distribution Normal(loc: torch.Size([1, 4]), scale: torch.Size([1, 4])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:\ntensor([[inf, inf, 0., 0.]])').

我的行为具有形状 (4,) 和观察 (3,)。

它是否认为无穷大不>0,或者0不大于0?我不知道为什么会出现这种情况。它是简单地使用 model.learn 在稳定的基线 3 中训练模型。然而,它学习了一段时间,但在这一步失败了:

~\anaconda3\envs\\lib\site-packages\stable_baselines3\common\on_policy_algorithm.py in learn(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)
    257 
    258         while self.num_timesteps < total_timesteps:
--> 259             continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
    260 
    261             if continue_training is False:

~\anaconda3\envs\\lib\site-packages\stable_baselines3\common\on_policy_algorithm.py in collect_rollouts(self, env, callback, rollout_buffer, n_rollout_steps)
    167                 # Convert to pytorch tensor or to TensorDict
    168                 obs_tensor = obs_as_tensor(self._last_obs, self.device)
--> 169                 actions, values, log_probs = self.policy(obs_tensor)
    170             actions = actions.cpu().numpy()
    171 

~\anaconda3\envs\\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
   1192         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used
   1196         full_backward_hooks, non_full_backward_hooks = [], []

~\anaconda3\envs\\lib\site-packages\stable_baselines3\common\policies.py in forward(self, obs, deterministic)
    624         # Evaluate the values for the given observations
    625         values = self.value_net(latent_vf)
--> 626         distribution = self._get_action_dist_from_latent(latent_pi)
    627         actions = distribution.get_actions(deterministic=deterministic)
    628         log_prob = distribution.log_prob(actions)

~\anaconda3\envs\\lib\site-packages\stable_baselines3\common\policies.py in _get_action_dist_from_latent(self, latent_pi)
    654 
    655         if isinstance(self.action_dist, DiagGaussianDistribution):
--> 656             return self.action_dist.proba_distribution(mean_actions, self.log_std)
    657         elif isinstance(self.action_dist, CategoricalDistribution):
    658             # Here mean_actions are the logits before the softmax

~\anaconda3\envs\\lib\site-packages\stable_baselines3\common\distributions.py in proba_distribution(self, mean_actions, log_std)
    162         """
    163         action_std = th.ones_like(mean_actions) * log_std.exp()
--> 164         self.distribution = Normal(mean_actions, action_std)
    165         return self
    166 

~\anaconda3\envs\\lib\site-packages\torch\distributions\normal.py in __init__(self, loc, scale, validate_args)
     54         else:
     55             batch_shape = self.loc.size()
---> 56         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     57 
     58     def expand(self, batch_shape, _instance=None):

~\anaconda3\envs\\lib\site-packages\torch\distributions\distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     55                 if not valid.all():
     56                     raise ValueError(
---> 57                         f"Expected parameter {param} "
     58                         f"({type(value).__name__} of shape {tuple(value.shape)}) "
     59                         f"of distribution {repr(self)} "

记住我的行动是0<=a<=1. Do I need to make it 0

我很难知道它到底在抱怨什么,因为这段代码位于稳定基线的深处 3.这可能是他们的包中的一个小故障吗?我希望它更新权重并继续运行,但它却抱怨 0 不大于 0.. 我不知道为什么我关心这个,但它应该继续运行,不是吗?

感谢您的浏览。

machine-learning pytorch artificial-intelligence openai-gym
1个回答
0
投票

我解决了这个问题。我的问题是,我使用的健身房环境奖励持续的行为,并且由于如果动作是恒定的,则动作没有任何标准偏差,这就是产生该错误的原因。代码现在可以运行了!

© www.soinside.com 2019 - 2024. All rights reserved.