Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial structure for out of the application workflow orchestration #5

Draft
wants to merge 9 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ option(WITH_HDF5 "Use HDF5 as a database back end" OFF)
option(WITH_RMQ "Use RabbitMQ as a database back end (require a reachable and running RabbitMQ server service)" OFF)
option(WITH_AMS_DEBUG "Enable verbose messages" OFF)
option(WITH_PERFFLOWASPECT "Use PerfFlowAspect for Profiling" OFF)
option(WITH_WORKFLOW "Install python drivers used by the outer workflow" OFF)
option(BUILD_SHARED_LIBS "Build using shared libraries" ON)

if (WITH_MPI)
Expand Down
25 changes: 25 additions & 0 deletions src/AMSWorkflow/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
find_package(Python COMPONENTS Interpreter REQUIRED)

add_subdirectory(app)
add_subdirectory(ams)

configure_file("setup.py" "${CMAKE_CURRENT_BINARY_DIR}/setup.py" COPYONLY)

file(GLOB_RECURSE pyfiles *.py app/*.py ams/*.py)

# detect virtualenv and set Pip args accordingly
set(AMS_PY_APP "${CMAKE_CURRENT_BINARY_DIR}")
if(DEFINED ENV{VIRTUAL_ENV} OR DEFINED ENV{CONDA_PREFIX})
set(_pip_args)
else()
set(_pip_args "--user")
endif()

message(WARNING "AMS Python Source files are ${pyfiles}")
message(WARNING "AMS Python built cmd is : ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP}")

add_custom_target(PyAMS ALL
COMMAND ${Python_EXECUTABLE} -m pip install ${_pip_args} ${AMS_PY_APP}
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
COMMENT "Build AMS-WF Python Modules and Applications"
DEPENDS ${pyfiles})
11 changes: 11 additions & 0 deletions src/AMSWorkflow/ams/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# AMSLib Project Developers
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

file(GLOB_RECURSE pyfiles *.py)
foreach (filename ${pyfiles})
get_filename_component(target "${filename}" NAME)
message(STATUS "Copying ${filename} to ${target}")
configure_file("${filename}" "${CMAKE_CURRENT_BINARY_DIR}/${target}" COPYONLY)
endforeach (filename)
102 changes: 102 additions & 0 deletions src/AMSWorkflow/ams/orchestrator.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@milroy This is a general 'driver' of the code we are going to run outside of compute nodes. I re-used and restructured your previous commit and tried to abstract out some of the hard-coded paths.

Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# AMSLib Project Developers
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import flux
import json
from flux.security import SecurityContext
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@milroy I get an import error. I am guessing because of running on an old flux version. Can you make the appropriate changes to support both versions?

from flux import job
from flux.job import JobspecV1

from .ams_rmq import RMQClient


class Orchestrator:
def __init__(self, server_config_fn, certificate):
with open(server_config_fn, "r") as fd:
server_config = json.load(fd)

self.host = (server_config["service-vhost"],)
self.port = (server_config["service-port"],)
self.user = (server_config["rabbitmq-user"],)
self.password = server_config["rabbitmq-password"]
self.certificate = certificate

class AMSDaemon(Orchestrator):
"""
Class modeling a rmq-client daemon running on a compute
allocation that will issue flux run commands
"""

def __init__(self, server_config_fn, certificate):
super().__init__(server_config_fn, certificate)

def getMLJobSpec(self, client):
with client.connect("test3") as channel:
spec = channel.receive(n_msg=1).pop()
# TODO: Write some simple wrapper class around the 'dict'
# to correcly create ML Job Specification
return spec

def __run(self, flux_handle, client, jobspec):
with client.connect("ml-start") as channel:
while True:
# Currently we ignore the message
channel.receive(n_msg=1)
jobid = flux.job.submit(flux_handle, jobspec, pre_signed=True, wait=True)
job.wait()
# TODO: Send completed message in RMQ

def __call__(self):
flux_handle = flux.Flux()

with RMQClient(self.host, self.port, self.user, self.password, self.certificate) as client:
# We currently assume a single ML job specification
ml_job_spec = self.getMLJobSpec(client)

# Create a Flux Job Specification
jobspec = JobspecV1.from_command(
command=ml_job_spec["jobspec"]["command"],
num_tasks=ml_job_spec["jobspec"]["num_tasks"],
num_nodes=ml_job_spec["jobspec"]["num_nodes"],
cores_per_task=ml_job_spec["jobspec"]["cores_per_task"],
gpus_per_task=ml_job_spec["jobspec"]["gpus_per_task"],
)

ctx = SecurityContext()
signed_jobspec = ctx.sign_wrap_as(
ml_job_spec["uid"], jobspec.dumps(), mech_type="none"
).decode("utf-8")
# This is a 'busy' loop
self.__run(flux_handle, client, signed_jobspec)

class FluxDaemonWrapper(Orchestrator):
"""
class to start Daemon through Flux
"""

def __init__(self, server_config_fn, certificate):
super().__init__(server_config_fn, certificate)

def getFluxUri(self, client):
with client.connect("test3") as channel:
msg = channel.receive(n_msg=1).pop()
return msg['ml_uri']

def __call__(self, application_cmd : list):
if not isinstance(application_cmd, list):
raise TypeError('StartDaemon requires application_cmd as a list')

with RMQClient(self.host, self.port, self.user, self.password, self.certificate) as client:
self.uri = getFluxUri(client)

flux_cmd = [
"flux",
"proxy",
"--force",
f"{self.uri}",
"flux"]
cmd = flux_cmd + application_cmd
subprocess.run(cmd)

144 changes: 144 additions & 0 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# AMSLib Project Developers
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import pika
import ssl
import sys
import os
import logging
import json


class RMQClient:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lpottier This is a first re-implementation of the RMQClient we had on our previous hands-on meeting. Connecting with the server needs to happen with a context manager python API. The same applies for enabling/pulling messages from channels. This is inspired from the orchestrator requirements, but use it as a baseline to support the database client. Ideally, we would like to have a single API-point of entry to communicating messages across AMS components.

"""
RMQClient is a class that manages the RMQ client lifecycle.
"""

class RMQChannel:
"""
A wrapper around RMQ channel
"""

def __init__(self, connection, q_name):
self.q_name = q_name

def __enter__(self):
self.open()
return self

def __exit__(self):
self.close()

@staticmethod
def callback(method, properties, body):
return body.decode("utf-8")

def open(self):
self.channel = connection.channel()
self.channel.queue_declare(queue=q_name)

def close(self):
self.channel.close()

def receive(self, n_msg: int = None, accum_msg = list()):
"""
Consume a message on the queue and post processing by calling the callback.
@param n_msg The number of messages to receive.
- if n_msg is None, this call will block for ever and will process all messages that arrives
- if n_msg = 1 for example, this function will block until one message has been processed.
@return a list containing all received messages
"""

if self.channel and self.channel.is_open:
self.logger.info(
f"Starting to consume messages from queue={self.q_name}, routing_key={self.routing_key} ..."
)
# we will consume only n_msg and requeue all other messages
# if there are more messages in the queue.
# It will block as long as n_msg did not get read
if n_msg:
n_msg = max(n_msg, 0)
message_consumed = 0
# Comsume n_msg messages and break out
for method_frame, properties, body in self.channel.consume(self.q_name):
# Call the call on the message parts
try:
accum_msg.append(
RMQClient.RMQChannel(
method_frame,
properties,
body,
)
)
except Exception as e:
self.logger.error(f"Exception {type(e)}: {e}")
self.logger.debug(traceback.format_exc())
finally:
# Acknowledge the message even on failure
self.channel.basic_ack(delivery_tag=method_frame.delivery_tag)
self.logger.warning(
f"Consumed message {message_consumed+1}/{method_frame.delivery_tag} (exchange={method_frame.exchange}, routing_key={method_frame.routing_key})"
)
message_consumed += 1
# Escape out of the loop after nb_msg messages
if message_consumed == n_msg:
# Cancel the consumer and return any pending messages
self.channel.cancel()
break
return accum_msg

def send(self, text):
"""
Send \p text
"""

self.channel.basic_publish(exchange="", routing_key=self.q_name, body=text)

return

def get_messages(self):
return # messages

def purge(self):
"""Removes all the messages from the queue."""

if self.channel and self.channel.is_open:
self.channel.queue_purge(self.q_name)

def __init__(self, vhost, port, user, password, cert, logger: logging.Logger = None):
# CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file):
# openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' rmq-pds.crt
self.logger = logger if logger else logging.getLogger(__name__)
self.cert = cert
self.context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
self.context.verify_mode = ssl.CERT_REQUIRED
self.context.load_verify_locations(self.cert)
self.vhost = vhost
self.port = port
self.user = user
self.password = password

self.credentials = pika.PlainCredentials(
self.user, self.password)

self.connection_params = pika.ConnectionParameters(
host=self.host,
port=self.port,
virtual_host=self.vhost,
credentials=self.credentials,
ssl_options=pika.SSLOptions(self.context),
)

def __enter__(self):
self.connection = pika.BlockingConnection(self.connection_params)
return self

def __exit__(self):
self.connection.close()

def connect(self, queue):
"""Connect to the queue"""
return RMQChannel(self.connection, queue)

15 changes: 15 additions & 0 deletions src/AMSWorkflow/ams_wf/AMSDBStage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# AMSLib Project Developers
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse
from pathlib import Path


def main():
print("Place Holder for AMSDBStage.py")


if __name__ == "__main__":
main()
73 changes: 73 additions & 0 deletions src/AMSWorkflow/ams_wf/AMSOrchestrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# AMSLib Project Developers
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import sys
import os
from ams.rmq import RMQClient
from ams.orchestrator import AMSDaemon
from ams.orchestrator import FluxDaemonWrapper

import argparse


def main():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@milroy I have not tested any of this code. Take a look and let me know of any considerations. Effectively I bundle both the AMSDaemon and the FluxDaemon at the same python entry point.

daemon_actions = ["start", "wrap"]
parser = argparse.ArgumentParser(
description="AMS Machine Learning Daemon running on Training allocation"
)
parser.add_argument(
"-a",
"--action",
dest='action',
choices=daemon_actions,
help="Decide whether to start daemon process directly or through flux wrap script",
required=True,
)

parser.add_argument(
"-c",
"--certificate",
dest="certificate",
help="Path to certificate file to establish connection",
required=True,
)

parser.add_argument(
"-cfg",
"--config",
dest="config",
help="Path to AMS configuration file",
required=True,
)

args = parser.parse_args()
if args.action == "start":
daemon = AMSDaemon(args.config, args.certificate)
# Busy wait for messages and spawn ML training jobs
daemon()
elif args.action == "wrap":
daemon_cmd = [
"python",
__file__,
"--action",
"start",
"-c",
args.certificate,
"-cfg",
args.config,
]
daemon = FluxDaemonWrapper(args.config, args.certificate)
daemon(daemon_cmd)


if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("Interrupted")
try:
sys.exit(0)
except SystemExit:
os._exit(0)
12 changes: 12 additions & 0 deletions src/AMSWorkflow/ams_wf/AMSTrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
# AMSLib Project Developers
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import argparse

def main():
print('Hello from AMSTrain.py')

if __name__ == '__main__':
main()
Loading