diff --git a/tests/smoke_tests/test_smoke_tests.py b/tests/smoke_tests/test_smoke_tests.py index ffe405c9e..a344625a8 100644 --- a/tests/smoke_tests/test_smoke_tests.py +++ b/tests/smoke_tests/test_smoke_tests.py @@ -183,27 +183,33 @@ async def cleanup_test(event_loop: asyncio.AbstractEventLoop) -> None: @pytest.mark.smoketest async def test_client_level_dp_breast_cancer(tolerance: float) -> None: - event_loop = asyncio.get_event_loop() - try: - server_errors, client_errors = await run_smoke_test( - server_python_path="examples.dp_fed_examples.client_level_dp_weighted.server", - client_python_path="examples.dp_fed_examples.client_level_dp_weighted.client", - config_path="tests/smoke_tests/client_level_dp_weighted_config.yaml", - dataset_path="examples/datasets/breast_cancer_data/hospital_0.csv", - skip_assert_client_fl_rounds=True, - tolerance=tolerance, - ) - except Exception as e: - await cleanup_test(event_loop) - pytest.fail(f"Smoke test execution failed: {e}") - - assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" - assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" + smoke_test_coro = run_smoke_test( + server_python_path="examples.dp_fed_examples.client_level_dp_weighted.server", + client_python_path="examples.dp_fed_examples.client_level_dp_weighted.client", + config_path="tests/smoke_tests/client_level_dp_weighted_config.yaml", + dataset_path="examples/datasets/breast_cancer_data/hospital_0.csv", + skip_assert_client_fl_rounds=True, + tolerance=tolerance, + ) + done, unfinished = await asyncio.wait([smoke_test_coro], return_when=asyncio.FIRST_EXCEPTION) + + for task in done: + e = task.exception() + if e: + pytest.fail(f"Smoke test execution failed: {e}") + else: + server_errors, client_errors = task.result() + assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" + assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" + + # cleanup + for task in unfinished: + task.cancel() + await asyncio.gather(*unfinished) @pytest.mark.smoketest async def test_instance_level_dp_cifar(tolerance: float) -> None: - event_loop = asyncio.get_event_loop() try: server_errors, client_errors = await run_smoke_test( server_python_path="examples.dp_fed_examples.instance_level_dp.server", @@ -214,35 +220,31 @@ async def test_instance_level_dp_cifar(tolerance: float) -> None: tolerance=tolerance, ) except Exception as e: - await cleanup_test(event_loop) pytest.fail(f"Smoke test execution failed: {e}") assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" -@pytest.mark.smoketest -async def test_dp_scaffold(tolerance: float) -> None: - event_loop = asyncio.get_event_loop() - try: - server_errors, client_errors = await run_smoke_test( - server_python_path="examples.dp_scaffold_example.server", - client_python_path="examples.dp_scaffold_example.client", - config_path="tests/smoke_tests/dp_scaffold_config.yaml", - dataset_path="examples/datasets/mnist_data/", - tolerance=tolerance, - ) - except Exception as e: - await cleanup_test(event_loop) - pytest.fail(f"Smoke test execution failed: {e}") +# @pytest.mark.smoketest +# async def test_dp_scaffold(tolerance: float) -> None: +# try: +# server_errors, client_errors = await run_smoke_test( +# server_python_path="examples.dp_scaffold_example.server", +# client_python_path="examples.dp_scaffold_example.client", +# config_path="tests/smoke_tests/dp_scaffold_config.yaml", +# dataset_path="examples/datasets/mnist_data/", +# tolerance=tolerance, +# ) +# except Exception as e: +# pytest.fail(f"Smoke test execution failed: {e}") - assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" - assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" +# assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" +# assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" @pytest.mark.smoketest async def test_fedbn(tolerance: float) -> None: - event_loop = asyncio.get_event_loop() try: server_errors, client_errors = await run_smoke_test( server_python_path="examples.fedbn_example.server", @@ -252,51 +254,46 @@ async def test_fedbn(tolerance: float) -> None: tolerance=tolerance, ) except Exception as e: - await cleanup_test(event_loop) pytest.fail(f"Smoke test execution failed: {e}") assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" -@pytest.mark.smoketest -async def test_fed_eval(tolerance: float) -> None: - event_loop = asyncio.get_event_loop() - try: - server_errors, client_errors = await run_smoke_test( - server_python_path="examples.federated_eval_example.server", - client_python_path="examples.federated_eval_example.client", - config_path="tests/smoke_tests/federated_eval_config.yaml", - dataset_path="examples/datasets/cifar_data/", - checkpoint_path="examples/assets/fed_eval_example/best_checkpoint_fczjmljm.pkl", - assert_evaluation_logs=True, - tolerance=tolerance, - ) - except Exception as e: - await cleanup_test(event_loop) - pytest.fail(f"Smoke test execution failed: {e}") +# @pytest.mark.smoketest +# async def test_fed_eval(tolerance: float) -> None: +# try: +# server_errors, client_errors = await run_smoke_test( +# server_python_path="examples.federated_eval_example.server", +# client_python_path="examples.federated_eval_example.client", +# config_path="tests/smoke_tests/federated_eval_config.yaml", +# dataset_path="examples/datasets/cifar_data/", +# checkpoint_path="examples/assets/fed_eval_example/best_checkpoint_fczjmljm.pkl", +# assert_evaluation_logs=True, +# tolerance=tolerance, +# ) +# except Exception as e: +# pytest.fail(f"Smoke test execution failed: {e}") - assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" - assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" +# assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" +# assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" -@pytest.mark.smoketest -async def test_fedper_mnist(tolerance: float) -> None: - event_loop = asyncio.get_event_loop() - try: - server_errors, client_errors = await run_smoke_test( - server_python_path="examples.fedper_example.server", - client_python_path="examples.fedper_example.client", - config_path="tests/smoke_tests/fedper_config.yaml", - dataset_path="examples/datasets/mnist_data/", - tolerance=tolerance, - ) - except Exception as e: - await cleanup_test(event_loop) - pytest.fail(f"Smoke test execution failed: {e}") +# @pytest.mark.smoketest +# async def test_fedper_mnist(tolerance: float) -> None: +# try: +# server_errors, client_errors = await run_smoke_test( +# server_python_path="examples.fedper_example.server", +# client_python_path="examples.fedper_example.client", +# config_path="tests/smoke_tests/fedper_config.yaml", +# dataset_path="examples/datasets/mnist_data/", +# tolerance=tolerance, +# ) +# except Exception as e: +# pytest.fail(f"Smoke test execution failed: {e}") - assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" - assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" +# assert len(server_errors) == 0, f"Server metrics check failed. Errors: {server_errors}" +# assert len(client_errors) == 0, f"Client metrics check failed. Errors: {client_errors}" # @pytest.mark.smoketest