如何使用张量作为选择返回值的条件?

问题描述 投票:0回答:1

我正在尝试在tensorflow.keras中创建一个参数激活层,其中权重定义要使用的激活函数,使用if语句和tf.clip_by_value来帮助网络选择要使用的激活函数。但是,当尝试使用模型的'.predict()'方法时,出现以下错误:“不允许将tf.tensor用作Python bool。”

代码:

class Pact(L.Layer):
    def __init__(self, **kwargs):
        super(Pact, self).__init__(**kwargs)

        self.weight = self.add_weight(shape=(1,1),
                                    initializer='random_normal',
                                    )
    def call(self, inputs):
        weight = tf.clip_by_value(self.weight, 0, 1)
        if weight < 0.1:
            return tf.sin(inputs)
        elif weight < 0.2:
            return tf.math.maximum(0.0, inputs)
        elif weight < 0.3:
            return tf.nn.swish(inputs)
        elif weight < 0.4:
            return tf.math.sigmoid(inputs)
        elif weight < 0.5:
            return tf.math.pow(inputs, 2.0)
        elif weight < 0.6:
            return tf.math.cos(inputs) - inputs
        elif weight < 0.7:
            return tf.math.exp(inputs)
        elif weight < 0.8:
            return -(inputs)
        elif weight < 0.9:
            return tf.math.pow(inputs, 3.0)
        else:
            return inputs

追踪:

Traceback (most recent call last):
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 778, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/autograph/impl/api.py", line 292, in wrapper
    return func(*args, **kwargs)
  File "/home/ai/Projects/Class/layers.py", line 99, in call
    if weight < 0.1:
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 757, in __bool__
    self._disallow_bool_casting()
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 520, in _disallow_bool_casting
    "using a `tf.Tensor` as a Python `bool`")
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 505, in _disallow_when_autograph_disabled
    " Try decorating it directly with @tf.function.".format(task))
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph is disabled in this function. Try decorating it directly with @tf.function.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "ddpg.py", line 137, in <module>
    action = ddpg.get_target_action(state)
  File "ddpg.py", line 63, in get_target_action
    return self.target_actor.predict(state)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py", line 1013, in predict
    use_multiprocessing=use_multiprocessing)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 498, in predict
    workers=workers, use_multiprocessing=use_multiprocessing, **kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 475, in _model_iteration
    total_epochs=1)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py", line 128, in run_one_epoch
    batch_outs = execution_function(iterator)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 98, in execution_function
    distributed_function(input_fn))
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 568, in __call__
    result = self._call(*args, **kwds)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 615, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 497, in _initialize
    *args, **kwds))
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2389, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2703, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py", line 2593, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py", line 978, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py", line 439, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 85, in distributed_function
    per_replica_function, args=args)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 763, in experimental_run_v2
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 1819, in call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/distribute/distribute_lib.py", line 2164, in _call_for_each_replica
    return fn(*args, **kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/autograph/impl/api.py", line 292, in wrapper
    return func(*args, **kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 212, in _predict_on_batch
    result = predict_on_batch(model, x)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2_utils.py", line 556, in predict_on_batch
    return predict_on_batch_fn(inputs)  # pylint: disable=not-callable
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 778, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py", line 717, in call
    convert_kwargs_to_constants=base_layer_utils.call_context().saving)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py", line 891, in _run_internal_graph
    output_tensors = layer(computed_tensors, **kwargs)
  File "/home/ai/anaconda3/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py", line 785, in __call__
    str(e) + '\n"""')
TypeError: You are attempting to use Python control flow in a layer that was not declared to be dynamic. Pass `dynamic=True` to the class constructor.
Encountered error:
"""
using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph is disabled in this function. Try decorating it directly with @tf.function.
"""

我正在尝试在tensorflow.keras中创建一个参数化激活层,其中权重定义要使用的激活函数,使用if语句和tf.clip_by_value来帮助网络选择...

python tensorflow keras activation-function
1个回答
0
投票
对您的问题的矢量化解决方案,或者至少是可能的解决方案,是对每个条件使用tf.where
© www.soinside.com 2019 - 2024. All rights reserved.