-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial structure for out of the application workflow orchestration #5
base: develop
Are you sure you want to change the base?
Changes from 4 commits
00df733
1d08ff3
c4d323f
5f9da01
44c15e6
2d93aec
d883442
184f53c
61bfb06
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
find_package(Python COMPONENTS Interpreter REQUIRED) | ||
|
||
add_subdirectory(app) | ||
add_subdirectory(ams) | ||
|
||
configure_file("setup.py" "${CMAKE_CURRENT_BINARY_DIR}/setup.py" COPYONLY) | ||
|
||
file(GLOB_RECURSE pyfiles *.py app/*.py ams/*.py) | ||
|
||
# detect virtualenv and set Pip args accordingly | ||
set(AMS_PY_APP "${CMAKE_CURRENT_BINARY_DIR}") | ||
if(DEFINED ENV{VIRTUAL_ENV} OR DEFINED ENV{CONDA_PREFIX}) | ||
set(_pip_args) | ||
else() | ||
set(_pip_args "--user") | ||
endif() | ||
|
||
message(WARNING "AMS Python Source files are ${pyfiles}") | ||
message(WARNING "AMS Python built cmd is : ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP}") | ||
|
||
add_custom_target(PyAMS ALL | ||
COMMAND ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP} | ||
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" | ||
COMMENT "Build AMS-WF Python Modules and Applications" | ||
DEPENDS ${pyfiles}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other | ||
# AMSLib Project Developers | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
file(GLOB_RECURSE pyfiles *.py) | ||
foreach (filename ${pyfiles}) | ||
get_filename_component(target "${filename}" NAME) | ||
message(STATUS "Copying ${filename} to ${target}") | ||
configure_file("${filename}" "${CMAKE_CURRENT_BINARY_DIR}/${target}" COPYONLY) | ||
endforeach (filename) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other | ||
# AMSLib Project Developers | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import flux | ||
import json | ||
from flux.security import SecurityContext | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @milroy I get an import error. I am guessing because of running on an old flux version. Can you make the appropriate changes to support both versions? |
||
from flux import job | ||
from flux.job import JobspecV1 | ||
|
||
from .ams_rmq import RMQClient | ||
|
||
|
||
class Orchestrator: | ||
def __init__(self, server_config_fn, certificate): | ||
with open(server_config_fn, "r") as fd: | ||
server_config = json.load(fd) | ||
|
||
self.host = (server_config["service-vhost"],) | ||
self.port = (server_config["service-port"],) | ||
self.user = (server_config["rabbitmq-user"],) | ||
self.password = server_config["rabbitmq-password"] | ||
self.certificate = certificate | ||
|
||
class AMSDaemon(Orchestrator): | ||
""" | ||
Class modeling a rmq-client daemon running on a compute | ||
allocation that will issue flux run commands | ||
""" | ||
|
||
def __init__(self, server_config_fn, certificate): | ||
super().__init__(server_config_fn, certificate) | ||
|
||
def getMLJobSpec(self, client): | ||
with client.connect("test3") as channel: | ||
spec = channel.receive(n_msg=1).pop() | ||
# TODO: Write some simple wrapper class around the 'dict' | ||
# to correcly create ML Job Specification | ||
return spec | ||
|
||
def __run(self, flux_handle, client, jobspec): | ||
with client.connect("ml-start") as channel: | ||
while True: | ||
# Currently we ignore the message | ||
channel.receive(n_msg=1) | ||
jobid = flux.job.submit(flux_handle, jobspec, pre_signed=True, wait=True) | ||
job.wait() | ||
# TODO: Send completed message in RMQ | ||
|
||
def __call__(self): | ||
flux_handle = flux.Flux() | ||
|
||
with RMQClient(self.host, self.port, self.user, self.password, self.certificate) as client: | ||
# We currently assume a single ML job specification | ||
ml_job_spec = self.getMLJobSpec(client) | ||
|
||
# Create a Flux Job Specification | ||
jobspec = JobspecV1.from_command( | ||
command=ml_job_spec["jobspec"]["command"], | ||
num_tasks=ml_job_spec["jobspec"]["num_tasks"], | ||
num_nodes=ml_job_spec["jobspec"]["num_nodes"], | ||
cores_per_task=ml_job_spec["jobspec"]["cores_per_task"], | ||
gpus_per_task=ml_job_spec["jobspec"]["gpus_per_task"], | ||
) | ||
|
||
ctx = SecurityContext() | ||
signed_jobspec = ctx.sign_wrap_as( | ||
ml_job_spec["uid"], jobspec.dumps(), mech_type="none" | ||
).decode("utf-8") | ||
# This is a 'busy' loop | ||
self.__run(flux_handle, client, signed_jobspec) | ||
|
||
class FluxDaemonWrapper(Orchestrator): | ||
""" | ||
class to start Daemon through Flux | ||
""" | ||
|
||
def __init__(self, server_config_fn, certificate): | ||
super().__init__(server_config_fn, certificate) | ||
|
||
def getFluxUri(self, client): | ||
with client.connect("test3") as channel: | ||
msg = channel.receive(n_msg=1).pop() | ||
return msg['ml_uri'] | ||
|
||
def __call__(self, application_cmd : list): | ||
if not isinstance(application_cmd, list): | ||
raise TypeError('StartDaemon requires application_cmd as a list') | ||
|
||
with RMQClient(self.host, self.port, self.user, self.password, self.certificate) as client: | ||
self.uri = getFluxUri(client) | ||
|
||
flux_cmd = [ | ||
"flux", | ||
"proxy", | ||
"--force", | ||
f"{self.uri}", | ||
"flux"] | ||
cmd = flux_cmd + application_cmd | ||
subprocess.run(cmd) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other | ||
# AMSLib Project Developers | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import pika | ||
import ssl | ||
import sys | ||
import os | ||
import logging | ||
import json | ||
|
||
|
||
class RMQClient: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lpottier This is a first re-implementation of the RMQClient we had on our previous hands-on meeting. Connecting with the server needs to happen with a context manager python API. The same applies for enabling/pulling messages from channels. This is inspired from the orchestrator requirements, but use it as a baseline to support the database client. Ideally, we would like to have a single API-point of entry to communicating messages across AMS components. |
||
""" | ||
RMQClient is a class that manages the RMQ client lifecycle. | ||
""" | ||
|
||
class RMQChannel: | ||
""" | ||
A wrapper around RMQ channel | ||
""" | ||
|
||
def __init__(self, connection, q_name): | ||
self.q_name = q_name | ||
|
||
def __enter__(self): | ||
self.open() | ||
return self | ||
|
||
def __exit__(self): | ||
self.close() | ||
|
||
@staticmethod | ||
def callback(method, properties, body): | ||
return body.decode("utf-8") | ||
|
||
def open(self): | ||
self.channel = connection.channel() | ||
self.channel.queue_declare(queue=q_name) | ||
|
||
def close(self): | ||
self.channel.close() | ||
|
||
def receive(self, n_msg: int = None, accum_msg = list()): | ||
""" | ||
Consume a message on the queue and post processing by calling the callback. | ||
@param n_msg The number of messages to receive. | ||
- if n_msg is None, this call will block for ever and will process all messages that arrives | ||
- if n_msg = 1 for example, this function will block until one message has been processed. | ||
@return a list containing all received messages | ||
""" | ||
|
||
if self.channel and self.channel.is_open: | ||
self.logger.info( | ||
f"Starting to consume messages from queue={self.q_name}, routing_key={self.routing_key} ..." | ||
) | ||
# we will consume only n_msg and requeue all other messages | ||
# if there are more messages in the queue. | ||
# It will block as long as n_msg did not get read | ||
if n_msg: | ||
n_msg = max(n_msg, 0) | ||
message_consumed = 0 | ||
# Comsume n_msg messages and break out | ||
for method_frame, properties, body in self.channel.consume(self.q_name): | ||
# Call the call on the message parts | ||
try: | ||
accum_msg.append( | ||
RMQClient.RMQChannel( | ||
method_frame, | ||
properties, | ||
body, | ||
) | ||
) | ||
except Exception as e: | ||
self.logger.error(f"Exception {type(e)}: {e}") | ||
self.logger.debug(traceback.format_exc()) | ||
finally: | ||
# Acknowledge the message even on failure | ||
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag) | ||
self.logger.warning( | ||
f"Consumed message {message_consumed+1}/{method_frame.delivery_tag} (exchange={method_frame.exchange}, routing_key={method_frame.routing_key})" | ||
) | ||
message_consumed += 1 | ||
# Escape out of the loop after nb_msg messages | ||
if message_consumed == n_msg: | ||
# Cancel the consumer and return any pending messages | ||
self.channel.cancel() | ||
break | ||
return accum_msg | ||
|
||
def send(self, text): | ||
""" | ||
Send \p text | ||
""" | ||
|
||
self.channel.basic_publish(exchange="", routing_key=self.q_name, body=text) | ||
|
||
return | ||
|
||
def get_messages(self): | ||
return # messages | ||
|
||
def purge(self): | ||
"""Removes all the messages from the queue.""" | ||
|
||
if self.channel and self.channel.is_open: | ||
self.channel.queue_purge(self.q_name) | ||
|
||
def __init__(self, vhost, port, user, password, cert, logger: logging.Logger = None): | ||
# CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file): | ||
# openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' rmq-pds.crt | ||
self.logger = logger if logger else logging.getLogger(__name__) | ||
self.cert = cert | ||
self.context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) | ||
self.context.verify_mode = ssl.CERT_REQUIRED | ||
self.context.load_verify_locations(self.cert) | ||
self.vhost = vhost | ||
self.port = port | ||
self.user = user | ||
self.password = password | ||
|
||
self.credentials = pika.PlainCredentials( | ||
self.user, self.password) | ||
|
||
self.connection_params = pika.ConnectionParameters( | ||
host=self.host, | ||
port=self.port, | ||
virtual_host=self.vhost, | ||
credentials=self.credentials, | ||
ssl_options=pika.SSLOptions(self.context), | ||
) | ||
|
||
def __enter__(self): | ||
self.connection = pika.BlockingConnection(self.connection_params) | ||
return self | ||
|
||
def __exit__(self): | ||
self.connection.close() | ||
|
||
def connect(self, queue): | ||
"""Connect to the queue""" | ||
return RMQChannel(self.connection, queue) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other | ||
# AMSLib Project Developers | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import argparse | ||
from pathlib import Path | ||
|
||
|
||
def main(): | ||
print("Place Holder for AMSDBStage.py") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other | ||
# AMSLib Project Developers | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import sys | ||
import os | ||
from ams.rmq import RMQClient | ||
from ams.orchestrator import AMSDaemon | ||
from ams.orchestrator import FluxDaemonWrapper | ||
|
||
import argparse | ||
|
||
|
||
def main(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @milroy I have not tested any of this code. Take a look and let me know of any considerations. Effectively I bundle both the AMSDaemon and the FluxDaemon at the same python entry point. |
||
daemon_actions = ["start", "wrap"] | ||
parser = argparse.ArgumentParser( | ||
description="AMS Machine Learning Daemon running on Training allocation" | ||
) | ||
parser.add_argument( | ||
"-a", | ||
"--action", | ||
dest='action', | ||
choices=daemon_actions, | ||
help="Decide whether to start daemon process directly or through flux wrap script", | ||
required=True, | ||
) | ||
|
||
parser.add_argument( | ||
"-c", | ||
"--certificate", | ||
dest="certificate", | ||
help="Path to certificate file to establish connection", | ||
required=True, | ||
) | ||
|
||
parser.add_argument( | ||
"-cfg", | ||
"--config", | ||
dest="config", | ||
help="Path to AMS configuration file", | ||
required=True, | ||
) | ||
|
||
args = parser.parse_args() | ||
if args.action == "start": | ||
daemon = AMSDaemon(args.config, args.certificate) | ||
# Busy wait for messages and spawn ML training jobs | ||
daemon() | ||
elif args.action == "wrap": | ||
daemon_cmd = [ | ||
"python", | ||
__file__, | ||
"--action", | ||
"start", | ||
"-c", | ||
args.certificate, | ||
"-cfg", | ||
args.config, | ||
] | ||
daemon = FluxDaemonWrapper(args.config, args.certificate) | ||
daemon(daemon_cmd) | ||
|
||
|
||
if __name__ == "__main__": | ||
try: | ||
main() | ||
except KeyboardInterrupt: | ||
print("Interrupted") | ||
try: | ||
sys.exit(0) | ||
except SystemExit: | ||
os._exit(0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other | ||
# AMSLib Project Developers | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import argparse | ||
|
||
def main(): | ||
print('Hello from AMSTrain.py') | ||
|
||
if __name__ == '__main__': | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@milroy This is a general 'driver' of the code we are going to run outside of compute nodes. I re-used and restructured your previous commit and tried to abstract out some of the hard-coded paths.