Skip to content

Commit

Permalink
Extend backup upload API with file name parameter (#5568)
Browse files Browse the repository at this point in the history
* Extend backup upload API with file name parameter

Add a query parameter which allows to specify the file name on upload.
All locations will store the backup with the same file name.

* ruff format

* Update tests to cover bad filename

* Fix ruff check error

* Drop unnecessary logging
  • Loading branch information
agners authored Jan 27, 2025
1 parent 2a8d2d2 commit 1b0aa30
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
14 changes: 13 additions & 1 deletion supervisor/api/backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from aiohttp import web
from aiohttp.hdrs import CONTENT_DISPOSITION
import voluptuous as vol
from voluptuous.humanize import humanize_error

from ..backups.backup import Backup
from ..backups.const import LOCATION_CLOUD_BACKUP, LOCATION_TYPE
Expand Down Expand Up @@ -503,6 +504,14 @@ async def upload(self, request: web.Request):
if location and location != LOCATION_CLOUD_BACKUP:
tmp_path = location.local_where

filename: str | None = None
if ATTR_FILENAME in request.query:
filename = request.query.get(ATTR_FILENAME)
try:
vol.Match(RE_BACKUP_FILENAME)(filename)
except vol.Invalid as ex:
raise APIError(humanize_error(filename, ex)) from None

with TemporaryDirectory(dir=tmp_path.as_posix()) as temp_dir:
tar_file = Path(temp_dir, "backup.tar")
reader = await request.multipart()
Expand All @@ -529,7 +538,10 @@ async def upload(self, request: web.Request):

backup = await asyncio.shield(
self.sys_backups.import_backup(
tar_file, location=location, additional_locations=locations
tar_file,
filename,
location=location,
additional_locations=locations,
)
)

Expand Down
11 changes: 8 additions & 3 deletions supervisor/backups/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def copy_to_additional_locations() -> dict[str | None, Path]:
async def import_backup(
self,
tar_file: Path,
filename: str | None = None,
location: LOCATION_TYPE = None,
additional_locations: list[LOCATION_TYPE] | None = None,
) -> Backup | None:
Expand All @@ -376,9 +377,13 @@ async def import_backup(
return None

# Move backup to destination folder
tar_origin = Path(self._get_base_path(location), f"{backup.slug}.tar")
if filename:
tar_file = Path(self._get_base_path(location), Path(filename).name)
else:
tar_file = Path(self._get_base_path(location), f"{backup.slug}.tar")

try:
backup.tarfile.rename(tar_origin)
backup.tarfile.rename(tar_file)

except OSError as err:
if err.errno == errno.EBADMSG and location in {LOCATION_CLOUD_BACKUP, None}:
Expand All @@ -387,7 +392,7 @@ async def import_backup(
return None

# Load new backup
backup = Backup(self.coresys, tar_origin, backup.slug, location, backup.data)
backup = Backup(self.coresys, tar_file, backup.slug, location, backup.data)
if not await backup.load():
# Remove invalid backup from location it was moved to
backup.tarfile.unlink()
Expand Down
37 changes: 37 additions & 0 deletions tests/api/test_backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,43 @@ async def test_upload_duplicate_backup_new_location(
assert coresys.backups.get("7fed74c8").location is None


@pytest.mark.parametrize(
("filename", "expected_status"),
[("good.tar", 200), ("../bad.tar", 400), ("bad", 400)],
)
@pytest.mark.usefixtures("tmp_supervisor_data")
async def test_upload_with_filename(
api_client: TestClient, coresys: CoreSys, filename: str, expected_status: int
):
"""Test uploading a backup to multiple locations."""
backup_file = get_fixture_path("backup_example.tar")

with backup_file.open("rb") as file, MultipartWriter("form-data") as mp:
mp.append(file)
resp = await api_client.post(
f"/backups/new/upload?filename={filename}", data=mp
)

assert resp.status == expected_status
body = await resp.json()
if expected_status != 200:
assert (
body["message"]
== r"does not match regular expression ^[^\\\/]+\.tar$."
+ f" Got '{filename}'"
)
return

assert body["data"]["slug"] == "7fed74c8"

orig_backup = coresys.config.path_backup / filename
assert orig_backup.exists()
assert coresys.backups.get("7fed74c8").all_locations == {
None: {"path": orig_backup, "protected": False}
}
assert coresys.backups.get("7fed74c8").location is None


@pytest.mark.parametrize(
("method", "url"),
[
Expand Down

0 comments on commit 1b0aa30

Please sign in to comment.