Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move assert statements out of run_smoke_test and into the actual test (for graceful shutdown in case of failure) #318

Merged
merged 31 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2046b24
mv assertions outside of run_smoke_tests
nerdai Jan 23, 2025
a65329c
add workflow_dispatch trigger for testing
nerdai Jan 23, 2025
8dfe3a7
get rid of ALL assert used in and raise Custom exception instead
nerdai Jan 23, 2025
a2490f1
add verbose flag
nerdai Jan 23, 2025
aead025
change scope to session
nerdai Jan 23, 2025
e8e6d5e
use module instead of session, should still work the same
nerdai Jan 23, 2025
a2cdef1
subset of tests, to test cleanup
nerdai Jan 23, 2025
4b8ebc9
cleanup func
nerdai Jan 23, 2025
2635e79
smaller subset
nerdai Jan 23, 2025
f898bc7
wip
nerdai Jan 23, 2025
6399c72
test entire
nerdai Jan 23, 2025
6eee21b
handle TimeoutError
nerdai Jan 23, 2025
32ef5b1
add cancel cleanup
nerdai Jan 23, 2025
f166d74
unlock all
nerdai Jan 23, 2025
157e18b
rm workflow_dispatch trigger
nerdai Jan 23, 2025
c527c05
clean up getting output from stdout
nerdai Jan 24, 2025
8962e33
working
nerdai Jan 24, 2025
94ab853
add comment
nerdai Jan 24, 2025
69afc0b
test on runner
nerdai Jan 24, 2025
c095753
add retry for flaky test
nerdai Jan 24, 2025
810dd30
removed unnecessary event_loop from tests
nerdai Jan 24, 2025
ab8d3a7
cr
nerdai Jan 24, 2025
f6d3cea
rm workflow_dispatch
nerdai Jan 24, 2025
01fdbd2
use function scope
nerdai Jan 24, 2025
643ed0d
add workflow_dispatch to test function scope
nerdai Jan 24, 2025
791ca2d
rm workflow_dispatch
nerdai Jan 24, 2025
9a5b0e1
revert back to module
nerdai Jan 24, 2025
6d2ea4a
add graceful shutdown of processes
nerdai Jan 24, 2025
d61fc5c
comment
nerdai Jan 24, 2025
1e9b944
increase attempts of flaky test
nerdai Jan 24, 2025
5c5576c
rm workflow_dispatch
nerdai Jan 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/smoke_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:
pull_request:
branches:
- main
workflow_dispatch:
nerdai marked this conversation as resolved.
Show resolved Hide resolved

jobs:
test:
Expand Down Expand Up @@ -46,4 +47,4 @@ jobs:
- name: Run Script
run: |
source .venv/bin/activate
pytest -m "smoketest"
pytest -m "smoketest" -v
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,3 @@ markers = [
"smoketest: marks tests as smoke tests (deselect with '-m \"not smoketest\"')",
]
asyncio_default_fixture_loop_scope = "session"
asyncio_mode = "auto"
193 changes: 121 additions & 72 deletions tests/smoke_tests/run_smoke_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import datetime
import json
import logging
import re
Expand All @@ -22,6 +21,20 @@
logger = logging.getLogger()

DEFAULT_TOLERANCE = 0.0005
DEFAULT_READ_LOGS_TIMEOUT = 300


# Custom Errors
class SmokeTestAssertError(Exception):
pass


class SmokeTestExecutionError(Exception):
pass


class SmokeTestTimeoutError(Exception):
pass


def postprocess_logs(logs: str) -> str:
Expand Down Expand Up @@ -50,7 +63,8 @@ async def run_smoke_test(
client_metrics: dict[str, Any] | None = None,
# assertion params
tolerance: float = DEFAULT_TOLERANCE,
) -> None:
read_logs_timeout: int = DEFAULT_READ_LOGS_TIMEOUT,
) -> tuple[list[str], list[str]]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now returns server_errors and client_errors to the caller.

"""Runs a smoke test for a given server, client, and dataset configuration.

Uses asyncio to kick off one server instance defined by the `server_python_path` module and N client instances
Expand Down Expand Up @@ -141,6 +155,10 @@ async def run_smoke_test(
client_metrics (dict[str, Any] | None): A dictionary of metrics to be checked against the metrics file
saved by the clients. Should be in the same format as fl4health.reporting.metrics.MetricsReporter.
Default is None.

Returns:
(server_errors, client_errors): (list[str], list[str]): list of errors from server and client processes,
respectively.
"""
clear_metrics_folder()

Expand Down Expand Up @@ -191,7 +209,8 @@ async def run_smoke_test(
output_found = False
while not output_found:
try:
assert server_process.stdout is not None, "Server's process stdout is None"
if not (server_process.stdout is not None):
raise SmokeTestExecutionError("Server's process stdout is None")
server_output_in_bytes = await asyncio.wait_for(server_process.stdout.readline(), 20)
server_output = server_output_in_bytes.decode()
logger.debug(f"Server output: {server_output}")
Expand All @@ -201,16 +220,16 @@ async def run_smoke_test(
break

return_code = server_process.returncode
assert return_code is None or (return_code is not None and return_code == 0), (
Copy link
Collaborator Author

@nerdai nerdai Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to remove these assert statements in way that has me not doing any dangerous Boolean algebra, I employ the following pattern. We can change this if we want...

# old
assert <cond>, <fail_msg>

# new
if not <cond>:
    raise SmokeTestError(fail_msg)

f"Full output:\n{full_server_output}\n" f"[ASSERT ERROR] Server exited with code {return_code}."
)
if not (return_code is None or (return_code is not None and return_code == 0)):
msg = f"Full output:\n{full_server_output}\n" f"[ASSERT ERROR] Server exited with code {return_code}."
raise SmokeTestAssertError(msg)

if any(startup_message in server_output for startup_message in startup_messages):
output_found = True

assert output_found, (
f"Full output:\n{full_server_output}\n" f"[ASSERT_ERROR] Startup log message not found in server output."
)
if not output_found:
msg = f"Full output:\n{full_server_output}\n" f"[ASSERT_ERROR] Startup log message not found in server output."
raise SmokeTestAssertError(msg)

logger.info("Server started")

Expand Down Expand Up @@ -240,74 +259,96 @@ async def run_smoke_test(
# Collecting the clients output when their processes finish
client_result_tasks = []
for i, client_process in enumerate(client_processes):
client_result_tasks.append(_wait_for_process_to_finish_and_retrieve_logs(client_process, f"Client {i}"))
client_result_tasks.append(
_wait_for_process_to_finish_and_retrieve_logs(client_process, f"Client {i}", read_logs_timeout),
)

full_client_outputs = await asyncio.gather(*client_result_tasks)
logger.info("All clients finished execution")

# Collecting the server output when its process finish
full_server_output = await _wait_for_process_to_finish_and_retrieve_logs(server_process, "Server")
full_server_output = await _wait_for_process_to_finish_and_retrieve_logs(
server_process, "Server", read_logs_timeout
)
full_server_output = postprocess_logs(full_server_output)

logger.info("Server has finished execution")

# server assertions
assert "error" not in full_server_output.lower(), (
f"Full output:\n{full_server_output}\n" "[ASSERT ERROR] Error message found for server."
)
if not ("error" not in full_server_output.lower()):
msg = f"Full output:\n{full_server_output}\n" "[ASSERT ERROR] Error message found for server."
raise SmokeTestAssertError(msg)

if assert_evaluation_logs:
assert f"Federated Evaluation received {config['n_clients']} results and 0 failures" in full_server_output, (
f"Full output:\n{full_server_output}\n" "[ASSERT ERROR] Last FL round message not found for server."
)
assert "Federated Evaluation Finished" in full_server_output, (
f"Full output:\n{full_server_output}\n"
"[ASSERT ERROR] Federated Evaluation Finished message not found for server."
)
if not (f"Federated Evaluation received {config['n_clients']} results and 0 failures" in full_server_output):
msg = f"Full output:\n{full_server_output}\n" "[ASSERT ERROR] Last FL round message not found for server."
raise SmokeTestAssertError(msg)

if not ("Federated Evaluation Finished" in full_server_output):
msg = (
f"Full output:\n{full_server_output}\n"
"[ASSERT ERROR] Federated Evaluation Finished message not found for server."
)
raise SmokeTestAssertError(msg)

else:
assert "[SUMMARY]" in full_server_output, (
f"Full output:\n{full_server_output}\n" "[ASSERT ERROR] [SUMMARY] message not found for server."
)
if not ("[SUMMARY]" in full_server_output):
msg = f"Full output:\n{full_server_output}\n" "[ASSERT ERROR] [SUMMARY] message not found for server."
raise SmokeTestAssertError(msg)
if not assert_evaluation_logs:
assert all(
message in full_server_output
for message in [
"History (loss, distributed):",
"History (metrics, distributed, fit):",
]
), f"Full output:\n{full_server_output}\n[ASSERT ERROR] Metrics message not found for server."
if not (
all(
message in full_server_output
for message in [
"History (loss, distributed):",
"History (metrics, distributed, fit):",
]
)
):
msg = f"Full output:\n{full_server_output}\n[ASSERT ERROR] Metrics message not found for server."
raise SmokeTestAssertError(msg)

else:
assert all(
message in full_server_output for message in ["History (metrics, distributed, evaluate):"]
), f"Full output:\n{full_server_output}\n[ASSERT ERROR] Metrics message not found for server."
if not (all(message in full_server_output for message in ["History (metrics, distributed, evaluate):"])):
msg = f"Full output:\n{full_server_output}\n[ASSERT ERROR] Metrics message not found for server."
raise SmokeTestAssertError(msg)

server_errors = _assert_metrics(MetricType.SERVER, server_metrics, tolerance)
assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"

# client assertions
client_errors = []
for i, full_client_output in enumerate(full_client_outputs):
full_client_output = postprocess_logs(full_client_output)
assert "error" not in full_client_output.lower(), (
f"Full client output:\n{full_client_output}\n" f"[ASSERT ERROR] Error message found for client {i}."
)
assert "Disconnect and shut down" in full_client_output, (
f"Full client output:\n{full_client_output}\n" f"[ASSERT ERROR] Shutdown message not found for client {i}."
)
if assert_evaluation_logs:
assert "Client Evaluation Local Model Metrics" in full_client_output, (
if not ("error" not in full_client_output.lower()):
msg = f"Full client output:\n{full_client_output}\n" f"[ASSERT ERROR] Error message found for client {i}."
raise SmokeTestAssertError(msg)

if not ("Disconnect and shut down" in full_client_output):
msg = (
f"Full client output:\n{full_client_output}\n"
f"[ASSERT ERROR] 'Client Evaluation Local Model Metrics' message not found for client {i}."
f"[ASSERT ERROR] Shutdown message not found for client {i}."
)
raise SmokeTestAssertError(msg)

if assert_evaluation_logs:
if not ("Client Evaluation Local Model Metrics" in full_client_output):
msg = (
f"Full client output:\n{full_client_output}\n"
f"[ASSERT ERROR] 'Client Evaluation Local Model Metrics' message not found for client {i}."
)
raise SmokeTestAssertError(msg)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh, I am not sure if I like the name SmokeTestAssertError it makes me feel like its asserting on a good output against an expected value.

I think these are probably better as SmokeTestExecutionError? I was just following our naming with [ASSERT ERROR] but I think this convention is confusing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm good with either. I'll leave it to you to decide 😂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm lazy now. I'm going to leave it as is lol.


elif not skip_assert_client_fl_rounds:
assert f"Current FL Round: {config['n_server_rounds']}" in full_client_output, (
f"Full client output:\n{full_client_output}\n"
f"[ASSERT ERROR] Last FL round message not found for client {i}."
)
if not (f"Current FL Round: {config['n_server_rounds']}" in full_client_output):
msg = (
f"Full client output:\n{full_client_output}\n"
f"[ASSERT ERROR] Last FL round message not found for client {i}."
)
raise SmokeTestAssertError(msg)

client_errors.extend(_assert_metrics(MetricType.CLIENT, client_metrics, tolerance))
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"

logger.info("All checks passed. Test finished.")
return server_errors, client_errors


async def run_fault_tolerance_smoke_test(
Expand All @@ -322,7 +363,8 @@ async def run_fault_tolerance_smoke_test(
intermediate_checkpoint_dir: str = "./",
server_name: str = "server",
tolerance: float = DEFAULT_TOLERANCE,
) -> None:
read_logs_timeout: int = DEFAULT_READ_LOGS_TIMEOUT,
) -> tuple[list[str], list[str]]:
"""Runs a smoke test for a given server, client, and dataset configuration.

Args:
Expand All @@ -349,6 +391,10 @@ async def run_fault_tolerance_smoke_test(
saved by the server. Should be in the same format as fl4health.reporting.metrics.MetricsReporter.
client_metrics (dict[str, Any]): A dictionary of metrics to be checked against the metrics file
saved by the clients. Should be in the same format as fl4health.reporting.metrics.MetricsReporter.

Returns:
(server_errors, client_errors): (list[str], list[str]): list of errors from server and client processes,
respectively.
"""
clear_metrics_folder()

Expand Down Expand Up @@ -426,13 +472,15 @@ async def run_fault_tolerance_smoke_test(

client_output_tasks = []
for i in range(len(client_processes)):
client_output_tasks.append(_wait_for_process_to_finish_and_retrieve_logs(client_processes[i], f"Client {i}"))
client_output_tasks.append(
_wait_for_process_to_finish_and_retrieve_logs(client_processes[i], f"Client {i}", read_logs_timeout),
)

_ = await asyncio.gather(*client_output_tasks)

logger.info("All clients finished execution")

await _wait_for_process_to_finish_and_retrieve_logs(server_process, "Server")
await _wait_for_process_to_finish_and_retrieve_logs(server_process, "Server", read_logs_timeout)

logger.info("Server has finished execution")

Expand Down Expand Up @@ -461,24 +509,22 @@ async def run_fault_tolerance_smoke_test(
client_processes.append(client_process)

for i in range(len(client_processes)):
await _wait_for_process_to_finish_and_retrieve_logs(client_processes[i], f"Client {i}")
await _wait_for_process_to_finish_and_retrieve_logs(client_processes[i], f"Client {i}", read_logs_timeout)

logger.info("All clients finished execution")

await _wait_for_process_to_finish_and_retrieve_logs(server_process, "Server")
await _wait_for_process_to_finish_and_retrieve_logs(server_process, "Server", read_logs_timeout)

logger.info("Server has finished execution")

server_errors = _assert_metrics(MetricType.SERVER, server_metrics, tolerance)
assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"

# client assertions
client_errors = []
for i in range(len(client_processes)):
client_errors.extend(_assert_metrics(MetricType.CLIENT, client_metrics, tolerance))
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"

logger.info("All checks passed. Test finished.")
return server_errors, client_errors


def _preload_dataset(dataset_path: str, config: Config, seed: int | None = None) -> None:
Expand Down Expand Up @@ -521,14 +567,12 @@ async def _wait_for_process_to_finish_and_retrieve_logs(
process_name: str,
timeout: int = 300, # timeout for the whole process to complete
) -> str:
logger.info(f"Collecting output for {process_name}...")
full_output = ""
try:
assert process.stdout
start_time = datetime.datetime.now()

async def get_output_from_stdout(stream_reader: asyncio.streams.StreamReader) -> tuple[str, int | None]:
full_output = ""
while True:
nerdai marked this conversation as resolved.
Show resolved Hide resolved
# giving a smaller timeout here just in case it hangs for a long time waiting for a single log line
Copy link
Collaborator Author

@nerdai nerdai Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cleaned this up, but note that we weren't actually giving this "inner task" a smaller timeout.

Instead I outsource this logic to a contained method get_output_from_stdout() which reads the stream until completion. This whole process of reading from the stream is what I assign a timeout for. (No more need for manual computation elapsed_time = datetime.datetime.now() - start_time if timeout was reached)

output_in_bytes = await asyncio.wait_for(process.stdout.readline(), timeout=timeout)
output_in_bytes = await stream_reader.readline()
await asyncio.sleep(0) # give control back to loop manager
output = output_in_bytes.decode().replace("\\n", "\n")
logger.debug(f"{process_name} output: {output}")
full_output += output
Expand All @@ -537,21 +581,26 @@ async def _wait_for_process_to_finish_and_retrieve_logs(
if output == "" and return_code is not None:
break

elapsed_time = datetime.datetime.now() - start_time
if elapsed_time.seconds > timeout:
raise Exception(f"Timeout limit of {timeout}s exceeded waiting for {process_name} to finish execution")
return full_output, return_code

logger.info(f"Collecting output for {process_name}...")

try:
if not (process.stdout is not None):
raise SmokeTestExecutionError("Process stdout is None")
full_output, return_code = await asyncio.wait_for(get_output_from_stdout(process.stdout), timeout=timeout)
except asyncio.exceptions.TimeoutError as e:
raise SmokeTestTimeoutError("Timeout for reading logs reached.") from e
except Exception as ex:
logger.error(f"{process_name} output:\n{full_output}")
logger.exception(f"Error collecting {process_name} log messages:")
raise ex

logger.info(f"Output collected for {process_name}")

# checking for clients with failure exit codes
assert return_code is None or (return_code is not None and return_code == 0), (
f"Full output:\n{full_output}\n" f"[ASSERT ERROR] {process_name} exited with code {return_code}."
)
if not (return_code is None or (return_code is not None and return_code == 0)):
msg = f"Full output:\n{full_output}\n" f"[ASSERT ERROR] {process_name} exited with code {return_code}."
raise SmokeTestAssertError(msg)

return full_output

Expand Down
Loading
Loading