Skip to content

Commit

Permalink
Extend backup upload API with file name parameter
Browse files Browse the repository at this point in the history
Add a query parameter which allows to specify the file name on upload.
All locations will store the backup with the same file name.
  • Loading branch information
agners committed Jan 22, 2025
1 parent 805017e commit 2722f4d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
7 changes: 6 additions & 1 deletion supervisor/api/backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ATTR_DATE,
ATTR_DAYS_UNTIL_STALE,
ATTR_EXTRA,
ATTR_FILENAME,
ATTR_FOLDERS,
ATTR_HOMEASSISTANT,
ATTR_HOMEASSISTANT_EXCLUDE_DATABASE,
Expand Down Expand Up @@ -484,6 +485,10 @@ async def upload(self, request: web.Request):
if location and location != LOCATION_CLOUD_BACKUP:
tmp_path = location.local_where

filename: str | None = None
if ATTR_FILENAME in request.query:
filename = request.query.get(ATTR_FILENAME)

with TemporaryDirectory(dir=tmp_path.as_posix()) as temp_dir:
tar_file = Path(temp_dir, "backup.tar")
reader = await request.multipart()
Expand All @@ -510,7 +515,7 @@ async def upload(self, request: web.Request):

backup = await asyncio.shield(
self.sys_backups.import_backup(
tar_file, location=location, additional_locations=locations
tar_file, filename, location=location, additional_locations=locations
)
)

Expand Down
11 changes: 8 additions & 3 deletions supervisor/backups/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def copy_to_additional_locations() -> dict[str | None, Path]:
async def import_backup(
self,
tar_file: Path,
filename: str | None = None,
location: LOCATION_TYPE = None,
additional_locations: list[LOCATION_TYPE] | None = None,
) -> Backup | None:
Expand All @@ -362,9 +363,13 @@ async def import_backup(
return None

# Move backup to destination folder
tar_origin = Path(self._get_base_path(location), f"{backup.slug}.tar")
if filename:
tar_file = Path(self._get_base_path(location), Path(filename).name)
else:
tar_file = Path(self._get_base_path(location), f"{backup.slug}.tar")

try:
backup.tarfile.rename(tar_origin)
backup.tarfile.rename(tar_file)

except OSError as err:
if err.errno == errno.EBADMSG and location in {LOCATION_CLOUD_BACKUP, None}:
Expand All @@ -373,7 +378,7 @@ async def import_backup(
return None

# Load new backup
backup = Backup(self.coresys, tar_origin, backup.slug, location, backup.data)
backup = Backup(self.coresys, tar_file, backup.slug, location, backup.data)
if not await backup.load():
# Remove invalid backup from location it was moved to
backup.tarfile.unlink()
Expand Down
23 changes: 23 additions & 0 deletions tests/api/test_backups.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,29 @@ async def test_upload_duplicate_backup_new_location(
assert coresys.backups.get("7fed74c8").location is None


@pytest.mark.usefixtures("tmp_supervisor_data")
async def test_upload_with_filename(api_client: TestClient, coresys: CoreSys):
"""Test uploading a backup to multiple locations."""
backup_file = get_fixture_path("backup_example.tar")

with backup_file.open("rb") as file, MultipartWriter("form-data") as mp:
mp.append(file)
resp = await api_client.post(
"/backups/new/upload?filename=abc.tar", data=mp
)

assert resp.status == 200
body = await resp.json()
assert body["data"]["slug"] == "7fed74c8"

orig_backup = coresys.config.path_backup / "abc.tar"
assert orig_backup.exists()
assert coresys.backups.get("7fed74c8").all_locations == {
None: orig_backup,
}
assert coresys.backups.get("7fed74c8").location is None


@pytest.mark.parametrize(
("method", "url"),
[
Expand Down

0 comments on commit 2722f4d

Please sign in to comment.