InvalidArgumentError:整形的输入是具有0值的张量,但请求的形状为54912

问题描述 投票:0回答:1

非常初学者的问题,希望很好

[我正在尝试使用MAPS数据集从GitHub训练此model,并使用此代码为训练集创建了新的.tfrecords。它基于代码here,但我更改了一些东西以让路用于其他输入(另一个MIDI文件,我只是称其为“速度MIDI”)。

def create_train_set(tempopath, train_list, outdir, min_length, max_length):
  # train_list = list of wav paths selected for  

  train_file_pairs = []

  # find matching midi files

  for wav_path in train_list:
    midi_file = ''
    tempo_midi_file = ''

    if os.path.isfile(wav_path + '.mid'):
      midi_file = wav_path + '.mid'
    if os.path.isfile(wav_path + '.midi'):
      midi_file = wav_path + '.midi'

    if os.path.isfile(tempopath + os.path.basename(wav_path) + '_tempo.mid'):
      tempo_midi_file = tempopath + os.path.basename(wav_path) + '_tempo.mid'
    if os.path.isfile(tempopath + os.path.basename(wav_path) + '_tempo.midi'):
      tempo_midi_file = tempopath + os.path.basename(wav_path) + '_tempo.midi'

    wav_file = wav_path + '.wav'   
    train_file_pairs.append((wav_file, midi_file, tempo_midi_file))

  train_output_name = os.path.join(outdir, 'train.tfrecord')

  with tf.python_io.TFRecordWriter(train_output_name) as writer:
    for idx, pair in enumerate(train_file_pairs):
      print('{} of {}: {}'.format(idx, len(train_file_pairs), pair[0]))
      # load the wav data
      wav_data = tf.gfile.Open(pair[0], 'rb').read()
      # load the midi data and convert to a notesequence
      ns = midi_io.midi_file_to_note_sequence(pair[1])
      tempo = midi_io.midi_file_to_note_sequence(pair[2])
      # aldu = audio_label_data_utils.py
      for example in aldu.process_record(          
          wav_data, ns, tempo, pair[0], min_length, max_length,
          sample_rate):       
        writer.write(example.SerializeToString())

使用tf。示例如下:

  example = tf.train.Example(
      features=tf.train.Features(
          feature={
              'id':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(
                          value=[example_id.encode('utf-8')])),
              'sequence':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(
                          value=[ns.SerializeToString()])),
              'audio':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(value=[wav_data])),
              'tempo':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(
                          value=[velocity_range.SerializeToString()])),                        
              'velocity_range':
                  tf.train.Feature(
                      bytes_list=tf.train.BytesList(
                          value=[velocity_range.SerializeToString()])),          
          })) 

但是,当我尝试训练模型时,收到此错误消息(我用打印行标记了py脚本,所以我知道一切在哪里进行了:]

Running wav_to_spec from data.py
Running _wav_to_mel in data.py
Running wav_to_num_frames from data.py
Running wav_to_spec from data.py
Running _wav_to_mel in data.py
Running wav_to_num_frames from data.py

E0611 07:56:55.419340  8436 error_handling.py:70] Error recorded from training_loop: Input to reshape is a tensor with 0 values, but the requested shape has 54912
         [[{{node Reshape_8}}]]
         [[IteratorGetNext]]
I0611 07:56:55.420338  8436 error_handling.py:96] training_loop marked as finished
W0611 07:56:55.421335  8436 error_handling.py:130] Reraising captured error
Traceback (most recent call last):
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1356, in _do_call
    return fn(*args)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 0 values, but the requested shape has 54912
         [[{{node Reshape_8}}]]
         [[IteratorGetNext]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "onsets_frames_transcription_train.py", line 128, in <module>
    console_entry_point()
  File "onsets_frames_transcription_train.py", line 124, in console_entry_point
    tf.app.run(main)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\platform\app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\absl\app.py", line 300, in run
    _run_main(main, args)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\absl\app.py", line 251, in _run_main
    sys.exit(main(argv))
  File "onsets_frames_transcription_train.py", line 120, in main
    additional_trial_info=additional_trial_info)
  File "onsets_frames_transcription_train.py", line 95, in run
    num_steps=FLAGS.num_steps)
  File "C:\Users\User\magenta\magenta\models\onsets_frames_transcription\train_util.py", line 134, in train
    estimator.train(input_fn=transcription_data, max_steps=num_steps)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\tpu_estimator.py", line 2876, in train
    rendezvous.raise_errors()
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\error_handling.py", line 131, in raise_errors
    six.reraise(typ, value, traceback)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\six.py", line 693, in reraise
    raise value
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\tpu\tpu_estimator.py", line 2871, in train
    saving_listeners=saving_listeners)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 367, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1158, in _train_model
    return self._train_model_default(input_fn, hooks, saving_listeners)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1192, in _train_model_default
    saving_listeners)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow_estimator\python\estimator\estimator.py", line 1484, in _train_with_estimator_spec
    _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 754, in run
    run_metadata=run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1252, in run
    run_metadata=run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1353, in run
    raise six.reraise(*original_exc_info)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\six.py", line 693, in reraise
    raise value
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1338, in run
    return self._sess.run(*args, **kwargs)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1411, in run
    run_metadata=run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\training\monitored_session.py", line 1169, in run
    return self._sess.run(*args, **kwargs)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 950, in run
    run_metadata_ptr)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1173, in _run
    feed_dict_tensor, options, run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1350, in _do_run
    run_metadata)
  File "C:\Users\User\AppData\Local\Programs\Python\Python36\lib\site-packages\tensorflow\python\client\session.py", line 1370, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 0 values, but the requested shape has 54912
         [[{{node Reshape_8}}]]
         [[IteratorGetNext]]

由此,我发现问题出在wav_to_num_frames,但这是唯一的代码。

def wav_to_num_frames(wav_audio, frames_per_second):
  """Transforms a wav-encoded audio string into number of frames."""
  print("Running wav_to_num_frames from data")
  w = wave.open(six.BytesIO(wav_audio))
  return np.int32(w.getnframes() / w.getframerate() * frames_per_second)

当我尝试使用由原始代码创建的tfrecords训练模型时,我没有遇到这个问题,所以我不知道出了什么问题。

python tensorflow deep-learning tensorflow-estimator tfrecord
1个回答
0
投票

事实证明,问题不是创建的.tfrecords本身,而是我为新添加的数据分配的张量的大小。对此没有具体的答案,因为它非常适合这种情况。

© www.soinside.com 2019 - 2024. All rights reserved.