将tf.data.Dataset与tf集线器模块一起使用

问题描述 投票:0回答:1

我如何使用tf.data.Dataset来提供一个包含一维输入TF Hub模块的tf.keras模型?

尝试过此:

import tensorflow as tf import tensorflow_hub as hub embed = "https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1" hub_layer = hub.KerasLayer(embed, output_shape=[20], input_shape=[], dtype=tf.string, trainable=True, name='hub_layer') # From tf hub webpage: "The module takes a batch of sentences in a 1-D tensor of strings as input." input_tensor = tf.keras.Input(shape=(), dtype=tf.string) hub_tensor = hub_layer(input_tensor) x = tf.keras.layers.Dense(16, activation='relu')(hub_tensor)#(x) main_output = tf.keras.layers.Dense(units=4, activation='softmax', name='main_output')(x) model = tf.keras.models.Model(inputs=[input_tensor], outputs=[main_output]) # This works as expected. X_tensor = tf.constant(['Hello World', 'The Quick Brown Fox']) model(X_tensor) # This fails X_ds = tf.data.Dataset.from_tensors(X_tensor) X_ds.element_spec model(X_ds)

期望是数据集中的一维张量将被模型自动提取和使用。

错误消息:

--------------------------------------------------------------------------- ValueError Traceback (most recent call last) in 21 X_ds = tf.data.Dataset.from_tensors(X_tensor) 22 X_ds.element_spec ---> 23 model(X_ds) 24 25 ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs) 966 with base_layer_utils.autocast_context_manager( 967 self._compute_dtype): --> 968 outputs = self.call(cast_inputs, *args, **kwargs) 969 self._handle_activity_regularization(inputs, outputs) 970 self._set_mask_metadata(inputs, outputs, input_masks) ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py in call(self, inputs, training, mask) 717 return self._run_internal_graph( 718 inputs, training=training, mask=mask, --> 719 convert_kwargs_to_constants=base_layer_utils.call_context().saving) 720 721 def compute_output_shape(self, input_shape): ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py in _run_internal_graph(self, inputs, training, mask, convert_kwargs_to_constants) 835 tensor_dict = {} 836 for x, y in zip(self.inputs, inputs): --> 837 y = self._conform_to_reference_input(y, ref_input=x) 838 x_id = str(id(x)) 839 tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py in _conform_to_reference_input(self, tensor, ref_input) 959 # Dtype handling. 960 if isinstance(ref_input, (ops.Tensor, composite_tensor.CompositeTensor)): --> 961 tensor = math_ops.cast(tensor, dtype=ref_input.dtype) 962 963 return tensor ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs) 178 """Call target, and fall back on dispatchers if there is a TypeError.""" 179 try: --> 180 return target(*args, **kwargs) 181 except (TypeError, ValueError): 182 # Note: convert_to_eager_tensor currently raises a ValueError, not a ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py in cast(x, dtype, name) 785 # allows some conversions that cast() can't do, e.g. casting numbers to 786 # strings. --> 787 x = ops.convert_to_tensor(x, name="x") 788 if x.dtype.base_dtype != base_type: 789 x = gen_math_ops.cast(x, base_type, name=name) ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types) 1339 1340 if ret is None: -> 1341 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 1342 1343 if ret is NotImplemented: ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in _constant_tensor_conversion_function(v, dtype, name, as_ref) 319 as_ref=False): 320 _ = as_ref --> 321 return constant(v, dtype=dtype, name=name) 322 323 ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in constant(value, dtype, shape, name) 260 """ 261 return _constant_impl(value, dtype, shape, name, verify_shape=False, --> 262 allow_broadcast=True) 263 264 ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in _constant_impl(value, dtype, shape, name, verify_shape, allow_broadcast) 268 ctx = context.context() 269 if ctx.executing_eagerly(): --> 270 t = convert_to_eager_tensor(value, ctx, dtype) 271 if shape is None: 272 return t ~/projects/email_analysis/email_venv/lib/python3.6/site-packages/tensorflow/python/framework/constant_op.py in convert_to_eager_tensor(value, ctx, dtype) 94 dtype = dtypes.as_dtype(dtype).as_datatype_enum 95 ctx.ensure_initialized() ---> 96 return ops.EagerTensor(value, ctx.device_name, dtype) 97 98 ValueError: Attempt to convert a value () with an unsupported type () to a Tensor.

tensorflow2.0 tensorflow-datasets tf.keras tensorflow-hub
1个回答
0
投票
数据集的要点是提供张量的

序列,例如此处:

all_data = tf.constant([['Hello', 'World'], ['Brown Fox', 'lazy dog']]) ds = tf.data.Dataset.from_tensor_slices(all_data) for tensor in ds: print(tensor)
输出

tf.Tensor([b'Hello' b'World'], shape=(2,), dtype=string) tf.Tensor([b'Brown Fox' b'lazy dog'], shape=(2,), dtype=string)

不仅可以打印tensor,还可以用它来计算:

for tensor in ds: print(hub_layer(tensor))

每个输出2个形状为(2,20)的张量。

有关更多信息,请参见https://www.tensorflow.org/guide/data

© www.soinside.com 2019 - 2024. All rights reserved.