import collections
import tensorflow as tf
tf.compat.v1.enable_v2_behavior()
import tensorflow_federated as tff
dataset_paths = {
'client_0': '/tmp/A.txt',
'client_1': '/tmp/B.txt',
'client_2': '/tmp/C.txt',
}
def create_tf_dataset_for_client_fn(id):
path = dataset_paths.get(id)
if path is None:
raise ValueError(f'No dataset for client {id}')
return tf.data.Dataset.TextLineDataset(path)
source = tff.simulation.ClientData.from_clients_and_fn(
dataset_paths.keys(), create_tf_dataset_for_client_fn)
def client_data(n):
ds = source.create_tf_dataset_for_client(source.client_ids[n])
return ds
train_data = [client_data(n) for n in range(10)]
batch = tf.nest.map_structure(lambda x: x.numpy(), next(iter(train_data[0])))
......
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state, metrics = iterative_process.next(state, train_data)
print('round 1, metrics={}'.format(metrics))
当我执行此行时
state, metrics = iterative_process.next(state, train_data)
print('round 1, metrics={}'.format(metrics))
内核在崩溃并重新启动后花费了大量时间来运行
我是Tensorflow联合的初学者,用tff编写代码后,当我想开始训练第一轮时,内核崩溃并重新启动,但我不明白为什么。这是我的一部分...