tf.function + 自定义训练函数导致内存泄漏

问题描述 投票:0回答:1

我得到了以下型号

class FRAE(tf.keras.Model):
def __init__(self, latent_dim, shape, ht, n1, n2, n3, n4, bypass=False, trainable=True,**kwargs):
    super(FRAE, self).__init__(**kwargs)
    self.latent_dim = latent_dim
    self.shape = shape
    self.ht = ht
    self.buffer = tf.Variable(initial_value=tf.zeros(shape=(1,shape[0] * self.ht), dtype=tf.float32), trainable=False)
    self.bypass = bypass
    self.quantizer = None
    self.trainable = trainable
    
    self.l1 = tf.keras.layers.Dense(n1, activation='tanh', input_shape=shape)
    self.l2 = tf.keras.layers.Dense(n1, activation='tanh')
    self.ls = tf.keras.layers.Dense(latent_dim, activation='swish')

    self.l3 = tf.keras.layers.Dense(n3, activation='tanh')
    self.l4 = tf.keras.layers.Dense(n4, activation='tanh')
    self.l5 = tf.keras.layers.Dense(shape[-1], activation='linear')


def get_config(self):
    config = super(FRAE,self).get_config().copy()
    config.update({'latent_dim':self.latent_dim, 'bypass':self.bypass, 'quantizer':self.quantizer, 
                   "encoder":self.encoder, "buffer":self.buffer,
                   'decoder':self.decoder,"ht":self.ht, "shape":self.shape, "name":self.name})        
    
    return config
      
@tf.function(experimental_compile=True)
def update_buffer(self, new_element):
    n = self.shape[0]
    self.buffer.assign(tf.keras.backend.concatenate([new_element, self.buffer[:, :-n]], axis=1))

@tf.function(experimental_compile=True)
def resetBuffer(self):
    self.buffer[:,:].assign(tf.zeros(shape=(1,self.shape[0] * self.ht), dtype=tf.float32))

@tf.function(experimental_compile=True)
def call(self, x):        
    x = tf.squeeze(x,axis=0)
    decoded = tf.TensorArray(tf.float32, size=tf.shape(x)[0])
    for i in tf.range(tf.shape(x)[0]):

        xexpand = tf.expand_dims(x[i],axis=0)
        xin = tf.concat((xexpand, self.buffer), axis=1) # xin = tf.concat((xexpand, self.buffer), axis=1)

        encoded = self.ls(self.l2(self.l1(xin)))
        decin = tf.concat([encoded, self.buffer], axis=1)
        y = self.l5(self.l4(self.l3(decin)))
        decoded = decoded.write(i,y)
        i += 1
        # self.update_buffer(tf.squeeze(y))
        self.update_buffer(y)


    tmp = tf.transpose(decoded.stack(),[1,0,2])
    return tmp

@tf.function(experimental_compile=True)
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
    x, y = data

    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compute_loss(y=y, y_pred=y_pred)

    # Compute gradients
    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    # Update weights
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    for metric in self.metrics:
        if metric.name == "loss":
            metric.update_state(loss)
        else:
            metric.update_state(y, y_pred)
    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

最终运行良好。然而,我注意到当我训练这个模型时,我的程序的内存使用量每隔几秒就会增加。最有可能的原因是我正在使用的缓冲区,它是存储中间输出所必需的。目前,我使用 tf.Variable 来实现此目的。 更准确地说,这条线

        self.buffer.assign(tf.keras.backend.concatenate([new_element, self.buffer[:, :-n]], axis=1))

似乎是原因。还有什么替代方案吗?偶尔调用 gc.collect() 没有任何反应。我使用的是tensorflow 2.0,所以没有tf.placeholder。我可以在我的情况下使用什么?

编辑: 我刚刚测试过,简单的预测不会增加内存消耗。所以泄漏是在训练期间发生的。我的培训电话是

    stoploss = trainutil.stopAtLossValue()
    resetbuffercallback = trainutil.ResetBufferCallback(frae)
    
    
    frae.fit(train_enc_left, train_enc_left, batch_size=1, epochs=10, callbacks=[stoploss, resetbuffercallback], verbose=1)

回调定义为

class ResetBufferCallback(tf.keras.callbacks.Callback):
def __init__(self):
    super(ResetBufferCallback, self).__init__()

def on_batch_end(self, batch, logs=None):
    self.model.resetBuffer()
   
class stopAtLossValue(Callback):
        def on_batch_end(self, batch, logs={}):
            THR = 10 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') > THR or math.isnan(logs.get('loss')) is True:
                 self.model.stop_training = True

我不明白回调如何导致内存泄漏,因此自定义training_step可能会导致此问题。

编辑: 因此,如果您使用@tf.function,显然在这种情况下会出现内存泄漏。请参阅https://github.com/tensorflow/tensorflow/issues/50765 非常不幸,因为我需要编译它以使模型达到适合训练的速度。如果不编译训练步骤函数,我就无法编译它。 有谁知道如何解决这个问题吗?

tensorflow memory-leaks
1个回答
0
投票

内存泄漏给您带来了什么问题?除了等待张量流解决根本问题之外,您不一定可以采取很多措施来阻止泄漏本身,但是有一些缓解策略。

如果问题是您在训练过程中内存不足,那么您可以获取更多内存,或者进行其他更改以减少内存占用,足以完成训练。

如果问题是训练后您没有足够的内存并且无法释放内存,那么一种选择是在单独的进程中运行建模训练,保存模型,然后终止该进程。此时内存将被释放,您可以重新加载模型。

© www.soinside.com 2019 - 2024. All rights reserved.