From e813d2a5145948b228a443870da6631448e6b3aa Mon Sep 17 00:00:00 2001 From: Arthur LE MOIGNE Date: Sat, 30 Apr 2022 14:10:28 +0200 Subject: [PATCH 1/2] Add reading of setuptools metadata to find smart_open transport / compressor extensions --- smart_open/compression.py | 15 +++++++ smart_open/tests/fixtures/compressor.py | 10 +++++ smart_open/tests/test_compression.py | 53 +++++++++++++++++++++++ smart_open/tests/test_transport.py | 56 ++++++++++++++++++++++++- smart_open/transport.py | 12 ++++++ 5 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 smart_open/tests/fixtures/compressor.py create mode 100644 smart_open/tests/test_compression.py diff --git a/smart_open/compression.py b/smart_open/compression.py index fe022460..45538a61 100644 --- a/smart_open/compression.py +++ b/smart_open/compression.py @@ -7,6 +7,8 @@ # """Implements the compression layer of the ``smart_open`` library.""" import logging +import importlib +import importlib.metadata import os.path logger = logging.getLogger(__name__) @@ -145,3 +147,16 @@ def compression_wrapper(file_obj, mode, compression): # register_compressor('.bz2', _handle_bz2) register_compressor('.gz', _handle_gzip) + + +def _register_compressor_entry_point(ep): + try: + assert len(ep.name) > 0, "At least one char is required for ep.name" + extension = ".{}".format(ep.name) + register_compressor(extension, ep.load()) + except Exception: + logger.warning("Fail to load smart_open compressor extension: %s (target: %s)", ep.name, ep.value) + + +for ep in importlib.metadata.entry_points().select(group='smart_open_compressor'): + _register_compressor_entry_point(ep) diff --git a/smart_open/tests/fixtures/compressor.py b/smart_open/tests/fixtures/compressor.py new file mode 100644 index 00000000..acc6a772 --- /dev/null +++ b/smart_open/tests/fixtures/compressor.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +"""Some no-op compressor""" + + +def handle_foo(): + ... + + +def handle_bar(): + ... diff --git a/smart_open/tests/test_compression.py b/smart_open/tests/test_compression.py new file mode 100644 index 00000000..d182056e --- /dev/null +++ b/smart_open/tests/test_compression.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +from importlib.metadata import EntryPoint +import pytest + +from smart_open.compression import _COMPRESSOR_REGISTRY, _register_compressor_entry_point + + +def unregister_compressor(ext): + if ext in _COMPRESSOR_REGISTRY: + del _COMPRESSOR_REGISTRY[ext] + + +@pytest.fixture(autouse=True) +def cleanup_compressor(): + unregister_compressor(".foo") + unregister_compressor(".bar") + + +def test_register_valid_entry_point(): + assert ".foo" not in _COMPRESSOR_REGISTRY + assert ".bar" not in _COMPRESSOR_REGISTRY + _register_compressor_entry_point(EntryPoint( + "foo", + "smart_open.tests.fixtures.compressor:handle_bar", + "smart_open_compressor", + )) + _register_compressor_entry_point(EntryPoint( + "bar", + "smart_open.tests.fixtures.compressor:handle_bar", + "smart_open_compressor", + )) + assert ".foo" in _COMPRESSOR_REGISTRY + assert ".bar" in _COMPRESSOR_REGISTRY + + +def test_register_invalid_entry_point_name_do_not_crash(): + _register_compressor_entry_point(EntryPoint( + "", + "smart_open.tests.fixtures.compressor:handle_foo", + "smart_open_compressor", + )) + assert "" not in _COMPRESSOR_REGISTRY + assert "." not in _COMPRESSOR_REGISTRY + + +def test_register_invalid_entry_point_value_do_not_crash(): + _register_compressor_entry_point(EntryPoint( + "foo", + "smart_open.tests.fixtures.compressor:handle_invalid", + "smart_open_compressor", + )) + assert ".foo" not in _COMPRESSOR_REGISTRY diff --git a/smart_open/tests/test_transport.py b/smart_open/tests/test_transport.py index c44b04ab..b39ed840 100644 --- a/smart_open/tests/test_transport.py +++ b/smart_open/tests/test_transport.py @@ -1,16 +1,44 @@ # -*- coding: utf-8 -*- +from importlib.metadata import EntryPoint import pytest import unittest -from smart_open.transport import register_transport, get_transport +from smart_open.transport import ( + register_transport, get_transport, _REGISTRY, _ERRORS, _register_transport_entry_point +) + + +def unregister_transport(x): + if x in _REGISTRY: + del _REGISTRY[x] + if x in _ERRORS: + del _ERRORS[x] + + +def assert_transport_not_registered(scheme): + with pytest.raises(NotImplementedError): + get_transport(scheme) + + +def assert_transport_registered(scheme): + transport = get_transport(scheme) + assert transport.SCHEME == scheme class TransportTest(unittest.TestCase): + def tearDown(self): + unregister_transport("foo") + unregister_transport("missing") def test_registry_requires_declared_schemes(self): with pytest.raises(ValueError): register_transport('smart_open.tests.fixtures.no_schemes_transport') + def test_registry_valid_transport(self): + assert_transport_not_registered("foo") + register_transport('smart_open.tests.fixtures.good_transport') + assert_transport_registered("foo") + def test_registry_errors_on_double_register_scheme(self): register_transport('smart_open.tests.fixtures.good_transport') with pytest.raises(AssertionError): @@ -20,3 +48,29 @@ def test_registry_errors_get_transport_for_module_with_missing_deps(self): register_transport('smart_open.tests.fixtures.missing_deps_transport') with pytest.raises(ImportError): get_transport("missing") + + def test_register_entry_point_valid(self): + assert_transport_not_registered("foo") + _register_transport_entry_point(EntryPoint( + "foo", + "smart_open.tests.fixtures.good_transport", + "smart_open_transport", + )) + assert_transport_registered("foo") + + def test_register_entry_point_catch_bad_data(self): + _register_transport_entry_point(EntryPoint( + "invalid", + "smart_open.some_totaly_invalid_module", + "smart_open_transport", + )) + + def test_register_entry_point_for_module_with_missing_deps(self): + assert_transport_not_registered("missing") + _register_transport_entry_point(EntryPoint( + "missing", + "smart_open.tests.fixtures.missing_deps_transport", + "smart_open_transport", + )) + with pytest.raises(ImportError): + get_transport("missing") diff --git a/smart_open/transport.py b/smart_open/transport.py index 00fb27d7..05ac2257 100644 --- a/smart_open/transport.py +++ b/smart_open/transport.py @@ -11,6 +11,7 @@ """ import importlib +import importlib.metadata import logging import smart_open.local_file @@ -102,5 +103,16 @@ def get_transport(scheme): register_transport('smart_open.ssh') register_transport('smart_open.webhdfs') + +def _register_transport_entry_point(ep): + try: + register_transport(ep.value) + except Exception: + logger.warning("Fail to load smart_open transport extension: %s (target: %s)", ep.name, ep.value) + + +for ep in importlib.metadata.entry_points().select(group='smart_open_transport'): + _register_transport_entry_point(ep) + SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys())) """The transport schemes that the local installation of ``smart_open`` supports.""" From ff281248670eca4bc2ddc69e19cb0606d046e19b Mon Sep 17 00:00:00 2001 From: Arthur LE MOIGNE Date: Sun, 1 May 2022 12:06:06 +0200 Subject: [PATCH 2/2] Handle multiple version of python importlib and add compat library for legacy python versions --- setup.py | 1 + smart_open/compression.py | 6 ++--- smart_open/tests/test_compression.py | 4 +++- smart_open/tests/test_transport.py | 4 +++- smart_open/tests/test_utils.py | 11 +++++++++ smart_open/transport.py | 4 ++-- smart_open/utils.py | 35 ++++++++++++++++++++++++++++ 7 files changed, 58 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index def401b2..7015ed65 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ def read(fname): 'all': all_deps, 'http': http_deps, 'webhdfs': http_deps, + ':python_version<"3.8"': ['importlib-metadata'], }, python_requires=">=3.6,<4.0", diff --git a/smart_open/compression.py b/smart_open/compression.py index 45538a61..09baf488 100644 --- a/smart_open/compression.py +++ b/smart_open/compression.py @@ -7,10 +7,10 @@ # """Implements the compression layer of the ``smart_open`` library.""" import logging -import importlib -import importlib.metadata import os.path +from smart_open.utils import find_entry_points + logger = logging.getLogger(__name__) _COMPRESSOR_REGISTRY = {} @@ -158,5 +158,5 @@ def _register_compressor_entry_point(ep): logger.warning("Fail to load smart_open compressor extension: %s (target: %s)", ep.name, ep.value) -for ep in importlib.metadata.entry_points().select(group='smart_open_compressor'): +for ep in find_entry_points('smart_open_compressor'): _register_compressor_entry_point(ep) diff --git a/smart_open/tests/test_compression.py b/smart_open/tests/test_compression.py index d182056e..b8546507 100644 --- a/smart_open/tests/test_compression.py +++ b/smart_open/tests/test_compression.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- -from importlib.metadata import EntryPoint import pytest +from smart_open.utils import importlib_metadata from smart_open.compression import _COMPRESSOR_REGISTRY, _register_compressor_entry_point +EntryPoint = importlib_metadata.EntryPoint + def unregister_compressor(ext): if ext in _COMPRESSOR_REGISTRY: diff --git a/smart_open/tests/test_transport.py b/smart_open/tests/test_transport.py index b39ed840..27236ba5 100644 --- a/smart_open/tests/test_transport.py +++ b/smart_open/tests/test_transport.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- -from importlib.metadata import EntryPoint import pytest import unittest from smart_open.transport import ( register_transport, get_transport, _REGISTRY, _ERRORS, _register_transport_entry_point ) +from smart_open.utils import importlib_metadata + +EntryPoint = importlib_metadata.EntryPoint def unregister_transport(x): diff --git a/smart_open/tests/test_utils.py b/smart_open/tests/test_utils.py index c6be9a2d..8f8251e0 100644 --- a/smart_open/tests/test_utils.py +++ b/smart_open/tests/test_utils.py @@ -59,3 +59,14 @@ def test_check_kwargs(): def test_safe_urlsplit(url, expected): actual = smart_open.utils.safe_urlsplit(url) assert actual == urllib.parse.SplitResult(*expected) + + +def test_find_entry_points(): + # Installed through setup.py tests requirements + eps = smart_open.utils.find_entry_points("pytest11") + eps_names = {ep.name for ep in eps} + assert "rerunfailures" in eps_names + + # Part of setuptools + eps = smart_open.utils.find_entry_points("distutils.commands") + assert len(eps) > 0 diff --git a/smart_open/transport.py b/smart_open/transport.py index 05ac2257..66a0d220 100644 --- a/smart_open/transport.py +++ b/smart_open/transport.py @@ -11,10 +11,10 @@ """ import importlib -import importlib.metadata import logging import smart_open.local_file +from smart_open.utils import find_entry_points logger = logging.getLogger(__name__) @@ -111,7 +111,7 @@ def _register_transport_entry_point(ep): logger.warning("Fail to load smart_open transport extension: %s (target: %s)", ep.name, ep.value) -for ep in importlib.metadata.entry_points().select(group='smart_open_transport'): +for ep in find_entry_points(group='smart_open_transport'): _register_transport_entry_point(ep) SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys())) diff --git a/smart_open/utils.py b/smart_open/utils.py index 4fc6aa84..2be772c5 100644 --- a/smart_open/utils.py +++ b/smart_open/utils.py @@ -10,8 +10,16 @@ import inspect import logging +import sys import urllib.parse +# Library has been added in version 3.8 +# See: https://docs.python.org/3.8/library/importlib.metadata.html#module-importlib.metadata +if sys.version_info >= (3, 8): + import importlib.metadata as importlib_metadata +else: + import importlib_metadata + logger = logging.getLogger(__name__) WORKAROUND_SCHEMES = ['s3', 's3n', 's3u', 's3a', 'gs'] @@ -189,3 +197,30 @@ def safe_urlsplit(url): path = sr.path.replace(placeholder, '?') return urllib.parse.SplitResult(sr.scheme, sr.netloc, path, '', '') + + +def find_entry_points(group): + """Search packages entry points and filter value based + on their group name. + + Parameters + ---------- + group: str + An entry point group name to filter registered entry points. + + Returns + ------- + list[EntryPoint] + Valid registered entry points. + """ + + try: + # Try new filter API (python 3.10+ and importlib_metadata 3.6+). + # + # Check "Compatibility Note" here: + # https://docs.python.org/3.10/library/importlib.metadata.html#entry-points + return importlib_metadata.entry_points(group=group) + except Exception: + # Legacy API + eps = importlib_metadata.entry_points() + return eps.get(group, [])