Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix UPath.rename type signature #258

Merged
merged 3 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions typesafety/test_upath_interface.yml
Original file line number Diff line number Diff line change
Expand Up @@ -565,3 +565,13 @@
from upath import UPath

reveal_type(UPath("abc").walk()) # N: Revealed type is "typing.Iterator[tuple[upath.core.UPath, builtins.list[builtins.str], builtins.list[builtins.str]]]"

- case: upath_rename_extra_kwargs
disable_cache: false
main: |
from upath import UPath

UPath("abc").rename("efg")
UPath("recursive bool").rename("efg", recursive=True)
UPath("maxdepth int").rename("efg", maxdepth=1)
UPath("untyped extras").rename("efg", overwrite=True, something="else")
16 changes: 9 additions & 7 deletions upath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
from typing import Mapping
from typing import Sequence
from typing import TextIO
from typing import TypedDict
from typing import TypeVar
from typing import overload
from urllib.parse import urlsplit

if sys.version_info >= (3, 11):
from typing import Self
from typing import Unpack
else:
from typing_extensions import Self
from typing_extensions import Unpack

from fsspec.registry import get_filesystem_class
from fsspec.spec import AbstractFileSystem
Expand Down Expand Up @@ -94,9 +91,7 @@ def _make_instance(cls, args, kwargs):
return cls(*args, **kwargs)


class _UPathRenameParams(TypedDict, total=False):
recursive: bool
maxdepth: int | None
_unset: Any = object()


# accessors are deprecated
Expand Down Expand Up @@ -1016,7 +1011,10 @@ def rmdir(self, recursive: bool = True) -> None: # fixme: non-standard
def rename(
self,
target: str | os.PathLike[str] | UPath,
**kwargs: Unpack[_UPathRenameParams], # note: non-standard compared to pathlib
*, # note: non-standard compared to pathlib
recursive: bool = _unset,
maxdepth: int | None = _unset,
**kwargs: Any,
) -> Self:
if isinstance(target, str) and self.storage_options:
target = UPath(target, **self.storage_options)
Expand All @@ -1040,6 +1038,10 @@ def rename(
parent = parent.resolve()
target_ = parent.joinpath(os.path.normpath(target))
assert isinstance(target_, type(self)), "identical protocols enforced above"
if recursive is not _unset:
kwargs["recursive"] = recursive
if maxdepth is not _unset:
kwargs["maxdepth"] = maxdepth
self.fs.mv(
self.path,
target_.path,
Expand Down
11 changes: 7 additions & 4 deletions upath/implementations/smb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import os
import sys
import warnings
from typing import Any

if sys.version_info >= (3, 11):
from typing import Self
from typing import Unpack
else:
from typing_extensions import Self
from typing_extensions import Unpack

import smbprotocol.exceptions

from upath import UPath
from upath.core import _UPathRenameParams

_unset: Any = object()


class SMBPath(UPath):
Expand Down Expand Up @@ -44,7 +44,10 @@ def iterdir(self):
def rename(
self,
target: str | os.PathLike[str] | UPath,
**kwargs: Unpack[_UPathRenameParams], # note: non-standard compared to pathlib
*,
recursive: bool = _unset,
maxdepth: int | None = _unset,
**kwargs: Any,
) -> Self:
if kwargs.pop("recursive", None) is not None:
warnings.warn(
Expand Down
6 changes: 4 additions & 2 deletions upath/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ def xfail_if_version(module, *, reason, **conditions):
def xfail_if_no_ssl_connection(func):
try:
import requests

except ImportError:
return pytest.mark.skip(reason="requests not installed")(func)
try:
requests.get("https://example.com")
except (ImportError, requests.exceptions.SSLError):
except (requests.exceptions.ConnectionError, requests.exceptions.SSLError):
return pytest.mark.xfail(reason="No SSL connection")(func)
else:
return func
Expand Down
Loading