-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathserver.py
143 lines (111 loc) · 4.37 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#!/usr/bin/env python3
# modified from https://github.com/zhengzangw/Fed-SINGA/blob/main/src/server/app.py
import socket
from collections import defaultdict
from typing import Dict, List
from singa import tensor
from .proto import interface_pb2 as proto
from .proto.utils import parseargs
from .proto import utils
class Server:
"""Server sends and receives protobuf messages.
Create and start the server, then use pull and push to communicate with clients.
Attributes:
num_clients (int): Number of clients.
host (str): Host address of the server.
port (str): Port of the server.
sock (socket.socket): Socket of the server.
conns (List[socket.socket]): List of num_clients sockets.
addrs (List[str]): List of socket address.
weights (Dict[Any]): Weights stored on server.
"""
def __init__(
self,
num_clients=1,
host: str = "127.0.0.1",
port: str = 1234,
) -> None:
"""Class init method
Args:
num_clients (int, optional): Number of clients in training.
host (str, optional): Host ip address. Defaults to '127.0.0.1'.
port (str, optional): Port. Defaults to 1234.
"""
self.num_clients = num_clients
self.host = host
self.port = port
self.sock = socket.socket()
self.conns = [None] * num_clients
self.addrs = [None] * num_clients
self.weights = {}
def __start_connection(self) -> None:
"""Start the network connection of server."""
self.sock.bind((self.host, self.port))
self.sock.listen()
print("Server started.")
def __start_rank_pairing(self) -> None:
"""Start pair each client to a global rank"""
for _ in range(self.num_clients):
conn, addr = self.sock.accept()
rank = utils.receive_int(conn)
self.conns[rank] = conn
self.addrs[rank] = addr
print(f"[Server] Connected by {addr} [global_rank {rank}]")
assert None not in self.conns
def start(self) -> None:
"""Start the server.
This method will first bind and listen on the designated host and port.
Then it will connect to num_clients clients and maintain the socket.
In this process, each client shall provide their rank number.
"""
self.__start_connection()
self.__start_rank_pairing()
def close(self) -> None:
"""Close the server."""
self.sock.close()
def aggregate(self, weights: Dict[str, List[tensor.Tensor]]) -> Dict[str, tensor.Tensor]:
"""Aggregate collected weights to update server weight.
Args:
weights (Dict[str, List[tensor.Tensor]]): The collected weights.
Returns:
Dict[str, tensor.Tensor]: Updated weight stored in server.
"""
for k, v in weights.items():
self.weights[k] = sum(v) / self.num_clients
return self.weights
def pull(self) -> None:
"""Server pull weights from clients.
Namely clients push weights to the server. It is the gather process.
"""
# open space to collect weights from clients
datas = [proto.WeightsExchange() for _ in range(self.num_clients)]
weights = defaultdict(list)
# receive weights sequentially
for i in range(self.num_clients):
datas[i] = utils.receive_message(self.conns[i], datas[i])
for k, v in datas[i].weights.items():
weights[k].append(utils.deserialize_tensor(v))
# aggregation
self.aggregate(weights)
def push(self) -> None:
"""Server push weights to clients.
Namely clients pull weights from server. It is the scatter process.
"""
message = proto.WeightsExchange()
message.op_type = proto.SCATTER
for k, v in self.weights.items():
message.weights[k] = utils.serialize_tensor(v)
for conn in self.conns:
utils.send_message(conn, message)
if __name__ == "__main__":
args = parseargs()
server = Server(num_clients=args.num_clients, host=args.host, port=args.port)
server.start()
for i in range(args.max_epoch):
print(f"On epoch {i}:")
if i > 0:
# Push to Clients
server.push()
# Collects from Clients
server.pull()
server.close()