如何在tensorflow 2.0中保存和加载选定的变量?

问题描述 投票:1回答:1

如何将文件中下面显示的tensorflow 2.0中的选定变量保存到另一个代码中的某些已定义变量中?

class manyVariables:
    def __init__(self):
        self.initList = [None]*100
        for i in range(100):
            self.initList[i] = tf.Variable(tf.random.normal([5,5]))
        self.makeSomeMoreVariables()

    def makeSomeMoreVariables(self):
        self.moreList = [None]*10
        for i in range(10):
            self.moreList[i] = tf.Variable(tf.random.normal([3,3]))

    def saveVariables(self):
        # how to save self.initList's 3,55 and 60th elements and self.moreList's 4th element
python tensorflow tensorflow2.0
1个回答
2
投票

在下面的代码中,我将一个名为variables的数组保存到一个带有您选择名称的.txt文件中。该文件与python文件位于同一文件夹中。 open函数中的'wb'表示使用截断进行写入(因此删除以前在文件中的所有内容)并使用字节格式。我使用pickle来处理保存/解析列表。

import pickle

    def saveVariables(self, variables): #where 'variables' is a list of variables
        with open("nameOfYourFile.txt", 'wb+') as file:
           pickle.dump(variables, file)

    def retrieveVariables(self, filename):
        variables = []
        with open(str(filename), 'rb') as file:
            variables = pickle.load(file)
        return variables

要将特定内容保存到文件中,只需将其添加为saveVariables中的变量参数,如下所示:

myVariables = [initList[2], initList[54], initList[59], moreList[3]]
saveVariables(myVariables)

要从具有特定名称的文本文件中检索变量:

myVariables = retrieveVariables("theNameOfYourFile.txt")
thirdEl = myVariables[0]
fiftyFifthEl = myVariables[1]
SixtiethEl = myVariables[2]
fourthEl = myVariables[3]

您可以在课程的任何位置添加这些功能。

但是,为了能够访问示例中的initList / moreList,您应该从函数中返回它们(就像我使用variables列表一样)或者将它们设置为全局。

© www.soinside.com 2019 - 2024. All rights reserved.