Skip to content

Commit

Permalink
set up hostname and keep track of it - wip
Browse files Browse the repository at this point in the history
  • Loading branch information
kchilleri committed Sep 30, 2024
1 parent 28ef956 commit e40d8e6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
29 changes: 20 additions & 9 deletions beeflow/client/bee_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,34 @@ def __init__(self, *args):
self.args = args


def warn(*pargs):
"""Print a red warning message."""
typer.secho(' '.join(pargs), fg=typer.colors.RED, file=sys.stderr)


def db_path():
"""Return the client database path."""
bee_workdir = config_driver.BeeConfig.get('DEFAULT', 'bee_workdir')
return os.path.join(bee_workdir, 'client.db')


def connect_db():
"""Connect to the client database."""
return bdb.connect_db(client_db, db_path())
def setup_hostname(start_hn):
"""Set up front end name when beeflow core start is returned."""
db = bdb.connect_db(client_db, db_path())
db.info.set_hostname(start_hn)


def setup_hostname():
"""Set up front end name when beeflow core start is returned."""
db = connect_db(client_db, db_path())
# hard coding front end name for now
new_hostname = 'front_end_name'
db.info.set_hostname(new_hostname)
def check_hostname(curr_hn, stop = False):
"""Check current front end name matches the one beeflow was started on."""
db = bdb.connect_db(client_db, db_path())
start_hn = db.info.get_hostname()
if start_hn != "":
if curr_hn != start_hn:
warn(f'beeflow was started on "{start_hn}" and you are trying to run a command on "{curr_hn}".')
else:
warn('beeflow has not been started!')
if stop:
db.info.set_hostname("")


def error_exit(msg, include_caller=True):
Expand Down
10 changes: 10 additions & 0 deletions beeflow/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ def start(foreground: bool = typer.Option(False, '--foreground', '-F',

version = importlib.metadata.version("hpc-beeflow")
print(f'Starting beeflow {version}...')
start_hn = socket.gethostname() # hostname when beeflow starts
print(f'Running beeflow on {start_hn}')
bee_client.setup_hostname(start_hn) # add to client db
if not foreground:
print('Run `beeflow core status` for more information.')
# Create the log path if it doesn't exist yet
Expand All @@ -427,11 +430,13 @@ def start(foreground: bool = typer.Option(False, '--foreground', '-F',
@app.command()
def status():
"""Check the status of beeflow and the components."""
status_hn = socket.gethostname() # hostname when beeflow core status returned
resp = cli_connection.send(paths.beeflow_socket(), {'type': 'status'})
if resp is None:
beeflow_log = paths.log_fname('beeflow')
warn('Cannot connect to the beeflow daemon, is it running? Check the '
f'log at "{beeflow_log}".')
bee_client.check_hostname(status_hn)
sys.exit(1)
print('beeflow components:')
for comp, stat in resp['components'].items():
Expand All @@ -441,6 +446,8 @@ def status():
@app.command()
def info():
"""Get information about beeflow's installation."""
info_hn = socket.gethostname() # hostname when beeflow core info returned
bee_client.check_hostname(info_hn)
version = importlib.metadata.version("hpc-beeflow")
print(f"Beeflow version: {version}")
print(f"bee_workflow directory: {paths.workdir()}")
Expand All @@ -450,6 +457,7 @@ def info():
@app.command()
def stop(query='yes'):
"""Stop the current running beeflow daemon."""
stop_hn = socket.gethostname() # hostname when beeflow core stop returned
# Check workflow states; warn if there are active states, pause running workflows
workflow_list = bee_client.get_wf_list()
concern_states = {'Running', 'Initializing', 'Waiting'}
Expand All @@ -476,11 +484,13 @@ def stop(query='yes'):
warn('Error: beeflow is not running on this system. It could be '
'running on a different front end.\n'
f' Check the beeflow log: "{beeflow_log}".')
#bee_client.check_hostname(stop_hn, stop = True)
sys.exit(1)
# As long as it returned something, we should be good
beeflow_log = paths.log_fname('beeflow')
if query == "yes":
print(f'Beeflow has stopped. Check the log at "{beeflow_log}".')
bee_client.check_hostname(stop_hn, stop = True)


def archive_dir(dir_to_archive):
Expand Down
23 changes: 18 additions & 5 deletions beeflow/common/db/client_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ class ClientInfo:
"""Client Info object."""

def __init__(self, db_file):
"""Initialize info and db file."""
self.Info = namedtuple("Info", "id hostname") # noqa Snake Case
self.db_file = db_file
"""Initialize info and db file."""
self.Info = namedtuple("Info", "id hostname") # noqa Snake Case
self.db_file = db_file

def set_hostname(self, new_hostname):
"""Set hostname for current front end."""
Expand All @@ -25,6 +25,13 @@ def get_hostname(self):
hostname = result
return hostname

def get_info(self):
"""Return an info object containing port information."""
stmt = "SELECT * FROM info"
result = bdb.getone(self.db_file, stmt)
info = self.Info(*result)
return info

class ClientDB:
"""Client database."""

Expand All @@ -35,10 +42,16 @@ def __init__(self, db_file):

def _init_tables(self):
"""Initialize the client table if it doesn't exist."""
info_stmt = """CREATE TABLE IF NOT EXISTS info(
info_stmt = """CREATE TABLE IF NOT EXISTS info (
id INTEGER PRIMARY KEY ASC,
hostname TEXT);"""
bdb.create_table(self.db_file, info_stmt)
# bdb.create_table(self.db_file, info_stmt)
if not bdb.table_exists(self.db_file, 'info'):
bdb.create_table(self.db_file, info_stmt)
# insert a new workflow into the database
stmt = """INSERT INTO info (hostname) VALUES(?);"""
tmp = ""
bdb.run(self.db_file, stmt, [tmp])

@property
def info(self):
Expand Down
8 changes: 6 additions & 2 deletions beeflow/tests/test_db_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ def test_empty(temp_db):
"""Test the empty database."""
db = temp_db

assert len(list(db.info)) == 0
hn = db.info.get_hostname()
assert hn == ""


def test_info(temp_db):
"""Test setting the info."""
db = temp_db

db.info.set_hostname('front_end_name')
assert db.info.get_hostname == 'front_end_name'
hn = db.info.get_hostname()

print("testing get_info: ", db.info.get_info())
assert hn == 'front_end_name'

0 comments on commit e40d8e6

Please sign in to comment.