Skip to content

Commit

Permalink
various fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fotstrt committed Nov 8, 2024
1 parent a9da429 commit dae940c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions sailor/run_ft_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@ def CheckReady(self, request, context):
def Kill(self, request, context):
print(f"Killing local process ...")
if self.training_process_alive:
print("HERE")
os.system("pkill -f run_train_custom.py") # TODO: check cleanup
self.training_process_alive = False
# TODO: check abort
return KillResponse()

def ConfigurationChange(self, request, context):
assert not self.training_process
assert not self.training_process_alive
print(f"Got topology: {request.topology}")

# check if rank in participants
topology_list = list(request.topology)
if self.is_in_topo(topology_list):
print(f"Starting new process, node rank is {self.node_rank}")
start_cmd = f"python run_train_custom.py --config_file {self.script_args.config_file} --world_size {self.world_size} --rank {self.node_rank} --master_ip {self.master_addr}"
start_cmd = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --rank {self.node_rank} --master-ip {self.master_addr}"
os.system(start_cmd)
self.training_process_alive = True
return WorkerConfigurationResponse()
Expand Down
2 changes: 1 addition & 1 deletion sailor/run_train_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def get_args():

os.environ['WORLD_SIZE'] = str(args.world_size)
os.environ['RANK'] = str(args.rank)
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_ADDR'] = args.master_ip
os.environ['MASTER_PORT'] = "1234" # TODO


Expand Down

0 comments on commit dae940c

Please sign in to comment.