Skip to content

Commit

Permalink
Merge pull request #5661 from uktrade/feature/CLS2-912-add-query-para…
Browse files Browse the repository at this point in the history
…ms-to-eyb-lead-retrieve-view

Add query params to EYB lead retrieve view
  • Loading branch information
oliverjwroberts authored Sep 25, 2024
2 parents e5a63ce + 264e3cb commit 64d2905
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 4 deletions.
4 changes: 4 additions & 0 deletions datahub/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class Sector(Enum):
'Defence : Land',
'7c432bdc-77e0-49ac-8d5f-1eece499ae2a',
)
mining = Constant(
'Mining',
'a622c9d2-5f95-e211-a939-e4115bead28a',
)
mining_mining_vehicles_transport_equipment = Constant(
'Mining : Mining vehicles, transport and equipment',
'e17c69f9-8c65-457e-9a65-fd7c52a45700',
Expand Down
12 changes: 8 additions & 4 deletions datahub/investment_lead/test/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@ class Meta:
triage_modified = factory.LazyFunction(timezone.now)
sector = factory.LazyAttribute(lambda o: random.choice(list(Sector.objects.all())))
sector_sub = factory.LazyAttribute(lambda o: f'{o.sector.segment}')
intent = random.choices(EYBLead.IntentChoices.values, k=random.randint(1, 6))
intent = factory.LazyAttribute(
lambda o: random.sample(EYBLead.IntentChoices.values, k=random.randint(1, 4)),
)
intent_other = ''
location_id = constants.UKRegion.wales.value.id
location_city = 'Cardiff'
location_none = False
hiring = random.choice(EYBLead.HiringChoices.values)
spend = random.choice(EYBLead.SpendChoices.values)
hiring = factory.LazyAttribute(lambda o: random.choice(EYBLead.HiringChoices.values))
spend = factory.LazyAttribute(lambda o: random.choice(EYBLead.SpendChoices.values))
spend_other = ''
is_high_value = factory.Faker('pybool')

Expand All @@ -55,7 +57,9 @@ class Meta:
telephone_number = factory.Faker('phone_number')
agree_terms = factory.Faker('pybool')
agree_info_email = factory.Faker('pybool')
landing_timeframe = random.choice(EYBLead.LandingTimeframeChoices.values)
landing_timeframe = factory.LazyAttribute(
lambda o: random.choice(EYBLead.LandingTimeframeChoices.values),
)
company_website = factory.Faker('url')

# Company fields
Expand Down
85 changes: 85 additions & 0 deletions datahub/investment_lead/test/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from rest_framework import status
from rest_framework.reverse import reverse

from datahub.company.test.factories import CompanyFactory
from datahub.core import constants
from datahub.core.test_utils import APITestMixin
from datahub.investment_lead.models import EYBLead
from datahub.investment_lead.test.factories import EYBLeadFactory
from datahub.investment_lead.test.utils import verify_eyb_lead_data
from datahub.metadata.models import Sector


EYB_LEAD_COLLECTION_URL = reverse('api-v4:investment-lead:eyb-lead-collection')
Expand Down Expand Up @@ -153,3 +155,86 @@ def test_pagination(self, test_user_with_view_permissions):
assert response.data['count'] == number_of_leads
assert response.data['next'] is not None
assert len(response.data['results']) == pagination_limit

def test_filter_by_company_name(self, test_user_with_view_permissions):
"""Test filtering EYB leads by company name"""
company_name = 'Mars Exports Ltd'
company = CompanyFactory(name=company_name)
EYBLeadFactory(company=company)
EYBLeadFactory()
api_client = self.create_api_client(user=test_user_with_view_permissions)
response = api_client.get(EYB_LEAD_COLLECTION_URL, data={'company': company_name})
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 1
assert response.data['results'][0]['company']['name'] == company_name

def test_filter_by_sector(self, test_user_with_view_permissions):
"""Test filtering EYB leads by sector id"""
level_0_sector = Sector.objects.get(pk=constants.Sector.mining.value.id)
child_sector = Sector.objects.get(
pk=constants.Sector.mining_mining_vehicles_transport_equipment.value.id,
)
unrelated_sector = Sector.objects.get(pk=constants.Sector.renewable_energy_wind.value.id)
EYBLeadFactory(sector=level_0_sector)
EYBLeadFactory(sector=child_sector)
EYBLeadFactory(sector=unrelated_sector)

api_client = self.create_api_client(user=test_user_with_view_permissions)
response = api_client.get(EYB_LEAD_COLLECTION_URL)
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 3

response = api_client.get(EYB_LEAD_COLLECTION_URL, data={'sector': level_0_sector.pk})
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 2
sector_ids_in_results = set([lead['sector']['id'] for lead in response.data['results']])
assert {str(level_0_sector.pk), str(child_sector.pk)} == sector_ids_in_results

def test_filter_by_non_existing_sector(self, test_user_with_view_permissions):
"""Test filtering EYB leads by non existent sector is handled without error."""
non_existent_sector_uuid = uuid.uuid4()
sector = Sector.objects.get(pk=constants.Sector.renewable_energy_wind.value.id)
EYBLeadFactory(sector=sector)

api_client = self.create_api_client(user=test_user_with_view_permissions)
response = api_client.get(EYB_LEAD_COLLECTION_URL)
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 1

response = api_client.get(
EYB_LEAD_COLLECTION_URL, data={'sector': str(non_existent_sector_uuid)},
)
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 0

def test_filter_by_is_high_value(self, test_user_with_view_permissions):
"""Test filtering EYB leads by is high value status"""
EYBLeadFactory(is_high_value=True)
EYBLeadFactory(is_high_value=False)
EYBLeadFactory(is_high_value=False)
api_client = self.create_api_client(user=test_user_with_view_permissions)

response = api_client.get(EYB_LEAD_COLLECTION_URL)
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 3

response = api_client.get(EYB_LEAD_COLLECTION_URL, data={'value': 'high'})
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 1
assert response.data['results'][0]['is_high_value'] is True

def test_filter_by_is_low_value(self, test_user_with_view_permissions):
"""Test filtering EYB leads by is low value status"""
EYBLeadFactory(is_high_value=True)
EYBLeadFactory(is_high_value=True)
EYBLeadFactory(is_high_value=False)
api_client = self.create_api_client(user=test_user_with_view_permissions)

response = api_client.get(EYB_LEAD_COLLECTION_URL)
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 3

response = api_client.get(EYB_LEAD_COLLECTION_URL, data={'value': 'low'})
assert response.status_code == status.HTTP_200_OK
assert response.data['count'] == 1
assert response.data['results'][0]['is_high_value'] is False
27 changes: 27 additions & 0 deletions datahub/investment_lead/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CreateEYBLeadSerializer,
RetrieveEYBLeadSerializer,
)
from datahub.metadata.models import Sector


logger = logging.getLogger(__name__)
Expand All @@ -40,6 +41,32 @@ def get_serializer_class(self):
return CreateEYBLeadSerializer
return RetrieveEYBLeadSerializer

def get_queryset(self):
"""Apply filters to queryset based on query parameters (in GET operations)."""
queryset = super().get_queryset()
company_name = self.request.query_params.get('company')
sector_id = self.request.query_params.get('sector')
value = self.request.query_params.get('value')

if company_name:
queryset = queryset.filter(company__name__icontains=company_name)
if sector_id:
try:
# This will be a level 0 sector id;
# We want to find and return all leads with sectors that have this ancestor
sector = Sector.objects.get(pk=sector_id)
descendent_sectors = sector.get_descendants(include_self=True)
queryset = queryset.filter(sector__in=descendent_sectors)
except Exception:
queryset = queryset.none()
if value is not None:
if value.lower().strip() == 'high':
queryset = queryset.filter(is_high_value=True)
if value.lower().strip() == 'low':
queryset = queryset.filter(is_high_value=False)

return queryset

def create(self, request):
"""POST route definition.
Expand Down

0 comments on commit 64d2905

Please sign in to comment.