我在for循环中创建了一个tf.data.Dataset
,我注意到每次迭代后内存都没有被释放。
有没有办法从TensorFlow请求释放内存?
我尝试使用tf.reset_default_graph()
,我尝试在相关的python对象上调用del
,但这不起作用。
似乎唯一有效的是gc.collect()
。不幸的是,gc.collect
不适用于一些更复杂的例子。
完全可重现的代码:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import psutil
%matplotlib inline
memory_used = []
for i in range(500):
data = tf.data.Dataset.from_tensor_slices(
np.random.uniform(size=(10, 500, 500)))\
.prefetch(64)\
.repeat(-1)\
.batch(3)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
sess.run(next_element)
memory_used.append(psutil.virtual_memory().used / 2 ** 30)
tf.reset_default_graph()
plt.plot(memory_used)
plt.title('Evolution of memory')
plt.xlabel('iteration')
plt.ylabel('memory used (GB)')
您正在创建循环迭代的新python对象(数据集),看起来没有调用垃圾收集器。添加impplicit垃圾收集调用,内存使用情况应该没问题。
除此之外,如其他答案所述,继续在循环之外构建数据对象和会话。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import psutil
import gc
%matplotlib inline
memory_used = []
for i in range(100):
data = tf.data.Dataset.from_tensor_slices(
np.random.uniform(size=(10, 500, 500)))\
.prefetch(64)\
.repeat(-1)\
.batch(3)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
sess.run(next_element)
memory_used.append(psutil.virtual_memory().used / 2 ** 30)
tf.reset_default_graph()
gc.collect()
plt.plot(memory_used)
plt.title('Evolution of memory')
plt.xlabel('iteration')
plt.ylabel('memory used (GB)')
数据集API通过内置迭代器处理迭代,至少在急切模式关闭或TF版本不是2.0时。因此,根本不需要从for循环中的numpy数组创建数据集对象,因为它将图中的值写为tf.constant
。 data = tf.data.TFRecordDataset()
不是这种情况,所以如果你将数据转换为tfrecords格式并在for循环中运行它就不会泄漏内存。
for i in range(500):
data = tf.data.TFRecordDataset('file.tfrecords')\
.prefetch(64)\
.repeat(-1)\
.batch(1)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
sess.run(next_element)
memory_used.append(psutil.virtual_memory().used / 2 ** 30)
tf.reset_default_graph()
但正如我所说,没有必要在循环内创建数据集。
data = tf.data.Dataset.from_tensor_slices(
np.random.uniform(size=(10, 500, 500)))\
.prefetch(64)\
.repeat(-1)\
.batch(3)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
for i in range(500):
with tf.Session() as sess:
...