Skip to content

Commit

Permalink
Cache name resolutions (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
twm authored Sep 7, 2024
1 parent 15fad2e commit 60286e4
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 21 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ tests to ensure network calls are prevented.

## Requirements

- [Pytest](https://github.com/pytest-dev/pytest) 6.2.5 or greater
- [Pytest](https://github.com/pytest-dev/pytest) 7.0 or greater

## Installation

Expand All @@ -43,21 +43,21 @@ Run `pytest --disable-socket`, tests should fail on any access to `socket` or
libraries using socket with a `SocketBlockedError`.

To add this flag as the default behavior, add this section to your
[`pytest.ini`](https://docs.pytest.org/en/6.2.x/customize.html#pytest-ini):
[`pytest.ini`](https://docs.pytest.org/en/stable/reference/customize.html#pytest-ini):

```ini
[pytest]
addopts = --disable-socket
```

or add this to your [`setup.cfg`](https://docs.pytest.org/en/6.2.x/customize.html#setup-cfg):
or add this to your [`setup.cfg`](https://docs.pytest.org/en/stable/reference/customize.html#setup-cfg):

```ini
[tool:pytest]
addopts = --disable-socket
```

or update your [`conftest.py`](https://docs.pytest.org/en/6.2.x/writing_plugins.html#conftest-py-plugins) to include:
or update your [`conftest.py`](https://docs.pytest.org/en/stable/how-to/writing_plugins.html#conftest-py-local-per-directory-plugins) to include:

```python
from pytest_socket import disable_socket
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.8"
pytest = ">=6.2.5"
pytest = ">=7.0.0"

[tool.poetry.dev-dependencies]
coverage = "^7.6"
Expand Down
64 changes: 49 additions & 15 deletions pytest_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import socket
import typing
from collections import defaultdict
from dataclasses import dataclass, field

import pytest

Expand Down Expand Up @@ -56,7 +57,8 @@ def pytest_addoption(parser):
@pytest.fixture
def socket_disabled(pytestconfig):
"""disable socket.socket for duration of this test function"""
disable_socket(allow_unix_socket=pytestconfig.__socket_allow_unix_socket)
socket_config = pytestconfig.stash[_STASH_KEY]
disable_socket(allow_unix_socket=socket_config.allow_unix_socket)
yield


Expand All @@ -67,6 +69,18 @@ def socket_enabled(pytestconfig):
yield


@dataclass
class _PytestSocketConfig:
socket_disabled: bool
socket_force_enabled: bool
allow_unix_socket: bool
allow_hosts: typing.Union[str, typing.List[str], None]
resolution_cache: typing.Dict[str, typing.Set[str]] = field(default_factory=dict)


_STASH_KEY = pytest.StashKey[_PytestSocketConfig]()


def _is_unix_socket(family) -> bool:
try:
is_unix_socket = family == socket.AF_UNIX
Expand Down Expand Up @@ -109,10 +123,12 @@ def pytest_configure(config):
)

# Store the global configs in the `pytest.Config` object.
config.__socket_force_enabled = config.getoption("--force-enable-socket")
config.__socket_disabled = config.getoption("--disable-socket")
config.__socket_allow_unix_socket = config.getoption("--allow-unix-socket")
config.__socket_allow_hosts = config.getoption("--allow-hosts")
config.stash[_STASH_KEY] = _PytestSocketConfig(
socket_force_enabled=config.getoption("--force-enable-socket"),
socket_disabled=config.getoption("--disable-socket"),
allow_unix_socket=config.getoption("--allow-unix-socket"),
allow_hosts=config.getoption("--allow-hosts"),
)


def pytest_runtest_setup(item) -> None:
Expand All @@ -129,12 +145,14 @@ def pytest_runtest_setup(item) -> None:
if not hasattr(item, "fixturenames"):
return

socket_config = item.config.stash[_STASH_KEY]

# If test has the `enable_socket` marker, fixture or
# it's forced from the CLI, we accept this as most explicit.
if (
"socket_enabled" in item.fixturenames
or item.get_closest_marker("enable_socket")
or item.config.__socket_force_enabled
or socket_config.socket_force_enabled
):
enable_socket()
return
Expand All @@ -143,27 +161,34 @@ def pytest_runtest_setup(item) -> None:
if "socket_disabled" in item.fixturenames or item.get_closest_marker(
"disable_socket"
):
disable_socket(item.config.__socket_allow_unix_socket)
disable_socket(socket_config.allow_unix_socket)
return

# Resolve `allow_hosts` behaviors.
hosts = _resolve_allow_hosts(item)

# Finally, check the global config and disable socket if needed.
if item.config.__socket_disabled and not hosts:
disable_socket(item.config.__socket_allow_unix_socket)
if socket_config.socket_disabled and not hosts:
disable_socket(socket_config.allow_unix_socket)


def _resolve_allow_hosts(item):
"""Resolve `allow_hosts` behaviors."""
socket_config = item.config.stash[_STASH_KEY]

mark_restrictions = item.get_closest_marker("allow_hosts")
cli_restrictions = item.config.__socket_allow_hosts
cli_restrictions = socket_config.allow_hosts
hosts = None
if mark_restrictions:
hosts = mark_restrictions.args[0]
elif cli_restrictions:
hosts = cli_restrictions
socket_allow_hosts(hosts, allow_unix_socket=item.config.__socket_allow_unix_socket)

socket_allow_hosts(
hosts,
allow_unix_socket=socket_config.allow_unix_socket,
resolution_cache=socket_config.resolution_cache,
)
return hosts


Expand Down Expand Up @@ -206,28 +231,37 @@ def resolve_hostnames(hostname: str) -> typing.Set[str]:

def normalize_allowed_hosts(
allowed_hosts: typing.List[str],
resolution_cache: typing.Optional[typing.Dict[str, typing.List[str]]] = None,
) -> typing.Dict[str, typing.Set[str]]:
"""Map all items in `allowed_hosts` to IP addresses."""
if resolution_cache is None:
resolution_cache = {}
ip_hosts = defaultdict(set)
for host in allowed_hosts:
host = host.strip()
if is_ipaddress(host):
ip_hosts[host].add(host)
else:
ip_hosts[host].update(resolve_hostnames(host))
continue
if host not in resolution_cache:
resolution_cache[host] = resolve_hostnames(host)
ip_hosts[host].update(resolution_cache[host])

return ip_hosts


def socket_allow_hosts(allowed=None, allow_unix_socket=False):
def socket_allow_hosts(
allowed: typing.Union[str, typing.List[str], None] = None,
allow_unix_socket: bool = False,
resolution_cache: typing.Optional[typing.Dict[str, typing.List[str]]] = None,
) -> None:
"""disable socket.socket.connect() to disable the Internet. useful in testing."""
if isinstance(allowed, str):
allowed = allowed.split(",")

if not isinstance(allowed, list):
return

allowed_ip_hosts_by_host = normalize_allowed_hosts(allowed)
allowed_ip_hosts_by_host = normalize_allowed_hosts(allowed, resolution_cache)
allowed_ip_hosts_and_hostnames = set(
itertools.chain(*allowed_ip_hosts_by_host.values())
) | set(allowed_ip_hosts_by_host.keys())
Expand Down
86 changes: 86 additions & 0 deletions tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import collections
import inspect
import socket

import pytest

from pytest_socket import normalize_allowed_hosts

localhost = "127.0.0.1"

connect_code_template = """
Expand Down Expand Up @@ -85,6 +89,25 @@ def assert_socket_connect(should_pass, **kwargs):
return assert_socket_connect


@pytest.fixture
def getaddrinfo_hosts(monkeypatch):
hosts = []

def _getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
hosts.append(host)
v4 = (
socket.AF_INET,
socket.SOCK_STREAM,
socket.IPPROTO_TCP,
"",
("127.0.0.127", 0),
)
return [v4]

monkeypatch.setattr(socket, "getaddrinfo", _getaddrinfo)
return hosts


def test_help_message(testdir):
result = testdir.runpytest(
"--help",
Expand Down Expand Up @@ -298,3 +321,66 @@ def test_fail_2():
result.assert_outcomes(1, 0, 2)
assert_host_blocked(result, "2.2.2.2")
assert_host_blocked(result, httpbin.host)


def test_normalize_allowed_hosts(getaddrinfo_hosts):
"""normalize_allowed_hosts() produces a map of hosts to IP addresses."""
assert normalize_allowed_hosts(["127.0.0.1", "localhost", "localhost", "::1"]) == {
"::1": {"::1"},
"127.0.0.1": {"127.0.0.1"},
"localhost": {"127.0.0.127"},
}

assert getaddrinfo_hosts == ["localhost"]


def test_normalize_allowed_hosts_cache(getaddrinfo_hosts):
"""normalize_allowed_hosts() caches name resolutions when passed a cache"""
cache = {}

assert normalize_allowed_hosts(["localhost"], cache) == {
"localhost": {"127.0.0.127"}
}
assert cache == {"localhost": {"127.0.0.127"}}
assert getaddrinfo_hosts == ["localhost"]

del getaddrinfo_hosts[:]

assert normalize_allowed_hosts(["localhost", "localhost"], cache) == {
"localhost": {"127.0.0.127"}
}
assert cache == {"localhost": {"127.0.0.127"}}
assert getaddrinfo_hosts == []


def test_name_resolution_cached(testdir, getaddrinfo_hosts):
"""pytest-socket only resolves each allowed name once."""

testdir.makepyfile(
"""
import pytest
import socket
@pytest.mark.allow_hosts('name.internal')
def test_1():
...
@pytest.mark.allow_hosts(['name.internal', 'name.another'])
def test_2():
...
@pytest.mark.allow_hosts('name.internal')
@pytest.mark.parametrize("i", ["3", "4", "5"])
def test_456(i):
...
"""
)

hooks = testdir.inline_run("--allow-hosts=name.internal,name.internal")
[result] = hooks.getcalls("pytest_sessionfinish")
assert result.session.testsfailed == 0

assert collections.Counter(getaddrinfo_hosts) == {
"name.internal": 1,
"name.another": 1,
}

0 comments on commit 60286e4

Please sign in to comment.