最近我正在研究一个有关客户端-服务器设置中的安全大型模型推理的项目。在这种情况下,客户端和服务器必须进行大量通信才能进行协作计算。
例如,服务器会频繁地向客户端发送模型参数(即一些矩阵),这就需要在TCP下通过socket可靠地传输大量数据。重复调用
socket.socket.sendall()
和socket.socket.recv()
的直接方法并不合适,因为接收者很难找到一个数据与下一个数据之间的边界(TCP提供字节流服务)。
我这样写:
import socket
import pickle
import torch
class BetterSocket:
def __init__(self, s):
self.socket = s
self.msg_len = 2 ** 12
def sendall(self, obj):
pkl = pickle.dumps(obj)
l = len(pkl)
init_l = len(pkl)
lbs = l.to_bytes(4)
print(f"With LBS = {lbs}")
self.socket.sendall(lbs)
print("Going to send...")
while l > self.msg_len:
self.socket.sendall(pkl[0:self.msg_len])
pkl = pkl[self.msg_len:]
l = l - self.msg_len
self.socket.recv(len(b'0'))
self.socket.sendall(pkl)
print(f"Done. {4 + init_l} bytes have been sent.")
def recv(self):
print("Waiting to receive...")
lbs = self.socket.recv(4)
print(f"LBS {lbs} received.")
l = int.from_bytes(lbs)
print(f"Length {l} received.")
pkl = b''
while l > self.msg_len:
pkl = pkl + self.socket.recv(self.msg_len)
l = l - self.msg_len
self.socket.sendall(b'0')
pkl = pkl + self.socket.recv(l)
obj = pickle.loads(pkl)
return obj
此类有助于传输任何类型的对象。每个“数据单元”的头部附有4个字节,表示数据的整体大小。
while
循环中的“反向确认”似乎是为了平衡发送和接收的速度所必需的,否则会抛出_pickle.UnpicklingError: pickle data was truncated
,这表明接收方错误地分割了字节流。 self.msg_len
设置太大也会出现同样的错误。
但是,上述解决方案太慢了。由于客户端和服务端经常进行“反向确认”,传输一个50000*800的矩阵大约需要5分钟。如何在不严重牺牲性能的情况下正确进行数据传输?请帮助我!
(对不起我的英语...)
pickle 协议已经包含有关单个 pickle 转储大小的信息。有一个套接字方法
socket.makefile
可以将套接字包装在类似文件的对象中,该对象可以直接与 pickle.dump
和 pickle.load
一起使用,从而使读取和写入 pickled 对象变得容易。
这是一个例子:
服务器.py
import socket
import pickle
with socket.socket() as s:
s.bind(('', 5000))
s.listen()
while True:
client, addr = s.accept()
with client, client.makefile('rb') as rfile:
while True:
try:
obj = pickle.load(rfile)
print(f'{addr}: {obj}')
except EOFError: # raised by pickle.load when socket is closed
break
客户端.py
import socket
import pickle
def send_message(sock, obj):
wfile.write(pickle.dumps(obj))
wfile.flush() # ensures buffered writes are sent to socket.
with socket.socket() as s:
s.connect(('localhost', 5000))
with s.makefile('wb') as wfile:
send_message(wfile, [1, 2, 3, 'abc', 'def'])
send_message(wfile, [complex(1,2), complex(3,4)])
send_message(wfile, dict(zip('abc def ghi'.split(), [123, 456, 789])))
客户端运行一次后服务器的输出:
('127.0.0.1', 3010): [1, 2, 3, 'abc', 'def']
('127.0.0.1', 3010): [(1+2j), (3+4j)]
('127.0.0.1', 3010): {'abc': 123, 'def': 456, 'ghi': 789}