'TypeError: argument of type 'Adam' is not iterable' 是什么意思?

问题描述 投票:0回答:0

你好,我正在尝试构建一个模型,该模型将从低分辨率和高分辨率图像的输入输出超分辨率图像,我的第一个错误是:UserWarning: Discrepancy between trainable weights and collected trainable weights, did you set

model.trainable
without calling
model.compile
之后?,首先我尝试以这种方式解决,UserWarning: Discrepancy between trainable weights and collected trainable weights error(UserWarning: Discrepancy between trainable weights and collected trainable weights error) 但它没有用,因为 ISR 不显然没有“编译”属性。阅读完文档后,我确实认为我明白了,但这次我遇到了这个错误,我知道不可迭代一般意味着什么,我只是不明白这里有什么关系。

import os
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from ISR.models import RRDN, Discriminator
from ISR.models import Cut_VGG19
from ISR.train import Trainer

# Define paths for HR and LR input images
lr_train_dir = 'C:/data/lr_train_150/',
hr_train_dir = 'C:/data/hr_train_150/',
lr_valid_dir = 'C:/data/lr_test_150/',
hr_valid_dir = 'C:/data/hr_test_150/',

lr_train_patch_size = 22 #size of my LR image
layers_to_extract = [5, 9]
scale = 10
hr_train_patch_size = lr_train_patch_size * scale # 220 Size of my HR image


# Instantiate models
rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)
discriminator = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)
feature_extractor = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)


# Define optimizer and loss function
optimizer = Adam(1e-4, beta_1=0.9, beta_2=0.999)
#loss = 'mse'
loss_weights = {
  'generator': 0.0,
  'feature_extractor': 0.0833,
  'discriminator': 0.01
}
losses = {
  'generator': 'mae',
  'feature_extractor': 'mse',
  'discriminator': 'binary_crossentropy'
} 
learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}
log_dirs = {'logs': './logs', 'weights': './weights'}
flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}


# Define trainer
trainer = Trainer(generator=rrdn, 
                  discriminator=discriminator,
                  feature_extractor=feature_extractor, 
                  #name='srgan', 
                  log_dirs=log_dirs,
                  #checkpoint_dir='./models', 
                  learning_rate=learning_rate, 
                  losses=losses, 
                  flatness = flatness,
                  loss_weights = loss_weights,
                  adam_optimizer=optimizer,
                  lr_train_dir = 'C:/data/lr_train_150/',
                  hr_train_dir = 'C:/data/hr_train_150/',
                  lr_valid_dir = 'C:/data/lr_test_150/',
                  hr_valid_dir = 'C:/data/hr_test_150/'
                 )

# Train the model
trainer.train(batch_size=16, 
              steps_per_epoch=20, 
              #validation_steps=10, 
              epochs=1, 
              #print_frequency=100
              monitored_metrics={'val_generator_PSNR_Y': 'max'}
             )

我得到的错误:

TypeError                                 Traceback (most recent call last)
Cell In[18], line 52
     46 flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}
     47 # Define feature extractor
     48 #vgg = Model(inputs=rrdn.input, outputs=rrdn.get_layer('features').output)
     49 #vgg.trainable = False
     50 
     51 # Define trainer
---> 52 trainer = Trainer(generator=rrdn, 
     53                   discriminator=discriminator,
     54                   feature_extractor=feature_extractor, 
     55                   #name='srgan', 
     56                   log_dirs=log_dirs,
     57                   #checkpoint_dir='./models', 
     58                   learning_rate=learning_rate, 
     59                   losses=losses, 
     60                   flatness = flatness,
     61                   loss_weights = loss_weights,
     62                   adam_optimizer=optimizer,
     63                   lr_train_dir = 'C:/data/lr_train_150/',
     64                   hr_train_dir = 'C:/data/hr_train_150/',
     65                   lr_valid_dir = 'C:/data/lr_test_150/',
     66                   hr_valid_dir = 'C:/data/hr_test_150/'
     67                  )
     69 # Train the model
     70 trainer.train(train_lr_dir=lr_train_dir, train_hr_dir=hr_train_dir, 
     71               valid_lr_dir=lr_valid_dir, valid_hr_dir=hr_valid_dir, 
     72               batch_size=16, 
   (...)
     77               monitored_metrics={'val_generator_PSNR_Y': 'max'}
     78              )

File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\train\trainer.py:104, in Trainer.__init__(self, generator, discriminator, feature_extractor, lr_train_dir, hr_train_dir, lr_valid_dir, hr_valid_dir, loss_weights, log_dirs, fallback_save_every_n_epochs, dataname, weights_generator, weights_discriminator, n_validation, flatness, learning_rate, adam_optimizer, losses, metrics)
    102 elif self.metrics['generator'] == 'PSNR':
    103     self.metrics['generator'] = PSNR
--> 104 self._parameters_sanity_check()
    105 self.model = self._combine_networks()
    107 self.settings = {}

File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\train\trainer.py:163, in Trainer._parameters_sanity_check(self)
    151 check_parameter_keys(
    152     self.learning_rate,
    153     needed_keys=['initial_value'],
    154     optional_keys=['decay_factor', 'decay_frequency'],
    155     default_value=None,
    156 )
    157 check_parameter_keys(
    158     self.flatness,
    159     needed_keys=[],
    160     optional_keys=['min', 'increase_frequency', 'increase', 'max'],
    161     default_value=0.0,
    162 )
--> 163 check_parameter_keys(
    164     self.adam_optimizer,
    165     needed_keys=['beta1', 'beta2'],
    166     optional_keys=['epsilon'],
    167     default_value=None,
    168 )
    169 check_parameter_keys(self.log_dirs, needed_keys=['logs', 'weights'])

File ~\anaconda3\envs\img_tf\lib\site-packages\ISR\utils\utils.py:45, in check_parameter_keys(parameter, needed_keys, optional_keys, default_value)
     43 if needed_keys:
     44     for key in needed_keys:
---> 45         if key not in parameter:
     46             logger.error('{p} is missing key {k}'.format(p=parameter, k=key))
     47             raise

TypeError: argument of type 'Adam' is not iterable

ISR 文档 协作上的基本 ISR 实现 当我在协作上尝试此操作时,我得到了相同的警告:用户警告:可训练权重与收集的可训练权重之间存在差异,您是否设置了

model.trainable
而没有在之后调用
model.compile
? '可训练重量与收集的可训练重量之间的差异' 我不认为代码会按照预期的方式工作,但我不能确定,因为我是新手。

python-3.x tensorflow image-processing interrupt
© www.soinside.com 2019 - 2024. All rights reserved.