diff --git a/andromeda/main.py b/andromeda/main.py index db94284..a2e7c0a 100644 --- a/andromeda/main.py +++ b/andromeda/main.py @@ -167,6 +167,15 @@ def csv_reponse_for_observations(fieldnames, observations, user_id): headers={"Content-disposition": f"attachment; filename={filename}"}) +def get_host_url(request): + # HuggingFace Space HOST + space_host = os.environ.get('SPACE_HOST') + if space_host: + return f"https://{space_host}" + else: + return request.host_url + + @app.route('/api/inaturalist//dataset', methods=['POST']) def create_inaturalist_dataset(user_id): add_sat_rgb_data = get_boolean_param(request, "add_sat_rgb_data") @@ -181,7 +190,8 @@ def create_inaturalist_dataset(user_id): dataset_store = DatasetStore(base_directory=UPLOAD_FOLDER) dataset = dataset_store.create_dataset_with_content(csv_content) filename = get_csv_filename(user_id=user_id) - download_url = f"{request.host_url}/api/dataset/{dataset.id}?filename={filename}" + host_url = get_host_url(request=request) + download_url = f"{host_url}/api/dataset/{dataset.id}?filename={filename}" return jsonify({ "id": dataset.id, "url": download_url, diff --git a/andromeda/tests/test_main.py b/andromeda/tests/test_main.py index 76eb702..492b933 100644 --- a/andromeda/tests/test_main.py +++ b/andromeda/tests/test_main.py @@ -150,6 +150,24 @@ def test_create_inaturalist_dataset(self, mock_get_inaturalist_observations): result = client.post(f"/api/inaturalist/bob/dataset") self.assertEqual(result.status_code, 200) self.assertEqual(list(result.json.keys()), ["id", "url", "warnings"]) + self.assertIn("http://localhost", result.json["url"]) + + @patch("main.get_inaturalist_observations") + @patch.dict('main.os.environ', {"SPACE_HOST": "example.org"}) + def test_create_inaturalist_dataset_hugging_face(self, mock_get_inaturalist_observations): + observations = [{"Image_Label": "p1"}] + warnings = ["missing_lat_long"] + fieldnames = ["Image_Label", "Image_Link", "Species", "User", "Date", "Time", + "Seconds", "Place", "Lat", "Long"] + mock_get_inaturalist_observations.return_value = Mock( + data=observations, + fieldnames=fieldnames, + warnings=warnings) + client = app.test_client() + result = client.post(f"/api/inaturalist/bob/dataset") + self.assertEqual(result.status_code, 200) + self.assertEqual(list(result.json.keys()), ["id", "url", "warnings"]) + self.assertIn("https://example.org", result.json["url"]) @patch("main.get_inaturalist_observations") def test_get_inaturalist_csv(self, mock_get_inaturalist_observations):