From 5099f4500477f0484b2bc48bf7aa0dedb458b748 Mon Sep 17 00:00:00 2001 From: Bolke de Bruin Date: Fri, 27 Oct 2023 17:43:15 +0200 Subject: [PATCH] Add extra --- airflow/io/store/path.py | 21 ++++++++++++++++++--- tests/io/store/test_store.py | 5 +++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/airflow/io/store/path.py b/airflow/io/store/path.py index 2fe8ffad99afc..7ee0971855143 100644 --- a/airflow/io/store/path.py +++ b/airflow/io/store/path.py @@ -45,6 +45,13 @@ class ObjectStoragePath(os.PathLike): sep: typing.ClassVar[str] = "/" root_marker: typing.ClassVar[str] = "/" + _store: ObjectStore | None + _bucket: str + _key: str + _conn_id: str | None + _protocol: str + _hash: int | None + __slots__ = ( "_store", "_bucket", @@ -54,18 +61,26 @@ class ObjectStoragePath(os.PathLike): "_hash", ) - def __init__(self, path, conn_id: str | None = None, store: ObjectStore | None = None): + def __init__( + self, path: str | ObjectStoragePath, conn_id: str | None = None, store: ObjectStore | None = None + ): self._conn_id = conn_id self._store = store self._hash = None - self._protocol, self._bucket, self._key = self.split_path(path) + if isinstance(path, ObjectStoragePath): + self._protocol = path._protocol + self._bucket = path._bucket + self._key = path._key + self._store = path._store + else: + self._protocol, self._bucket, self._key = self.split_path(path) if store: self._conn_id = store.conn_id self._protocol = self._protocol if self._protocol else store.protocol - elif self._protocol: + elif self._protocol and not self._store: self._store = attach(self._protocol, conn_id) @classmethod diff --git a/tests/io/store/test_store.py b/tests/io/store/test_store.py index e16796637e64a..db1609f283a80 100644 --- a/tests/io/store/test_store.py +++ b/tests/io/store/test_store.py @@ -61,6 +61,11 @@ def test_init_objectstoragepath(self): assert path.key == "key/part1/part2" assert path._protocol == "file" + path2 = ObjectStoragePath(path / "part3") + assert path2.bucket == "bucket" + assert path2.key == "key/part1/part2/part3" + assert path2._protocol == "file" + def test_read_write(self): o = ObjectStoragePath(f"file:///tmp/{str(uuid.uuid4())}")