使用Python线程使model_averaging在联邦学习中独占

问题描述 投票:0回答:1

我正在使用以下代码创建

num_of_clients
线程:

sockets_thread = []
no_of_client = 1

all_data = b""
while True:
    try:
        for i in range(no_of_client):
            connection, client_info = soc.accept() 
            print("\nNew Connection from {client_info}.".format(client_info=client_info))
            socket_thread = SocketThread(connection=connection,
                                     client_info=client_info, 
                                     buffer_size=1024,
                                     recv_timeout=100)
            sockets_thread.append(socket_thread)
        for i in range(no_of_client):    
            sockets_thread[i].start()
            sockets_thread[i].join()
    except:
        soc.close()
        print("(Timeout) Socket Closed Because no Connections Received.\n")
        break

run函数中,有几段代码,如下:

class SocketThread(object):
     def run(self):
           while True: 
                received_data, status = self.recv()
                if status == 0:
                    self.connection.close() 
                    break
     
                self.reply(received_data)

     def reply(self, received_data):
        model = SimpleASR()
        #all threads must averge the model before going to next line
        model_instance = self.model_averaging(model, model_instance)
        print("All threads completed model averging.")
        #now do rest of the things 

在回复函数中,我调用了一个函数。我想以这样的方式编写这段代码,即每个线程在调用此函数后都会继续到下一行。

每个线程必须对模型进行平均,然后继续执行下一行。我知道我必须使用Python条件变量。我怎样才能做到这一点?

以下函数必须是互斥的。

model_instance = self.model_averaging(model, model_instance)

每个线程执行完这段代码后都会进入下一行。

我编写此代码是实现联邦学习算法的一部分。

python multithreading mutex condition-variable federated-learning
1个回答
0
投票

t.join()
可能不会像你想象的那样:

for i in range(no_of_client):    
    sockets_thread[i].start()
    sockets_thread[i].join()

当您的主线程调用

sockets_thread[i].join()
时,该调用将在相关线程完成之前不会返回。您的循环仅在前一个线程完全完成其应该执行的操作后才启动下一个线程。

如果您希望线程同时运行,则不要

join
任何线程,直到您完成所有线程之后:
start

P.S.,如果你这样写,有些人可能会认为它看起来更Pythonic:

for i in range(no_of_client): sockets_thread[i].start() for i in range(no_of_client): sockets_thread[i].join()

如果你使用 Python 的 

for t in sockets_thread: t.start() for t in sockets_thread: t.join() 函数

,可能会更 Pythonic,但我现在不会去那里。

© www.soinside.com 2019 - 2024. All rights reserved.