Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read importlib.metadata to find transport / compressor extensions #697

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def read(fname):
'all': all_deps,
'http': http_deps,
'webhdfs': http_deps,
':python_version<"3.8"': ['importlib-metadata'],
},
python_requires=">=3.6,<4.0",

Expand Down
15 changes: 15 additions & 0 deletions smart_open/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import logging
import os.path

from smart_open.utils import find_entry_points

logger = logging.getLogger(__name__)

_COMPRESSOR_REGISTRY = {}
Expand Down Expand Up @@ -145,3 +147,16 @@ def compression_wrapper(file_obj, mode, compression):
#
register_compressor('.bz2', _handle_bz2)
register_compressor('.gz', _handle_gzip)


def _register_compressor_entry_point(ep):
try:
assert len(ep.name) > 0, "At least one char is required for ep.name"
extension = ".{}".format(ep.name)
register_compressor(extension, ep.load())
except Exception:
logger.warning("Fail to load smart_open compressor extension: %s (target: %s)", ep.name, ep.value)


for ep in find_entry_points('smart_open_compressor'):
_register_compressor_entry_point(ep)
10 changes: 10 additions & 0 deletions smart_open/tests/fixtures/compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""Some no-op compressor"""


def handle_foo():
...


def handle_bar():
...
55 changes: 55 additions & 0 deletions smart_open/tests/test_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-

import pytest

from smart_open.utils import importlib_metadata
from smart_open.compression import _COMPRESSOR_REGISTRY, _register_compressor_entry_point

EntryPoint = importlib_metadata.EntryPoint


def unregister_compressor(ext):
if ext in _COMPRESSOR_REGISTRY:
del _COMPRESSOR_REGISTRY[ext]


@pytest.fixture(autouse=True)
def cleanup_compressor():
unregister_compressor(".foo")
unregister_compressor(".bar")


def test_register_valid_entry_point():
assert ".foo" not in _COMPRESSOR_REGISTRY
assert ".bar" not in _COMPRESSOR_REGISTRY
_register_compressor_entry_point(EntryPoint(
"foo",
"smart_open.tests.fixtures.compressor:handle_bar",
"smart_open_compressor",
))
_register_compressor_entry_point(EntryPoint(
"bar",
"smart_open.tests.fixtures.compressor:handle_bar",
"smart_open_compressor",
))
assert ".foo" in _COMPRESSOR_REGISTRY
assert ".bar" in _COMPRESSOR_REGISTRY


def test_register_invalid_entry_point_name_do_not_crash():
_register_compressor_entry_point(EntryPoint(
"",
"smart_open.tests.fixtures.compressor:handle_foo",
"smart_open_compressor",
))
assert "" not in _COMPRESSOR_REGISTRY
assert "." not in _COMPRESSOR_REGISTRY


def test_register_invalid_entry_point_value_do_not_crash():
_register_compressor_entry_point(EntryPoint(
"foo",
"smart_open.tests.fixtures.compressor:handle_invalid",
"smart_open_compressor",
))
assert ".foo" not in _COMPRESSOR_REGISTRY
58 changes: 57 additions & 1 deletion smart_open/tests/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,45 @@
import pytest
import unittest

from smart_open.transport import register_transport, get_transport
from smart_open.transport import (
register_transport, get_transport, _REGISTRY, _ERRORS, _register_transport_entry_point
)
from smart_open.utils import importlib_metadata

EntryPoint = importlib_metadata.EntryPoint


def unregister_transport(x):
if x in _REGISTRY:
del _REGISTRY[x]
if x in _ERRORS:
del _ERRORS[x]


def assert_transport_not_registered(scheme):
with pytest.raises(NotImplementedError):
get_transport(scheme)


def assert_transport_registered(scheme):
transport = get_transport(scheme)
assert transport.SCHEME == scheme


class TransportTest(unittest.TestCase):
def tearDown(self):
unregister_transport("foo")
unregister_transport("missing")

def test_registry_requires_declared_schemes(self):
with pytest.raises(ValueError):
register_transport('smart_open.tests.fixtures.no_schemes_transport')

def test_registry_valid_transport(self):
assert_transport_not_registered("foo")
register_transport('smart_open.tests.fixtures.good_transport')
assert_transport_registered("foo")

def test_registry_errors_on_double_register_scheme(self):
register_transport('smart_open.tests.fixtures.good_transport')
with pytest.raises(AssertionError):
Expand All @@ -20,3 +50,29 @@ def test_registry_errors_get_transport_for_module_with_missing_deps(self):
register_transport('smart_open.tests.fixtures.missing_deps_transport')
with pytest.raises(ImportError):
get_transport("missing")

def test_register_entry_point_valid(self):
assert_transport_not_registered("foo")
_register_transport_entry_point(EntryPoint(
"foo",
"smart_open.tests.fixtures.good_transport",
"smart_open_transport",
))
assert_transport_registered("foo")

def test_register_entry_point_catch_bad_data(self):
_register_transport_entry_point(EntryPoint(
"invalid",
"smart_open.some_totaly_invalid_module",
"smart_open_transport",
))

def test_register_entry_point_for_module_with_missing_deps(self):
assert_transport_not_registered("missing")
_register_transport_entry_point(EntryPoint(
"missing",
"smart_open.tests.fixtures.missing_deps_transport",
"smart_open_transport",
))
with pytest.raises(ImportError):
get_transport("missing")
11 changes: 11 additions & 0 deletions smart_open/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,14 @@ def test_check_kwargs():
def test_safe_urlsplit(url, expected):
actual = smart_open.utils.safe_urlsplit(url)
assert actual == urllib.parse.SplitResult(*expected)


def test_find_entry_points():
# Installed through setup.py tests requirements
eps = smart_open.utils.find_entry_points("pytest11")
eps_names = {ep.name for ep in eps}
assert "rerunfailures" in eps_names

# Part of setuptools
eps = smart_open.utils.find_entry_points("distutils.commands")
assert len(eps) > 0
12 changes: 12 additions & 0 deletions smart_open/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import logging

import smart_open.local_file
from smart_open.utils import find_entry_points

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,5 +103,16 @@ def get_transport(scheme):
register_transport('smart_open.ssh')
register_transport('smart_open.webhdfs')


def _register_transport_entry_point(ep):
try:
register_transport(ep.value)
except Exception:
logger.warning("Fail to load smart_open transport extension: %s (target: %s)", ep.name, ep.value)


for ep in find_entry_points(group='smart_open_transport'):
_register_transport_entry_point(ep)

SUPPORTED_SCHEMES = tuple(sorted(_REGISTRY.keys()))
"""The transport schemes that the local installation of ``smart_open`` supports."""
35 changes: 35 additions & 0 deletions smart_open/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@

import inspect
import logging
import sys
import urllib.parse

# Library has been added in version 3.8
# See: https://docs.python.org/3.8/library/importlib.metadata.html#module-importlib.metadata
if sys.version_info >= (3, 8):
import importlib.metadata as importlib_metadata
else:
import importlib_metadata

logger = logging.getLogger(__name__)

WORKAROUND_SCHEMES = ['s3', 's3n', 's3u', 's3a', 'gs']
Expand Down Expand Up @@ -189,3 +197,30 @@ def safe_urlsplit(url):

path = sr.path.replace(placeholder, '?')
return urllib.parse.SplitResult(sr.scheme, sr.netloc, path, '', '')


def find_entry_points(group):
"""Search packages entry points and filter value based
on their group name.

Parameters
----------
group: str
An entry point group name to filter registered entry points.

Returns
-------
list[EntryPoint]
Valid registered entry points.
"""

try:
# Try new filter API (python 3.10+ and importlib_metadata 3.6+).
#
# Check "Compatibility Note" here:
# https://docs.python.org/3.10/library/importlib.metadata.html#entry-points
return importlib_metadata.entry_points(group=group)
except Exception:
# Legacy API
eps = importlib_metadata.entry_points()
return eps.get(group, [])