-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
180 lines (148 loc) · 6.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# main.py
import sys
import argparse
import os
import socket
import threading
import time
import subprocess
def wait_for_worker_completion(conn, worker_rank, completion_event):
# Wait for "DONE" message from worker
data = conn.recv(1024)
if data.decode() == "DONE":
print(f"[Master] Worker {worker_rank} has completed its task.")
completion_event.set()
conn.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--master_addr', type=str, default='localhost', help='Address of the master node (hostname or IP)')
parser.add_argument('--master_port', type=int, default=12345, help='Port for the master server to bind')
parser.add_argument('--nnodes', type=int, default=1, help='Number of nodes')
parser.add_argument('--nproc_per_node', type=int, required=True, help='Number of processes per node')
parser.add_argument('--node_rank', type=int, default=None, help='Rank of the node (0 to nnodes-1)')
parser.add_argument('--local_rank', type=int, default=None, help='Rank of the process on the node (0 to nproc_per_node-1)')
parser.add_argument('--task', type=str, required=True, help='Path to the task script (e.g., example_task.py)')
args, unknown = parser.parse_known_args()
hostname = socket.gethostname()
master_addr = args.master_addr
master_port = args.master_port
nnodes = args.nnodes
nproc_per_node = args.nproc_per_node
node_rank = args.node_rank
local_rank = args.local_rank
# Arguments validation checks
assert nnodes >= 1, "Error: --nnodes must be a positive integer (>= 1)."
assert nproc_per_node >= 1, "Error: --nproc_per_node must be a positive integer (>= 1)."
if nnodes > 1 and node_rank is None:
print("Error: --node_rank must be specified when nnodes > 1")
sys.exit(1)
elif node_rank is None:
node_rank = 0
assert 0 <= node_rank < nnodes, f"Error: --node_rank must be between 0 and nnodes - 1 (0 <= node_rank < {nnodes})."
if local_rank is not None:
assert 0 <= local_rank < nproc_per_node, f"Error: --local_rank must be between 0 and nproc_per_node - 1 (0 <= local_rank < {nproc_per_node})."
assert 0 < master_port < 65536, "Error: --master_port must be an integer between 1 and 65535."
# If local_rank is not provided, spawn multiple processes on each node
if local_rank is None:
processes = []
for local_rank in range(nproc_per_node):
cmd = [sys.executable, sys.argv[0]] + sys.argv[1:] + ['--local_rank', str(local_rank)]
p = subprocess.Popen(cmd)
processes.append(p)
for p in processes:
p.wait()
sys.exit(0)
# Calculate the global rank and world size
rank = node_rank * nproc_per_node + local_rank
world_size = nnodes * nproc_per_node
# Determine if this process is the master
is_master = (node_rank == 0 and local_rank == 0)
if is_master:
# Master process
try:
# Bind to all interfaces on the specified port
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind(('', master_port))
server_socket.listen()
print(f"[{hostname}] Became master at {hostname}:{master_port}")
except OSError as e:
print(f"[{hostname}] Failed to bind as master: {e}")
sys.exit(1)
connections = []
completion_events = []
print(f"[Master {hostname}] Waiting for worker connections...")
# Accept connections from all workers
expected_workers = world_size - 1 # Exclude master itself
while len(connections) < expected_workers:
try:
conn, addr = server_socket.accept()
data = conn.recv(1024)
worker_rank = int(data.decode())
print(f"[Master] Connected to worker {worker_rank}")
connections.append((conn, worker_rank))
except Exception as e:
print(f"[Master] Error accepting connections: {e}")
print(f"[Master] All workers are connected. Sending START signal to all workers.")
# Send START signal to all workers
for conn, worker_rank in connections:
conn.sendall("START".encode())
# Prepare to receive completion signals from workers
for conn, worker_rank in connections:
completion_event = threading.Event()
t = threading.Thread(target=wait_for_worker_completion, args=(conn, worker_rank, completion_event))
t.start()
completion_events.append(completion_event)
# Set environment variables for the task script
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port)
# Execute the task script
task_command = [sys.executable, args.task] + unknown
print(f"[Rank {rank}] Executing task: {' '.join(task_command)}")
sys.stdout.flush()
sys.stderr.flush()
result = subprocess.run(task_command, env=os.environ)
# Wait for all workers to complete
for event in completion_events:
event.wait()
print(f"[Master] All workers have completed their tasks.")
server_socket.close()
sys.exit(result.returncode)
else:
# Worker process
connected = False
while not connected:
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((master_addr, master_port))
connected = True
except Exception:
time.sleep(1) # Wait before retrying
continue
# Send rank to master
s.sendall(str(rank).encode())
print(f"[Worker {hostname}] Connected to master at {master_addr}:{master_port}")
# Wait for START signal from master
data = s.recv(1024)
if data.decode() == "START":
print(f"[Worker {hostname}] Received START signal from master.")
# Set environment variables for the task script
os.environ['RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = str(master_port)
# Execute the task script
task_command = [sys.executable, args.task] + unknown
print(f"[Rank {rank}] Executing task: {' '.join(task_command)}")
sys.stdout.flush()
sys.stderr.flush()
result = subprocess.run(task_command, env=os.environ)
# After task completion, send completion signal to master
s.sendall("DONE".encode())
s.close()
sys.exit(result.returncode)
if __name__ == '__main__':
main()