Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow to pass a custom client #76

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions src/loafer/ext/aws/bases.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,50 @@
import logging
from contextlib import asynccontextmanager

from aiobotocore.session import get_session

logger = logging.getLogger(__name__)

session = get_session()
DEFAULT_SESSION = None


def _setup_default_session():
global DEFAULT_SESSION # noqa: PLW0603
DEFAULT_SESSION = get_session()


def get_default_session():
if not DEFAULT_SESSION:
_setup_default_session()

return DEFAULT_SESSION


class _BotoClient:
boto_service_name = None

def __init__(self, **client_options):
self._client_options = {
"api_version": client_options.get("api_version"),
"aws_access_key_id": client_options.get("aws_access_key_id"),
"aws_secret_access_key": client_options.get("aws_secret_access_key"),
"aws_session_token": client_options.get("aws_session_token"),
"endpoint_url": client_options.get("endpoint_url"),
"region_name": client_options.get("region_name"),
"use_ssl": client_options.get("use_ssl", True),
"verify": client_options.get("verify"),
}

def get_client(self):
return session.create_client(self.boto_service_name, **self._client_options)
def __init__(self, *, client=None, **client_options):
if client:
self._client = client
else:
self._client_options = {
"api_version": client_options.get("api_version"),
"aws_access_key_id": client_options.get("aws_access_key_id"),
"aws_secret_access_key": client_options.get("aws_secret_access_key"),
"aws_session_token": client_options.get("aws_session_token"),
"endpoint_url": client_options.get("endpoint_url"),
"region_name": client_options.get("region_name"),
"use_ssl": client_options.get("use_ssl", True),
"verify": client_options.get("verify"),
}

@asynccontextmanager
async def get_client(self):
if hasattr(self, "_client"):
yield self._client
else:
async with get_default_session().create_client(self.boto_service_name, **self._client_options) as client:
yield client


class BaseSQSClient(_BotoClient):
Expand Down
10 changes: 8 additions & 2 deletions tests/ext/aws/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def boto_client_sqs(queue_url, sqs_message):

@pytest.fixture
def mock_boto_session_sqs(boto_client_sqs):
return mock.patch("loafer.ext.aws.bases.session.create_client", return_value=ClientContextCreator(boto_client_sqs))
mock_create_client = mock.Mock(return_value=ClientContextCreator(boto_client_sqs))
return mock.patch(
"loafer.ext.aws.bases.get_default_session", return_value=mock.Mock(create_client=mock_create_client)
)


@pytest.fixture
Expand All @@ -76,4 +79,7 @@ def boto_client_sns(sns_publish):

@pytest.fixture
def mock_boto_session_sns(boto_client_sns):
return mock.patch("loafer.ext.aws.bases.session.create_client", return_value=ClientContextCreator(boto_client_sns))
mock_create_client = mock.Mock(return_value=ClientContextCreator(boto_client_sns))
return mock.patch(
"loafer.ext.aws.bases.get_default_session", return_value=mock.Mock(create_client=mock_create_client)
)
29 changes: 25 additions & 4 deletions tests/ext/aws/test_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,22 @@ async def test_get_queue_url_when_queue_name_is_url(mock_boto_session_sqs, boto_
@pytest.mark.asyncio
async def test_sqs_get_client(mock_boto_session_sqs, base_sqs_client, boto_client_sqs):
with mock_boto_session_sqs as mock_session:
client_generator = base_sqs_client.get_client()
async with base_sqs_client.get_client() as client:
assert boto_client_sqs is client

assert mock_session.called
async with client_generator as client:


@pytest.mark.asyncio
async def test_sqs_get_client_with_custom_client(mock_boto_session_sqs, boto_client_sqs):
base_sqs_client = BaseSQSClient(client=boto_client_sqs)

with mock_boto_session_sqs as mock_session:
async with base_sqs_client.get_client() as client:
assert boto_client_sqs is client

mock_session.assert_not_called()


@pytest.fixture
def base_sns_client():
Expand All @@ -70,7 +81,17 @@ async def test_cache_get_topic_arn_with_arn(base_sns_client):
@pytest.mark.asyncio
async def test_sns_get_client(mock_boto_session_sns, base_sns_client, boto_client_sns):
with mock_boto_session_sns as mock_session:
client_generator = base_sns_client.get_client()
async with base_sns_client.get_client() as client:
assert boto_client_sns is client
assert mock_session.called
async with client_generator as client:


@pytest.mark.asyncio
async def test_sns_get_client_with_custom_client(mock_boto_session_sns, boto_client_sns):
base_sns_client = BaseSNSClient(client=boto_client_sns)

with mock_boto_session_sns as mock_session:
async with base_sns_client.get_client() as client:
assert boto_client_sns is client

mock_session.assert_not_called()