Skip to content

Commit

Permalink
smaller subset
Browse files Browse the repository at this point in the history
  • Loading branch information
nerdai committed Jan 23, 2025
1 parent 8577c1e commit 83db1c8
Showing 1 changed file with 67 additions and 70 deletions.
137 changes: 67 additions & 70 deletions tests/smoke_tests/test_smoke_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,27 +183,33 @@ async def cleanup_test(event_loop: asyncio.AbstractEventLoop) -> None:

@pytest.mark.smoketest
async def test_client_level_dp_breast_cancer(tolerance: float) -> None:
event_loop = asyncio.get_event_loop()
try:
server_errors, client_errors = await run_smoke_test(
server_python_path="examples.dp_fed_examples.client_level_dp_weighted.server",
client_python_path="examples.dp_fed_examples.client_level_dp_weighted.client",
config_path="tests/smoke_tests/client_level_dp_weighted_config.yaml",
dataset_path="examples/datasets/breast_cancer_data/hospital_0.csv",
skip_assert_client_fl_rounds=True,
tolerance=tolerance,
)
except Exception as e:
await cleanup_test(event_loop)
pytest.fail(f"Smoke test execution failed: {e}")

assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"
smoke_test_coro = run_smoke_test(
server_python_path="examples.dp_fed_examples.client_level_dp_weighted.server",
client_python_path="examples.dp_fed_examples.client_level_dp_weighted.client",
config_path="tests/smoke_tests/client_level_dp_weighted_config.yaml",
dataset_path="examples/datasets/breast_cancer_data/hospital_0.csv",
skip_assert_client_fl_rounds=True,
tolerance=tolerance,
)
done, unfinished = await asyncio.wait([smoke_test_coro], return_when=asyncio.FIRST_EXCEPTION)

for task in done:
e = task.exception()
if e:
pytest.fail(f"Smoke test execution failed: {e}")
else:
server_errors, client_errors = task.result()
assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"

# cleanup
for task in unfinished:
task.cancel()
await asyncio.gather(*unfinished)


@pytest.mark.smoketest
async def test_instance_level_dp_cifar(tolerance: float) -> None:
event_loop = asyncio.get_event_loop()
try:
server_errors, client_errors = await run_smoke_test(
server_python_path="examples.dp_fed_examples.instance_level_dp.server",
Expand All @@ -214,35 +220,31 @@ async def test_instance_level_dp_cifar(tolerance: float) -> None:
tolerance=tolerance,
)
except Exception as e:
await cleanup_test(event_loop)
pytest.fail(f"Smoke test execution failed: {e}")

assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"


@pytest.mark.smoketest
async def test_dp_scaffold(tolerance: float) -> None:
event_loop = asyncio.get_event_loop()
try:
server_errors, client_errors = await run_smoke_test(
server_python_path="examples.dp_scaffold_example.server",
client_python_path="examples.dp_scaffold_example.client",
config_path="tests/smoke_tests/dp_scaffold_config.yaml",
dataset_path="examples/datasets/mnist_data/",
tolerance=tolerance,
)
except Exception as e:
await cleanup_test(event_loop)
pytest.fail(f"Smoke test execution failed: {e}")
# @pytest.mark.smoketest
# async def test_dp_scaffold(tolerance: float) -> None:
# try:
# server_errors, client_errors = await run_smoke_test(
# server_python_path="examples.dp_scaffold_example.server",
# client_python_path="examples.dp_scaffold_example.client",
# config_path="tests/smoke_tests/dp_scaffold_config.yaml",
# dataset_path="examples/datasets/mnist_data/",
# tolerance=tolerance,
# )
# except Exception as e:
# pytest.fail(f"Smoke test execution failed: {e}")

assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"
# assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
# assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"


@pytest.mark.smoketest
async def test_fedbn(tolerance: float) -> None:
event_loop = asyncio.get_event_loop()
try:
server_errors, client_errors = await run_smoke_test(
server_python_path="examples.fedbn_example.server",
Expand All @@ -252,51 +254,46 @@ async def test_fedbn(tolerance: float) -> None:
tolerance=tolerance,
)
except Exception as e:
await cleanup_test(event_loop)
pytest.fail(f"Smoke test execution failed: {e}")

assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"


@pytest.mark.smoketest
async def test_fed_eval(tolerance: float) -> None:
event_loop = asyncio.get_event_loop()
try:
server_errors, client_errors = await run_smoke_test(
server_python_path="examples.federated_eval_example.server",
client_python_path="examples.federated_eval_example.client",
config_path="tests/smoke_tests/federated_eval_config.yaml",
dataset_path="examples/datasets/cifar_data/",
checkpoint_path="examples/assets/fed_eval_example/best_checkpoint_fczjmljm.pkl",
assert_evaluation_logs=True,
tolerance=tolerance,
)
except Exception as e:
await cleanup_test(event_loop)
pytest.fail(f"Smoke test execution failed: {e}")
# @pytest.mark.smoketest
# async def test_fed_eval(tolerance: float) -> None:
# try:
# server_errors, client_errors = await run_smoke_test(
# server_python_path="examples.federated_eval_example.server",
# client_python_path="examples.federated_eval_example.client",
# config_path="tests/smoke_tests/federated_eval_config.yaml",
# dataset_path="examples/datasets/cifar_data/",
# checkpoint_path="examples/assets/fed_eval_example/best_checkpoint_fczjmljm.pkl",
# assert_evaluation_logs=True,
# tolerance=tolerance,
# )
# except Exception as e:
# pytest.fail(f"Smoke test execution failed: {e}")

assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"
# assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
# assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"


@pytest.mark.smoketest
async def test_fedper_mnist(tolerance: float) -> None:
event_loop = asyncio.get_event_loop()
try:
server_errors, client_errors = await run_smoke_test(
server_python_path="examples.fedper_example.server",
client_python_path="examples.fedper_example.client",
config_path="tests/smoke_tests/fedper_config.yaml",
dataset_path="examples/datasets/mnist_data/",
tolerance=tolerance,
)
except Exception as e:
await cleanup_test(event_loop)
pytest.fail(f"Smoke test execution failed: {e}")
# @pytest.mark.smoketest
# async def test_fedper_mnist(tolerance: float) -> None:
# try:
# server_errors, client_errors = await run_smoke_test(
# server_python_path="examples.fedper_example.server",
# client_python_path="examples.fedper_example.client",
# config_path="tests/smoke_tests/fedper_config.yaml",
# dataset_path="examples/datasets/mnist_data/",
# tolerance=tolerance,
# )
# except Exception as e:
# pytest.fail(f"Smoke test execution failed: {e}")

assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"
# assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}"
# assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}"


# @pytest.mark.smoketest
Expand Down

0 comments on commit 83db1c8

Please sign in to comment.