From 9192d5969b7e1538045d565100955d3a59edb6b6 Mon Sep 17 00:00:00 2001 From: Andrew Fulton Date: Mon, 31 May 2021 13:03:43 -0400 Subject: [PATCH] fixes issue #18 --- upath/__init__.py | 2 +- upath/core.py | 10 +++++----- upath/implementations/s3.py | 20 ++++++++++++++++++++ upath/tests/implementations/test_s3.py | 12 ++++++++++++ 4 files changed, 38 insertions(+), 6 deletions(-) diff --git a/upath/__init__.py b/upath/__init__.py index 8e0c8ff1..bfb51f2b 100644 --- a/upath/__init__.py +++ b/upath/__init__.py @@ -1,4 +1,4 @@ """Pathlib API extended to use fsspec backends""" -__version__ = "0.0.9" +__version__ = "0.0.10" from upath.core import UPath diff --git a/upath/core.py b/upath/core.py index 6f404cdd..4113b9f8 100644 --- a/upath/core.py +++ b/upath/core.py @@ -10,9 +10,9 @@ class UPath(pathlib.Path): def __new__(cls, *args, **kwargs): if cls is UPath: - new_args = list(args) - first_arg = new_args.pop(0) - parsed_url = urllib.parse.urlparse(first_arg) + args_list = list(args) + url = args_list.pop(0) + parsed_url = urllib.parse.urlparse(url) for key in ["scheme", "netloc"]: val = kwargs.get(key) if val: @@ -34,8 +34,8 @@ def __new__(cls, *args, **kwargs): else: cls = _registry[parsed_url.scheme] kwargs["_url"] = parsed_url - new_args.insert(0, parsed_url.path) - args = tuple(new_args) + args_list.insert(0, parsed_url.path) + args = tuple(args_list) self = cls._from_parts_init(args, init=False) self._init(*args, **kwargs) return self diff --git a/upath/implementations/s3.py b/upath/implementations/s3.py index 21d31524..4b09a9a7 100644 --- a/upath/implementations/s3.py +++ b/upath/implementations/s3.py @@ -24,3 +24,23 @@ def _sub_path(self, name): sp = self.path subed = re.sub(f"^{self._url.netloc}/({sp}|{sp[1:]})/?", "", name) return subed + + def _init(self, *args, template=None, **kwargs): + if kwargs.get("bucket") and kwargs.get("_url"): + bucket = kwargs.pop("bucket") + kwargs["_url"] = kwargs["_url"]._replace(netloc=bucket) + super()._init(*args, template=template, **kwargs) + + def joinpath(self, *args): + if self._url.netloc: + return super().joinpath(*args) + # handles a bucket in the path + else: + path = args[0] + if isinstance(path, list): + args_list = list(*args) + else: + args_list = path.split(self._flavour.sep) + bucket = args_list.pop(0) + self._kwargs["bucket"] = bucket + return super().joinpath(*tuple(args_list)) diff --git a/upath/tests/implementations/test_s3.py b/upath/tests/implementations/test_s3.py index aa50aad4..7721023e 100644 --- a/upath/tests/implementations/test_s3.py +++ b/upath/tests/implementations/test_s3.py @@ -83,3 +83,15 @@ def test_fsspec_compat(self): upath2 = UPath(p2, anon=self.anon, **self.s3so) assert upath2.read_bytes() == content upath2.unlink() + + @pytest.mark.parametrize( + "joiner", [["bucket", "path", "file"], "bucket/path/file"] + ) + def test_no_bucket_joinpath(self, joiner): + path = UPath("s3://", anon=self.anon, **self.s3so) + path = path.joinpath(joiner) + assert str(path) == "s3://bucket/path/file" + + def test_creating_s3path_with_bucket(self): + path = UPath("s3://", bucket="bucket", anon=self.anon, **self.s3so) + assert str(path) == "s3://bucket/"