diff --git a/.gitignore b/.gitignore index 02add4d11..79a4da738 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ nosetests.xml coverage.xml *,cover .hypothesis/ +/codecov.sh # Translations *.mo diff --git a/README.md b/README.md index 349a42ccb..a523cf432 100644 --- a/README.md +++ b/README.md @@ -32,25 +32,32 @@ Leeloo uses Docker compose to setup and run all the necessary components. The do docker-compose run leeloo python manage.py loaddata /app/fixtures/datahub_businesstypes.yaml docker-compose run leeloo python manage.py createinitialrevisions ``` +4. Optionally, you can load some test data and update elasticsearch: -4. Create a superuser: + ```shell + docker-compose run leeloo python manage.py loaddata /app/fixtures/test_data.yaml + + docker-compose run leeloo python manage.py sync_es + ``` + +5. Create a superuser: ```shell docker-compose run leeloo python manage.py createsuperuser ``` -5. Run the services: +6. Run the services: ```shell docker-compose up ``` -6. To set up the [data hub frontend app](https://github.com/uktrade/data-hub-fe-beta2), log into the [django admin](http://localhost:8000/admin/oauth2_provider/application/) and add a new oauth application with: +7. To set up the [data hub frontend app](https://github.com/uktrade/data-hub-frontend), log into the [django admin](http://localhost:8000/admin/oauth2_provider/application/) and add a new oauth application with: - Client type: Confidential - Authorization grant type: Resource owner password-based -7. Add the client id / client secret to the frontend .env file +8. Add the client id / client secret to the frontend .env file Local development with Docker ----------------------------- @@ -117,7 +124,13 @@ Dependencies: create database datahub; ``` -8. Configure and populate the db: +8. Make sure you have elasticsearch running locally. If you don't, you can run one in docker: + + ```shell + docker run -p 9200:9200 -e "http.host=0.0.0.0" -e "transport.host=127.0.0.1" elasticsearch:2.3 + ``` + +9. Configure and populate the db: ```shell ./manage.py migrate @@ -128,18 +141,26 @@ Dependencies: ./manage.py createinitialrevisions ``` -9. Start the server: +10. Optionally, you can load some test data and update elasticsearch: + + ```shell + ./manage.py loaddata /app/fixtures/test_data.yaml + + ./manage.py sync_es + ``` + +11. Start the server: ```shell ./manage.py runserver ``` -10. To set up the [data hub frontend app](https://github.com/uktrade/data-hub-fe-beta2), log into the [django admin](http://localhost:8000/admin/oauth2_provider/application/) and add a new oauth application with: +12. To set up the [data hub frontend app](https://github.com/uktrade/data-hub-frontend), log into the [django admin](http://localhost:8000/admin/oauth2_provider/application/) and add a new oauth application with: - Client type: Confidential - Authorization grant type: Resource owner password-based -11. Add the client id / client secret to the frontend .env file +13. Add the client id / client secret to the frontend .env file Local development (without Docker) ---------------------------------- @@ -209,6 +230,12 @@ docker-compose run leeloo python manage.py loaddata /app/fixtures/metadata.yaml docker-compose run leeloo python manage.py loaddata /app/fixtures/datahub_businesstypes.yaml ``` +Update Elasticsearch: + +```shell +docker-compose run leeloo python manage.py sync_es +``` + Dependencies ============ diff --git a/config/api_urls.py b/config/api_urls.py index 9e32c04bd..39a334b61 100644 --- a/config/api_urls.py +++ b/config/api_urls.py @@ -3,11 +3,12 @@ from django.conf.urls import include, url from rest_framework import routers -from datahub.company import views as company_views from datahub.company import urls as company_urls -from datahub.investment import urls as investment_urls +from datahub.company import views as company_views from datahub.interaction import views as interaction_views +from datahub.investment import urls as investment_urls from datahub.leads import urls as leads_urls +from datahub.omis import urls as omis_urls from datahub.search import urls as search_urls from datahub.v2.urls import urlpatterns as v2_urlpatterns @@ -36,5 +37,6 @@ url(r'^', include((company_urls.contact_urls, 'contact'), namespace='contact')), url(r'^', include((company_urls.company_urls, 'company'), namespace='company')), url(r'^', include((company_urls.ch_company_urls, 'ch-company'), namespace='ch-company')), - url(r'^', include((search_urls, 'search'), namespace='search')) + url(r'^', include((search_urls, 'search'), namespace='search')), + url(r'^omis/', include((omis_urls, 'omis'), namespace='omis')) ] diff --git a/config/settings/common.py b/config/settings/common.py index 33a3b1b0e..afbb93e19 100644 --- a/config/settings/common.py +++ b/config/settings/common.py @@ -59,6 +59,7 @@ 'datahub.search.apps.SearchConfig', 'datahub.user', 'datahub.korben', + 'datahub.omis.order' ] INSTALLED_APPS = DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS diff --git a/datahub/company/management/commands/sync_ch.py b/datahub/company/management/commands/sync_ch.py index d247ab1d9..f12a56143 100644 --- a/datahub/company/management/commands/sync_ch.py +++ b/datahub/company/management/commands/sync_ch.py @@ -59,7 +59,8 @@ def filter_irrelevant_ch_columns(row): @contextmanager def open_ch_zipped_csv(fp): - """Enclose all the complicated logic of on-the-fly unzip->csv read in a nice context manager.""" + """Enclose all the complicated logic of on-the-fly unzip->csv read in a nice context manager. + """ with zipfile.ZipFile(fp) as zf: # get the first file from zip, assuming it's the only one from CH csv_name = zf.filelist[0].filename @@ -86,9 +87,12 @@ def iter_ch_csv_from_url(url, tmp_file_creator): def sync_ch(tmp_file_creator, endpoint=None, truncate_first=False): """Do the sync. - We are batching the records instead of letting bulk_create doing it because Django casts the objects into a list + We are batching the records instead of letting bulk_create doing it because Django casts + the objects into a list: https://github.com/django/django/blob/master/django/db/models/query.py#L420 - this would create a list with millions of objects, that will try to be saved in batches in a single transaction + + This would create a list with millions of objects, that will try to be saved in batches + in a single transaction. """ logger.info('Starting CH load...') count = 0 diff --git a/datahub/company/models/company.py b/datahub/company/models/company.py index 64ed9fa18..b7f7cae68 100644 --- a/datahub/company/models/company.py +++ b/datahub/company/models/company.py @@ -54,7 +54,9 @@ class Company(ArchivableModel, CompanyAbstract): company_number = models.CharField(max_length=MAX_LENGTH, blank=True, null=True) id = models.UUIDField(primary_key=True, db_index=True, default=uuid.uuid4) - alias = models.CharField(max_length=MAX_LENGTH, blank=True, null=True, help_text='Trading name') + alias = models.CharField( + max_length=MAX_LENGTH, blank=True, null=True, help_text='Trading name' + ) business_type = models.ForeignKey( metadata_models.BusinessType, blank=True, null=True, on_delete=models.SET_NULL @@ -86,7 +88,9 @@ class Company(ArchivableModel, CompanyAbstract): related_name='company_future_interest_countries' ) description = models.TextField(blank=True, null=True) - website = models.CharField(max_length=MAX_LENGTH, validators=[RelaxedURLValidator], blank=True, null=True) + website = models.CharField( + max_length=MAX_LENGTH, validators=[RelaxedURLValidator], blank=True, null=True + ) uk_region = models.ForeignKey( metadata_models.UKRegion, blank=True, null=True, on_delete=models.SET_NULL @@ -158,15 +162,17 @@ def _validate_trading_address(self): self.trading_address_postcode, self.trading_address_country )) - all_required_trading_address_fields = all(getattr(self, field) - for field in self.REQUIRED_TRADING_ADDRESS_FIELDS) + all_required_trading_address_fields = all( + getattr(self, field) for field in self.REQUIRED_TRADING_ADDRESS_FIELDS + ) if any_trading_address_fields and not all_required_trading_address_fields: return False return True def _generate_trading_address_errors(self): """Generate per field error.""" - empty_fields = [field for field in self.REQUIRED_TRADING_ADDRESS_FIELDS if not getattr(self, field)] + empty_fields = [field for field in self.REQUIRED_TRADING_ADDRESS_FIELDS + if not getattr(self, field)] return {field: ['This field may not be null.'] for field in empty_fields} def _validate_uk_region(self): diff --git a/datahub/company/models/contact.py b/datahub/company/models/contact.py index 008cb7dd9..88e0bfb1a 100644 --- a/datahub/company/models/contact.py +++ b/datahub/company/models/contact.py @@ -68,7 +68,8 @@ def __str__(self): def _generate_address_errors(self): """Generate per field error.""" - empty_fields = [field for field in self.REQUIRED_ADDRESS_FIELDS if not getattr(self, field)] + empty_fields = [field for field in self.REQUIRED_ADDRESS_FIELDS + if not getattr(self, field)] return {field: ['This field may not be null.'] for field in empty_fields} def validate_contact_preferences(self): @@ -92,15 +93,19 @@ def validate_address(self): self.address_postcode, self.address_country )) - all_required_fields_existence = all(getattr(self, field) for field in self.REQUIRED_ADDRESS_FIELDS) + all_required_fields_existence = all( + getattr(self, field) for field in self.REQUIRED_ADDRESS_FIELDS + ) if self.address_same_as_company and some_address_fields_existence: - error_message = 'Please select either address_same_as_company or enter an address manually, not both!' + error_message = ('Please select either address_same_as_company or enter an address ' + 'manually, not both!') raise ValidationError({'address_same_as_company': error_message}) if not self.address_same_as_company: if some_address_fields_existence and not all_required_fields_existence: raise ValidationError(self._generate_address_errors()) elif not some_address_fields_existence: - error_message = 'Please select either address_same_as_company or enter an address manually.' + error_message = ('Please select either address_same_as_company or enter an ' + 'address manually.') raise ValidationError({'address_same_as_company': error_message}) def clean(self): diff --git a/datahub/company/test/factories.py b/datahub/company/test/factories.py index 2521c5f2b..c3dfa9836 100644 --- a/datahub/company/test/factories.py +++ b/datahub/company/test/factories.py @@ -62,8 +62,8 @@ class ContactFactory(factory.django.DjangoModelFactory): id = factory.Sequence(lambda _: str(uuid.uuid4())) title_id = constants.Title.wing_commander.value.id - first_name = factory.Sequence(lambda n: 'name {n}') - last_name = factory.Sequence(lambda n: 'surname {n}') + first_name = factory.Sequence(lambda n: f'name {n}') + last_name = factory.Sequence(lambda n: f'surname {n}') company = factory.SubFactory(CompanyFactory) email = 'foo@bar.com' primary = True diff --git a/datahub/company/test/test_advisor_views.py b/datahub/company/test/test_advisor_views.py index 2d2e3767e..8a39c686f 100644 --- a/datahub/company/test/test_advisor_views.py +++ b/datahub/company/test/test_advisor_views.py @@ -1,11 +1,11 @@ from rest_framework import status from rest_framework.reverse import reverse -from datahub.core.test_utils import LeelooTestCase +from datahub.core.test_utils import APITestMixin from .factories import AdviserFactory -class AdviserTestCase(LeelooTestCase): +class TestAdviser(APITestMixin): """Adviser test case.""" def test_adviser_list_view(self): diff --git a/datahub/company/test/test_company_views.py b/datahub/company/test/test_company_views.py index 47c26579d..ae3f2a831 100644 --- a/datahub/company/test/test_company_views.py +++ b/datahub/company/test/test_company_views.py @@ -3,12 +3,14 @@ from rest_framework.reverse import reverse from datahub.company import models -from datahub.core import constants -from datahub.core.test_utils import LeelooTestCase +from datahub.core.constants import ( + BusinessType, CompanyClassification, Country, HeadquarterType, Sector, UKRegion +) +from datahub.core.test_utils import APITestMixin from .factories import CompaniesHouseCompanyFactory, CompanyFactory -class CompanyTestCase(LeelooTestCase): +class TestCompany(APITestMixin): """Company test case.""" def test_list_companies(self): @@ -31,7 +33,7 @@ def test_detail_company_with_company_number(self): name='Foo ltd.', registered_address_1='Hello st.', registered_address_town='Fooland', - registered_address_country_id=constants.Country.united_states.value.id + registered_address_country_id=Country.united_states.value.id ) company = CompanyFactory( company_number=123, @@ -67,9 +69,9 @@ def test_detail_company_without_company_number(self): name='Foo ltd.', registered_address_1='Hello st.', registered_address_town='Fooland', - registered_address_country_id=constants.Country.united_states.value.id, - headquarter_type_id=constants.HeadquarterType.ukhq.value.id, - classification_id=constants.CompanyClassification.tier_a.value.id, + registered_address_country_id=Country.united_states.value.id, + headquarter_type_id=HeadquarterType.ukhq.value.id, + classification_id=CompanyClassification.tier_a.value.id, ) url = reverse('api-v1:company-detail', kwargs={'pk': company.id}) @@ -88,8 +90,8 @@ def test_detail_company_without_company_number(self): } assert response.data['registered_address_county'] is None assert response.data['registered_address_postcode'] is None - assert response.data['headquarter_type']['name'] == constants.HeadquarterType.ukhq.value.name - assert response.data['classification']['name'] == constants.CompanyClassification.tier_a.value.name + assert response.data['headquarter_type']['name'] == HeadquarterType.ukhq.value.name + assert response.data['classification']['name'] == CompanyClassification.tier_a.value.name def test_update_company(self): """Test company update.""" @@ -97,7 +99,7 @@ def test_update_company(self): name='Foo ltd.', registered_address_1='Hello st.', registered_address_town='Fooland', - registered_address_country_id=constants.Country.united_states.value.id + registered_address_country_id=Country.united_states.value.id ) # now update it @@ -115,18 +117,18 @@ def test_classification_is_ro(self): name='Foo ltd.', registered_address_1='Hello st.', registered_address_town='Fooland', - registered_address_country_id=constants.Country.united_states.value.id, - classification_id=constants.CompanyClassification.tier_a.value.id, + registered_address_country_id=Country.united_states.value.id, + classification_id=CompanyClassification.tier_a.value.id, ) url = reverse('api-v1:company-detail', kwargs={'pk': company.pk}) response = self.api_client.patch(url, { - 'classification': constants.CompanyClassification.tier_b.value.id, + 'classification': CompanyClassification.tier_b.value.id, }) assert response.status_code == 200 # testing that this should be silently ignored error company.refresh_from_db() - assert str(company.classification_id) == constants.CompanyClassification.tier_a.value.id + assert str(company.classification_id) == CompanyClassification.tier_a.value.id def test_add_uk_company(self): """Test add new UK company.""" @@ -134,14 +136,14 @@ def test_add_uk_company(self): response = self.api_client.post(url, { 'name': 'Acme', 'alias': None, - 'business_type': constants.BusinessType.company.value.id, - 'sector': constants.Sector.aerospace_assembly_aircraft.value.id, - 'registered_address_country': constants.Country.united_kingdom.value.id, + 'business_type': BusinessType.company.value.id, + 'sector': Sector.aerospace_assembly_aircraft.value.id, + 'registered_address_country': Country.united_kingdom.value.id, 'registered_address_1': '75 Stramford Road', 'registered_address_town': 'London', - 'uk_region': constants.UKRegion.england.value.id, - 'headquarter_type': constants.HeadquarterType.ghq.value.id, - 'classification': constants.CompanyClassification.tier_a.value.id, + 'uk_region': UKRegion.england.value.id, + 'headquarter_type': HeadquarterType.ghq.value.id, + 'classification': CompanyClassification.tier_a.value.id, }) assert response.status_code == status.HTTP_201_CREATED @@ -153,15 +155,17 @@ def test_add_uk_company_without_uk_region(self): response = self.api_client.post(url, { 'name': 'Acme', 'alias': None, - 'business_type': constants.BusinessType.company.value.id, - 'sector': constants.Sector.aerospace_assembly_aircraft.value.id, - 'registered_address_country': constants.Country.united_kingdom.value.id, + 'business_type': BusinessType.company.value.id, + 'sector': Sector.aerospace_assembly_aircraft.value.id, + 'registered_address_country': Country.united_kingdom.value.id, 'registered_address_1': '75 Stramford Road', 'registered_address_town': 'London', }) assert response.status_code == status.HTTP_400_BAD_REQUEST - assert response.data['errors'] == {'uk_region': ['UK region is required for UK companies.']} + assert response.data['errors'] == { + 'uk_region': ['UK region is required for UK companies.'] + } def test_add_not_uk_company(self): """Test add new not UK company.""" @@ -169,9 +173,9 @@ def test_add_not_uk_company(self): response = self.api_client.post(url, { 'name': 'Acme', 'alias': None, - 'business_type': constants.BusinessType.company.value.id, - 'sector': constants.Sector.aerospace_assembly_aircraft.value.id, - 'registered_address_country': constants.Country.united_states.value.id, + 'business_type': BusinessType.company.value.id, + 'sector': Sector.aerospace_assembly_aircraft.value.id, + 'registered_address_country': Country.united_states.value.id, 'registered_address_1': '75 Stramford Road', 'registered_address_town': 'London', }) @@ -184,13 +188,13 @@ def test_add_company_partial_trading_address(self): url = reverse('api-v1:company-list') response = self.api_client.post(url, { 'name': 'Acme', - 'business_type': constants.BusinessType.company.value.id, - 'sector': constants.Sector.aerospace_assembly_aircraft.value.id, - 'registered_address_country': constants.Country.united_kingdom.value.id, + 'business_type': BusinessType.company.value.id, + 'sector': Sector.aerospace_assembly_aircraft.value.id, + 'registered_address_country': Country.united_kingdom.value.id, 'registered_address_1': '75 Stramford Road', 'registered_address_town': 'London', 'trading_address_1': 'test', - 'uk_region': constants.UKRegion.england.value.id + 'uk_region': UKRegion.england.value.id }) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -204,15 +208,15 @@ def test_add_company_with_trading_address(self): url = reverse('api-v1:company-list') response = self.api_client.post(url, { 'name': 'Acme', - 'business_type': constants.BusinessType.company.value.id, - 'sector': constants.Sector.aerospace_assembly_aircraft.value.id, - 'registered_address_country': constants.Country.united_kingdom.value.id, + 'business_type': BusinessType.company.value.id, + 'sector': Sector.aerospace_assembly_aircraft.value.id, + 'registered_address_country': Country.united_kingdom.value.id, 'registered_address_1': '75 Stramford Road', 'registered_address_town': 'London', - 'trading_address_country': constants.Country.ireland.value.id, + 'trading_address_country': Country.ireland.value.id, 'trading_address_1': '1 Hello st.', 'trading_address_town': 'Dublin', - 'uk_region': constants.UKRegion.england.value.id + 'uk_region': UKRegion.england.value.id }) assert response.status_code == status.HTTP_201_CREATED @@ -222,15 +226,15 @@ def test_add_company_with_website_without_scheme(self): url = reverse('api-v1:company-list') response = self.api_client.post(url, { 'name': 'Acme', - 'business_type': constants.BusinessType.company.value.id, - 'sector': constants.Sector.aerospace_assembly_aircraft.value.id, - 'registered_address_country': constants.Country.united_kingdom.value.id, + 'business_type': BusinessType.company.value.id, + 'sector': Sector.aerospace_assembly_aircraft.value.id, + 'registered_address_country': Country.united_kingdom.value.id, 'registered_address_1': '75 Stramford Road', 'registered_address_town': 'London', - 'trading_address_country': constants.Country.ireland.value.id, + 'trading_address_country': Country.ireland.value.id, 'trading_address_1': '1 Hello st.', 'trading_address_town': 'Dublin', - 'uk_region': constants.UKRegion.england.value.id, + 'uk_region': UKRegion.england.value.id, 'website': 'www.google.com', }) @@ -279,7 +283,7 @@ def test_unarchive_wrong_method(self): assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED -class CHCompanyTestCase(LeelooTestCase): +class TestCHCompany(APITestMixin): """Companies house company test case.""" def test_list_ch_companies(self): @@ -297,7 +301,9 @@ def test_detail_ch_company(self): """Test companies house company detail.""" ch_company = CompaniesHouseCompanyFactory(company_number=123) - url = reverse('api-v1:companieshousecompany-detail', kwargs={'company_number': ch_company.company_number}) + url = reverse('api-v1:companieshousecompany-detail', kwargs={ + 'company_number': ch_company.company_number + }) response = self.api_client.get(url) assert response.status_code == status.HTTP_200_OK @@ -319,15 +325,15 @@ def test_promote_a_ch_company(self): response = self.api_client.post(url, { 'name': 'Acme', 'company_number': 1234567890, - 'business_type': constants.BusinessType.company.value.id, - 'sector': constants.Sector.aerospace_assembly_aircraft.value.id, - 'registered_address_country': constants.Country.united_kingdom.value.id, + 'business_type': BusinessType.company.value.id, + 'sector': Sector.aerospace_assembly_aircraft.value.id, + 'registered_address_country': Country.united_kingdom.value.id, 'registered_address_1': '75 Stramford Road', 'registered_address_town': 'London', - 'trading_address_country': constants.Country.ireland.value.id, + 'trading_address_country': Country.ireland.value.id, 'trading_address_1': '1 Hello st.', 'trading_address_town': 'Dublin', - 'uk_region': constants.UKRegion.england.value.id + 'uk_region': UKRegion.england.value.id }) assert response.status_code == status.HTTP_201_CREATED diff --git a/datahub/company/test/test_company_views_v3.py b/datahub/company/test/test_company_views_v3.py index fd2853858..f01356f10 100644 --- a/datahub/company/test/test_company_views_v3.py +++ b/datahub/company/test/test_company_views_v3.py @@ -1,5 +1,6 @@ from operator import itemgetter +import reversion from django.utils.timezone import now from rest_framework import status from rest_framework.reverse import reverse @@ -11,11 +12,11 @@ BusinessType, CompanyClassification, Country, HeadquarterType, Sector, UKRegion ) -from datahub.core.test_utils import LeelooTestCase +from datahub.core.test_utils import APITestMixin from datahub.investment.test.factories import InvestmentProjectFactory -class CompanyTestCase(LeelooTestCase): +class TestCompany(APITestMixin): """Company test case.""" def test_list_companies(self): @@ -346,7 +347,46 @@ def test_unarchive_company(self): assert response.data['id'] == str(company.id) -class CHCompanyTestCase(LeelooTestCase): +class TestAuditLogView(APITestMixin): + """Tests for the audit log view.""" + + def test_audit_log_view(self): + """Test retrieval of audit log.""" + initial_datetime = now() + with reversion.create_revision(): + company = CompanyFactory( + description='Initial desc', + ) + + reversion.set_comment('Initial') + reversion.set_date_created(initial_datetime) + reversion.set_user(self.user) + + changed_datetime = now() + with reversion.create_revision(): + company.description = 'New desc' + company.save() + + reversion.set_comment('Changed') + reversion.set_date_created(changed_datetime) + reversion.set_user(self.user) + + url = reverse('api-v3:company:audit-item', kwargs={'pk': company.pk}) + + response = self.api_client.get(url) + response_data = response.json()['results'] + + # No need to test the whole response + assert len(response_data) == 1 + entry = response_data[0] + + assert entry['user']['name'] == self.user.name + assert entry['comment'] == 'Changed' + assert entry['timestamp'] == changed_datetime.isoformat() + assert entry['changes']['description'] == ['Initial desc', 'New desc'] + + +class TestCHCompany(APITestMixin): """CH company tests.""" def test_get_ch_company(self): diff --git a/datahub/company/test/test_contact_views.py b/datahub/company/test/test_contact_views.py index da746fa86..fdbdb1e51 100644 --- a/datahub/company/test/test_contact_views.py +++ b/datahub/company/test/test_contact_views.py @@ -1,17 +1,19 @@ import pytest +import reversion +from django.utils.timezone import now from freezegun import freeze_time from rest_framework import status from rest_framework.reverse import reverse from datahub.core import constants -from datahub.core.test_utils import LeelooTestCase +from datahub.core.test_utils import APITestMixin from .factories import CompanyFactory, ContactFactory # mark the whole module for db use pytestmark = pytest.mark.django_db -class AddContactTestCase(LeelooTestCase): +class TestAddContact(APITestMixin): """Add contact test case.""" @freeze_time('2017-04-18 13:25:30.986208+00:00') @@ -198,7 +200,9 @@ def test_fails_without_address(self): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.data == { - 'address_same_as_company': ['Please select either address_same_as_company or enter an address manually.'] + 'address_same_as_company': [ + 'Please select either address_same_as_company or enter an address manually.' + ] } def test_fails_with_only_partial_manual_address(self): @@ -254,7 +258,7 @@ def test_fails_with_contact_preferences_not_set(self): } -class EditContactTestCase(LeelooTestCase): +class TestEditContact(APITestMixin): """Edit contact test case.""" @freeze_time('2017-04-18 13:25:30.986208+00:00') @@ -342,7 +346,7 @@ def test_patch(self): } -class ArchiveContactTestCase(LeelooTestCase): +class TestArchiveContact(APITestMixin): """Archive/unarchive contact test case.""" def test_archive_without_reason(self): @@ -393,7 +397,7 @@ def test_unarchive_wrong_method(self): assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED -class ViewContactTestCase(LeelooTestCase): +class TestViewContact(APITestMixin): """View contact test case.""" @freeze_time('2017-04-18 13:25:30.986208+00:00') @@ -478,7 +482,7 @@ def test_view(self): } -class ContactListTestCase(LeelooTestCase): +class TestContactList(APITestMixin): """List/filter contacts test case.""" def test_all(self): @@ -504,4 +508,44 @@ def test_filter_by_company(self): assert response.status_code == status.HTTP_200_OK assert response.data['count'] == 2 - assert {contact['id'] for contact in response.data['results']} == {contact.id for contact in contacts} + expected_contacts = {contact.id for contact in contacts} + assert {contact['id'] for contact in response.data['results']} == expected_contacts + + +class TestAuditLogView(APITestMixin): + """Tests for the audit log view.""" + + def test_audit_log_view(self): + """Test retrieval of audit log.""" + initial_datetime = now() + with reversion.create_revision(): + contact = ContactFactory( + notes='Initial notes', + ) + + reversion.set_comment('Initial') + reversion.set_date_created(initial_datetime) + reversion.set_user(self.user) + + changed_datetime = now() + with reversion.create_revision(): + contact.notes = 'New notes' + contact.save() + + reversion.set_comment('Changed') + reversion.set_date_created(changed_datetime) + reversion.set_user(self.user) + + url = reverse('api-v3:contact:audit-item', kwargs={'pk': contact.pk}) + + response = self.api_client.get(url) + response_data = response.json()['results'] + + # No need to test the whole response + assert len(response_data) == 1 + entry = response_data[0] + + assert entry['user']['name'] == self.user.name + assert entry['comment'] == 'Changed' + assert entry['timestamp'] == changed_datetime.isoformat() + assert entry['changes']['notes'] == ['Initial notes', 'New notes'] diff --git a/datahub/company/test/test_management_commands.py b/datahub/company/test/test_management_commands.py index a33b4d15b..dfe59eb25 100644 --- a/datahub/company/test/test_management_commands.py +++ b/datahub/company/test/test_management_commands.py @@ -24,7 +24,9 @@ def test_both_flags_passed_to_command(): user1 = AdviserFactory(use_cdms_auth=False) user2 = AdviserFactory(use_cdms_auth=False) with pytest.raises(CommandError) as exception: - management.call_command(manageusers.Command(), user1.email, user2.email, '--enable', '--disable') + management.call_command( + manageusers.Command(), user1.email, user2.email, '--enable', '--disable' + ) assert 'Pass either --enable or --disable not both' in str(exception.value) diff --git a/datahub/company/urls.py b/datahub/company/urls.py index 4c5b7c9b3..d8789115f 100644 --- a/datahub/company/urls.py +++ b/datahub/company/urls.py @@ -3,7 +3,8 @@ from django.conf.urls import url from datahub.company.views import ( - CompaniesHouseCompanyReadOnlyViewSetV1, CompanyViewSetV3, ContactViewSet + CompaniesHouseCompanyReadOnlyViewSetV1, CompanyAuditViewSet, CompanyViewSetV3, + ContactAuditViewSet, ContactViewSet ) # CONTACT @@ -26,6 +27,10 @@ 'post': 'unarchive', }) +contact_audit = ContactAuditViewSet.as_view({ + 'get': 'retrieve', +}) + contact_urls = [ url(r'^contact$', contact_collection, name='list'), url(r'^contact/(?P[0-9a-z-]{36})$', contact_item, name='detail'), @@ -33,6 +38,8 @@ name='archive'), url(r'^contact/(?P[0-9a-z-]{36})/unarchive$', contact_unarchive, name='unarchive'), + url(r'^contact/(?P[0-9a-z-]{36})/audit$', contact_audit, + name='audit-item'), ] # COMPANY @@ -47,6 +54,10 @@ 'patch': 'partial_update' }) +company_audit = CompanyAuditViewSet.as_view({ + 'get': 'retrieve', +}) + company_archive = CompanyViewSetV3.as_view({ 'post': 'archive' }) @@ -70,6 +81,8 @@ name='archive'), url(r'^company/(?P[0-9a-z-]{36})/unarchive$', company_unarchive, name='unarchive'), + url(r'^company/(?P[0-9a-z-]{36})/audit$', company_audit, + name='audit-item'), ] ch_company_urls = [ diff --git a/datahub/company/views.py b/datahub/company/views.py index 4167f162c..4efbd9407 100644 --- a/datahub/company/views.py +++ b/datahub/company/views.py @@ -4,6 +4,7 @@ from rest_framework import mixins, viewsets from datahub.core.mixins import ArchivableViewSetMixin +from datahub.core.serializers import AuditSerializer from datahub.core.viewsets import CoreViewSetV1, CoreViewSetV3 from .models import Advisor, CompaniesHouseCompany, Company, Contact from .serializers import ( @@ -62,6 +63,13 @@ class CompanyViewSetV3(ArchivableViewSetMixin, CoreViewSetV3): ) +class CompanyAuditViewSet(CoreViewSetV3): + """Company audit views.""" + + serializer_class = AuditSerializer + queryset = Company.objects.all() + + class CompaniesHouseCompanyReadOnlyViewSetV1( mixins.ListModelMixin, mixins.RetrieveModelMixin, viewsets.GenericViewSet): """Companies House company GET only views.""" @@ -95,6 +103,13 @@ def get_additional_data(self, create): return data +class ContactAuditViewSet(CoreViewSetV3): + """Contact audit views.""" + + serializer_class = AuditSerializer + queryset = Contact.objects.all() + + class AdviserFilter(FilterSet): """Adviser filter.""" diff --git a/datahub/core/auth.py b/datahub/core/auth.py index 1026fabb6..5836f890a 100644 --- a/datahub/core/auth.py +++ b/datahub/core/auth.py @@ -41,8 +41,10 @@ def authenticate(self, request, username=None, password=None, **kwargs): auth_result = self.validate_cdms_credentials(username, password) if auth_result is True: # user authenticated via CDMS - user.set_password(password) # cache passwd hash for backup auth - user.is_active = True # ensure user can use django backend to auth, in case CDMS fails + # cache passwd hash for backup auth + user.set_password(password) + # ensure user can use django backend to auth, in case CDMS fails + user.is_active = True user.save() return user diff --git a/datahub/core/constants.py b/datahub/core/constants.py index fb3a3946f..d08d3eef4 100644 --- a/datahub/core/constants.py +++ b/datahub/core/constants.py @@ -14,8 +14,12 @@ class BusinessType(Enum): intermediary = Constant('Intermediary', '9bd14e94-5d95-e211-a939-e4115bead28a') partnership = Constant('Partnership', '8b6eaf7e-03e7-e611-bca1-e4115bead28a') sole_trader = Constant('Sole Trader', '99d14e94-5d95-e211-a939-e4115bead28a') - private_limited_company = Constant('Private limited company', '6f75408b-03e7-e611-bca1-e4115bead28a') - public_limited_company = Constant('Public limited company', 'dac8c591-03e7-e611-bca1-e4115bead28a') + private_limited_company = Constant( + 'Private limited company', '6f75408b-03e7-e611-bca1-e4115bead28a' + ) + public_limited_company = Constant( + 'Public limited company', 'dac8c591-03e7-e611-bca1-e4115bead28a' + ) class Country(Enum): @@ -49,12 +53,18 @@ class Country(Enum): bhutan = Constant('Bhutan', 'ab5f66a0-5d95-e211-a939-e4115bead28a') blank = Constant('BLANK', '98c8d93d-5d06-e311-a78e-e4115bead28a') bolivia = Constant('Bolivia', 'ac5f66a0-5d95-e211-a939-e4115bead28a') - bosnia_and_herzegovina = Constant('Bosnia and Herzegovina', 'ad5f66a0-5d95-e211-a939-e4115bead28a') + bosnia_and_herzegovina = Constant( + 'Bosnia and Herzegovina', 'ad5f66a0-5d95-e211-a939-e4115bead28a' + ) botswana = Constant('Botswana', 'ae5f66a0-5d95-e211-a939-e4115bead28a') bouvet_island = Constant('Bouvet Island', 'af5f66a0-5d95-e211-a939-e4115bead28a') brazil = Constant('Brazil', 'b05f66a0-5d95-e211-a939-e4115bead28a') - british_indian_ocean_territory = Constant('British Indian Ocean Territory', 'b15f66a0-5d95-e211-a939-e4115bead28a') - british_virgin_islands = Constant('British Virgin Islands', 'b25f66a0-5d95-e211-a939-e4115bead28a') + british_indian_ocean_territory = Constant( + 'British Indian Ocean Territory', 'b15f66a0-5d95-e211-a939-e4115bead28a' + ) + british_virgin_islands = Constant( + 'British Virgin Islands', 'b25f66a0-5d95-e211-a939-e4115bead28a' + ) brunei = Constant('Brunei', '56af72a6-5d95-e211-a939-e4115bead28a') bulgaria = Constant('Bulgaria', '57af72a6-5d95-e211-a939-e4115bead28a') burkina = Constant('Burkina', '58af72a6-5d95-e211-a939-e4115bead28a') @@ -65,16 +75,22 @@ class Country(Enum): canada = Constant('Canada', '5daf72a6-5d95-e211-a939-e4115bead28a') cape_verde = Constant('Cape Verde', '5eaf72a6-5d95-e211-a939-e4115bead28a') cayman_islands = Constant('Cayman Islands', '5faf72a6-5d95-e211-a939-e4115bead28a') - central_african_republic = Constant('Central African Republic', '60af72a6-5d95-e211-a939-e4115bead28a') + central_african_republic = Constant( + 'Central African Republic', '60af72a6-5d95-e211-a939-e4115bead28a' + ) chad = Constant('Chad', '61af72a6-5d95-e211-a939-e4115bead28a') chile = Constant('Chile', '62af72a6-5d95-e211-a939-e4115bead28a') china = Constant('China', '63af72a6-5d95-e211-a939-e4115bead28a') christmas_island = Constant('Christmas Island', '64af72a6-5d95-e211-a939-e4115bead28a') - cocos_keeling_islands = Constant('Cocos (Keeling) Islands', '65af72a6-5d95-e211-a939-e4115bead28a') + cocos_keeling_islands = Constant( + 'Cocos (Keeling) Islands', '65af72a6-5d95-e211-a939-e4115bead28a' + ) colombia = Constant('Colombia', '66af72a6-5d95-e211-a939-e4115bead28a') comoros = Constant('Comoros', '67af72a6-5d95-e211-a939-e4115bead28a') congo = Constant('Congo', '69af72a6-5d95-e211-a939-e4115bead28a') - congo_democratic_republic = Constant('Congo (Democratic Republic)', '68af72a6-5d95-e211-a939-e4115bead28a') + congo_democratic_republic = Constant( + 'Congo (Democratic Republic)', '68af72a6-5d95-e211-a939-e4115bead28a' + ) cook_islands = Constant('Cook Islands', '6aaf72a6-5d95-e211-a939-e4115bead28a') costa_rica = Constant('Costa Rica', '6baf72a6-5d95-e211-a939-e4115bead28a') croatia = Constant('Croatia', '6caf72a6-5d95-e211-a939-e4115bead28a') @@ -100,7 +116,9 @@ class Country(Enum): france = Constant('France', '82756b9a-5d95-e211-a939-e4115bead28a') french_guiana = Constant('French Guiana', 'dbf682ac-5d95-e211-a939-e4115bead28a') french_polynesia = Constant('French Polynesia', 'dcf682ac-5d95-e211-a939-e4115bead28a') - french_southern_territories = Constant('French Southern Territories', 'ddf682ac-5d95-e211-a939-e4115bead28a') + french_southern_territories = Constant( + 'French Southern Territories', 'ddf682ac-5d95-e211-a939-e4115bead28a' + ) gabon = Constant('Gabon', 'def682ac-5d95-e211-a939-e4115bead28a') gambia, _the = Constant('Gambia, The', 'dff682ac-5d95-e211-a939-e4115bead28a') georgia = Constant('Georgia', 'e0f682ac-5d95-e211-a939-e4115bead28a') @@ -188,7 +206,9 @@ class Country(Enum): nigeria = Constant('Nigeria', '4561b8be-5d95-e211-a939-e4115bead28a') niue = Constant('Niue', '4661b8be-5d95-e211-a939-e4115bead28a') norfolk_island = Constant('Norfolk Island', '4761b8be-5d95-e211-a939-e4115bead28a') - northern_mariana_islands = Constant('Northern Mariana Islands', '4861b8be-5d95-e211-a939-e4115bead28a') + northern_mariana_islands = Constant( + 'Northern Mariana Islands', '4861b8be-5d95-e211-a939-e4115bead28a' + ) norway = Constant('Norway', '4961b8be-5d95-e211-a939-e4115bead28a') occupied_palestinian_territories = Constant('Occupied Palestinian Territories', '35afd8d0-5d95-e211-a939-e4115bead28a') @@ -200,8 +220,9 @@ class Country(Enum): paraguay = Constant('Paraguay', '4f61b8be-5d95-e211-a939-e4115bead28a') peru = Constant('Peru', '5061b8be-5d95-e211-a939-e4115bead28a') philippines = Constant('Philippines', '5161b8be-5d95-e211-a939-e4115bead28a') - pitcairn_henderson_ducie_and_oeno_islands = Constant('Pitcairn, Henderson, Ducie and Oeno Islands', - '5261b8be-5d95-e211-a939-e4115bead28a') + pitcairn_henderson_ducie_and_oeno_islands = Constant( + 'Pitcairn, Henderson, Ducie and Oeno Islands', '5261b8be-5d95-e211-a939-e4115bead28a' + ) poland = Constant('Poland', '5361b8be-5d95-e211-a939-e4115bead28a') portugal = Constant('Portugal', '5461b8be-5d95-e211-a939-e4115bead28a') puerto_rico = Constant('Puerto Rico', '5561b8be-5d95-e211-a939-e4115bead28a') @@ -212,7 +233,9 @@ class Country(Enum): rwanda = Constant('Rwanda', '5a61b8be-5d95-e211-a939-e4115bead28a') samoa = Constant('Samoa', '5b61b8be-5d95-e211-a939-e4115bead28a') san_marino = Constant('San Marino', '5c61b8be-5d95-e211-a939-e4115bead28a') - sao_tome_and_principe = Constant('Sao Tome and Principe', '5d61b8be-5d95-e211-a939-e4115bead28a') + sao_tome_and_principe = Constant( + 'Sao Tome and Principe', '5d61b8be-5d95-e211-a939-e4115bead28a' + ) saudi_arabia = Constant('Saudi Arabia', '1a0be5c4-5d95-e211-a939-e4115bead28a') senegal = Constant('Senegal', '1b0be5c4-5d95-e211-a939-e4115bead28a') serbia = Constant('Serbia', '1c0be5c4-5d95-e211-a939-e4115bead28a') @@ -233,12 +256,16 @@ class Country(Enum): st_kitts_and_nevis = Constant('St Kitts and Nevis', '280be5c4-5d95-e211-a939-e4115bead28a') st_lucia = Constant('St Lucia', '290be5c4-5d95-e211-a939-e4115bead28a') st_martin = Constant('St Martin', '7c756b9a-5d95-e211-a939-e4115bead28a') - st_pierre_and_miquelon = Constant('St Pierre and Miquelon', '2a0be5c4-5d95-e211-a939-e4115bead28a') + st_pierre_and_miquelon = Constant( + 'St Pierre and Miquelon', '2a0be5c4-5d95-e211-a939-e4115bead28a' + ) st_vincent = Constant('St Vincent', '2b0be5c4-5d95-e211-a939-e4115bead28a') sudan = Constant('Sudan', '2c0be5c4-5d95-e211-a939-e4115bead28a') sudan_south = Constant('Sudan, South', '7e756b9a-5d95-e211-a939-e4115bead28a') surinam = Constant('Surinam', '2d0be5c4-5d95-e211-a939-e4115bead28a') - svalbard_and_jan_mayen_islands = Constant('Svalbard and Jan Mayen Islands', '2e0be5c4-5d95-e211-a939-e4115bead28a') + svalbard_and_jan_mayen_islands = Constant( + 'Svalbard and Jan Mayen Islands', '2e0be5c4-5d95-e211-a939-e4115bead28a' + ) swaziland = Constant('Swaziland', '2f0be5c4-5d95-e211-a939-e4115bead28a') sweden = Constant('Sweden', '300be5c4-5d95-e211-a939-e4115bead28a') switzerland = Constant('Switzerland', '310be5c4-5d95-e211-a939-e4115bead28a') @@ -255,7 +282,9 @@ class Country(Enum): tunisia = Constant('Tunisia', 'ad6ee1ca-5d95-e211-a939-e4115bead28a') turkey = Constant('Turkey', 'ae6ee1ca-5d95-e211-a939-e4115bead28a') turkmenistan = Constant('Turkmenistan', 'af6ee1ca-5d95-e211-a939-e4115bead28a') - turks_and_caicos_islands = Constant('Turks and Caicos Islands', 'b06ee1ca-5d95-e211-a939-e4115bead28a') + turks_and_caicos_islands = Constant( + 'Turks and Caicos Islands', 'b06ee1ca-5d95-e211-a939-e4115bead28a' + ) tuvalu = Constant('Tuvalu', 'b16ee1ca-5d95-e211-a939-e4115bead28a') uganda = Constant('Uganda', 'b26ee1ca-5d95-e211-a939-e4115bead28a') ukraine = Constant('Ukraine', 'b36ee1ca-5d95-e211-a939-e4115bead28a') @@ -396,7 +425,9 @@ class UKRegion(Enum): ukti_dubai_hub = Constant('UKTI Dubai Hub', 'e1dd40e9-3dfd-e311-8a2b-e4115bead28a') wales = Constant('Wales', '8d4cd12a-6095-e211-a939-e4115bead28a') west_midlands = Constant('West Midlands', '854cd12a-6095-e211-a939-e4115bead28a') - yorkshire_and_the_humber = Constant('Yorkshire and The Humber', '834cd12a-6095-e211-a939-e4115bead28a') + yorkshire_and_the_humber = Constant( + 'Yorkshire and The Humber', '834cd12a-6095-e211-a939-e4115bead28a' + ) class Service(Enum): @@ -437,8 +468,12 @@ class CompanyClassification(Enum): tier_a1 = Constant("Tier A1 – Tomorrow's Champions", '2b55bb11-9518-e411-985c-e4115bead28a') tier_a2 = Constant('Tier A2 -Global Partners', '7e0c261a-d447-e411-985c-e4115bead28a') tier_b = Constant('Tier B - Global Accounts', 'bb1bf800-8d53-e311-aef3-441ea13961e2') - tier_c = Constant('Tier C - Local Accounts (UKTI Managed)', 'bd1bf800-8d53-e311-aef3-441ea13961e2') - tier_dl = Constant('Tier D - LEP Managed Branch (not IST)', '12798372-8eb4-e511-88b6-e4115bead28a') + tier_c = Constant( + 'Tier C - Local Accounts (UKTI Managed)', 'bd1bf800-8d53-e311-aef3-441ea13961e2' + ) + tier_dl = Constant( + 'Tier D - LEP Managed Branch (not IST)', '12798372-8eb4-e511-88b6-e4115bead28a' + ) tier_dg = Constant('Tier D - POST Identified/Managed', '572dfefe-cd1d-e611-9bdc-e4115bead28a') diff --git a/datahub/core/serializers.py b/datahub/core/serializers.py index 2aee37da2..1d618ac21 100644 --- a/datahub/core/serializers.py +++ b/datahub/core/serializers.py @@ -4,6 +4,7 @@ from django.core.exceptions import ObjectDoesNotExist from rest_framework import serializers from rest_framework.fields import UUIDField +from reversion.models import Version class ConstantModelSerializer(serializers.Serializer): @@ -16,6 +17,61 @@ class Meta: # noqa: D101 fields = '__all__' +class AuditSerializer(serializers.Serializer): + """Generic serializer for audit logs.""" + + def to_representation(self, instance): + """Override serialization process completely to get the Versions.""" + versions = Version.objects.get_for_object(instance) + version_pairs = ( + (versions[n], versions[n + 1]) for n in range(len(versions) - 1) + ) + + return { + 'results': self._construct_changelog(version_pairs), + } + + def _construct_changelog(self, version_pairs): + changelog = [] + + for v_new, v_old in version_pairs: + version_creator = v_new.revision.user + creator_repr = None + if version_creator: + creator_repr = { + 'id': str(version_creator.pk), + 'first_name': version_creator.first_name, + 'last_name': version_creator.last_name, + 'name': version_creator.name, + 'email': version_creator.email, + } + + changelog.append({ + 'user': creator_repr, + 'timestamp': v_new.revision.date_created, + 'comment': v_new.revision.comment or '', + 'changes': self._diff_versions( + v_old.field_dict, v_new.field_dict + ), + }) + + return changelog + + @staticmethod + def _diff_versions(old_version, new_version): + changes = {} + + for field_name, new_value in new_version.items(): + if field_name not in old_version: + changes[field_name] = [None, new_value] + else: + old_value = old_version[field_name] + if old_value != new_value: + changes[field_name] = [old_value, new_value] + + return changes + + class NestedRelatedField(serializers.RelatedField): """DRF serialiser field for foreign keys and many-to-many fields. diff --git a/datahub/core/test/test_auth.py b/datahub/core/test/test_auth.py index 68d8c43e5..8771b5972 100644 --- a/datahub/core/test/test_auth.py +++ b/datahub/core/test/test_auth.py @@ -87,7 +87,11 @@ def test_invalid_cdms_credentials(auth_mock, settings, live_server): auth = requests.auth.HTTPBasicAuth(application.client_id, application.client_secret) response = requests.post( url, - data={'grant_type': 'password', 'username': cdms_user.email, 'password': cdms_user.password}, + data={ + 'grant_type': 'password', + 'username': cdms_user.email, + 'password': cdms_user.password + }, auth=auth ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -110,7 +114,11 @@ def test_cdms_returns_500(mocked_login, live_server): auth = requests.auth.HTTPBasicAuth(application.client_id, application.client_secret) response = requests.post( url, - data={'grant_type': 'password', 'username': cdms_user.email, 'password': cdms_user.password}, + data={ + 'grant_type': 'password', + 'username': cdms_user.email, + 'password': cdms_user.password + }, auth=auth ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -255,7 +263,11 @@ def test_valid_cdms_credentials_user_not_whitelisted(auth_mock, live_server): auth = requests.auth.HTTPBasicAuth(application.client_id, application.client_secret) response = requests.post( url, - data={'grant_type': 'password', 'username': cdms_user.email, 'password': cdms_user.password}, + data={ + 'grant_type': 'password', + 'username': cdms_user.email, + 'password': cdms_user.password + }, auth=auth ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -278,7 +290,11 @@ def test_valid_django_user(auth_mock, live_server): auth = requests.auth.HTTPBasicAuth(application.client_id, application.client_secret) response = requests.post( url, - data={'grant_type': 'password', 'username': django_user.email, 'password': DJANGO_USER_PASSWORD}, + data={ + 'grant_type': 'password', + 'username': django_user.email, + 'password': DJANGO_USER_PASSWORD + }, auth=auth ) assert response.status_code == status.HTTP_200_OK diff --git a/datahub/core/test/test_serializers.py b/datahub/core/test/test_serializers.py index 8ef81de60..87b479c2f 100644 --- a/datahub/core/test/test_serializers.py +++ b/datahub/core/test/test_serializers.py @@ -5,7 +5,29 @@ from django.core.exceptions import ObjectDoesNotExist from rest_framework.exceptions import ValidationError -from datahub.core.serializers import NestedRelatedField +from datahub.core.serializers import AuditSerializer, NestedRelatedField + + +def test_audit_log_diff_algo(): + """Test simple diff algorithm.""" + given = { + 'old': { + 'field1': 'val1', + 'field2': 'val2', + }, + 'new': { + 'field1': 'val1', + 'field2': 'new-val', + 'field3': 'added', + }, + } + + expected = { + 'field2': ['val2', 'new-val'], + 'field3': [None, 'added'], + } + + assert AuditSerializer._diff_versions(given['old'], given['new']) == expected def test_nested_rel_field_to_internal_dict(): diff --git a/datahub/core/test_utils.py b/datahub/core/test_utils.py index 16798d0bf..d7fe5935f 100644 --- a/datahub/core/test_utils.py +++ b/datahub/core/test_utils.py @@ -2,7 +2,6 @@ import pytest from django.contrib.auth import get_user_model -from django.test import TestCase from django.utils.timezone import now from oauth2_provider.models import AccessToken, Application from rest_framework.test import APIClient @@ -25,28 +24,33 @@ def get_test_user(): return test_user -class LeelooTestCase(TestCase): +class APITestMixin: """All the tests using the DB and accessing end points behind auth should use this class.""" pytestmark = pytest.mark.django_db # use db - def setUp(self): - """Set ups some utils.""" - self._user = None - self._application = None - self._token = None - self.user = self.get_user() - self.application = self.get_application() - self.token = self.get_token() - self.api_client = self.get_logged_in_api_client() - - def get_user(self): + @property + def user(self): """Return the user.""" - if self._user: - return self._user - return get_test_user() + if not hasattr(self, '_user'): + self._user = get_test_user() + return self._user - def get_logged_in_api_client(self): + @property + def token(self): + """Get access token for user test.""" + if not hasattr(self, '_token'): + self._token = AccessToken( + user=self.user, + application=self.application, + token='123456789', # unsafe token, just for testing + expires=datetime.datetime.now() + datetime.timedelta(hours=1), + scope='write read' + ) + return self._token.token + + @property + def api_client(self): """ Login using the OAuth2 authentication. @@ -59,32 +63,17 @@ def get_logged_in_api_client(self): client.credentials(Authorization=f'Bearer {self.token}') return client - def get_token(self): - """Get access token for user test.""" - if self._token: - return self._token - - token = AccessToken( - user=self.user, - application=self.application, - token='123456789', # unsafe token, just for testing - expires=datetime.datetime.now() + datetime.timedelta(hours=1), - scope='write read' - ) - return token.token - - def get_application(self): + @property + def application(self): """Return the test application.""" - if self._application: - return self._application - - application, _ = Application.objects.get_or_create( - user=get_test_user(), - client_type=Application.CLIENT_CONFIDENTIAL, - authorization_grant_type=Application.GRANT_PASSWORD, - name='Test client' - ) - return application + if not hasattr(self, '_application'): + self._application, _ = Application.objects.get_or_create( + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_PASSWORD, + name='Test client' + ) + return self._application def synchronous_executor_submit(fn, *args, **kwargs): diff --git a/datahub/core/utils.py b/datahub/core/utils.py index e6b172915..18197f8fa 100644 --- a/datahub/core/utils.py +++ b/datahub/core/utils.py @@ -69,9 +69,12 @@ def get_s3_client(): return s3 -def sign_s3_url(bucket_name, path, method='get_object', expires=3600): +def sign_s3_url(bucket_name, path, method='get_object', expires=3600, client=None): """Sign s3 url using global config, and given expiry in seconds.""" - return get_s3_client().generate_presigned_url( + if client is None: + client = get_s3_client() + + return client.generate_presigned_url( ClientMethod=method, Params={ 'Bucket': bucket_name, @@ -79,3 +82,16 @@ def sign_s3_url(bucket_name, path, method='get_object', expires=3600): }, ExpiresIn=expires, ) + + +def delete_s3_obj(bucket, key, client=None): + """Remove object from S3 Bucket.""" + if client is None: + client = get_s3_client() + + response = client.delete_object( + Bucket=bucket, + Key=key, + ) + + assert response['ResponseMetadata']['HTTPStatusCode'] == 204 diff --git a/datahub/core/viewsets.py b/datahub/core/viewsets.py index 6452ed41d..a5bfcea80 100644 --- a/datahub/core/viewsets.py +++ b/datahub/core/viewsets.py @@ -66,6 +66,7 @@ def get_additional_data(self, create): class CoreViewSetV3(mixins.CreateModelMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, + mixins.DestroyModelMixin, mixins.ListModelMixin, GenericViewSet): """Base class for v3 view sets.""" diff --git a/datahub/dashboard/test/test_views.py b/datahub/dashboard/test/test_views.py index 1020fcfe8..c294b3bca 100644 --- a/datahub/dashboard/test/test_views.py +++ b/datahub/dashboard/test/test_views.py @@ -3,11 +3,11 @@ from datahub.company.test.factories import CompanyFactory from datahub.core import constants -from datahub.core.test_utils import get_test_user, LeelooTestCase +from datahub.core.test_utils import APITestMixin, get_test_user from datahub.interaction.test.factories import InteractionFactory -class DashboardTestCase(LeelooTestCase): +class TestDashboard(APITestMixin): """Dashboard test case.""" def test_intelligent_homepage(self): diff --git a/datahub/dashboard/views.py b/datahub/dashboard/views.py index 3a5ab6a85..77c4ffe69 100644 --- a/datahub/dashboard/views.py +++ b/datahub/dashboard/views.py @@ -30,5 +30,8 @@ def get(self, request, format=None): created_on__gte=days_in_the_past ).order_by('-created_on') - serializer = IntelligentHomepageSerializer({'interactions': interactions, 'contacts': contacts}) + serializer = IntelligentHomepageSerializer({ + 'interactions': interactions, + 'contacts': contacts + }) return Response(data=serializer.data) diff --git a/datahub/documents/models.py b/datahub/documents/models.py index 6278921e0..c12fad954 100644 --- a/datahub/documents/models.py +++ b/datahub/documents/models.py @@ -1,11 +1,17 @@ import uuid +from logging import getLogger from os import path from django.conf import settings -from django.db import models +from django.db import models, transaction +from django.db.models.signals import post_delete +from django.dispatch import receiver +from raven.contrib.django.raven_compat.models import client from datahub.core.models import ArchivableModel, BaseModel -from datahub.core.utils import sign_s3_url +from datahub.core.utils import delete_s3_obj, executor, sign_s3_url + +logger = getLogger(__name__) class Document(BaseModel, ArchivableModel): @@ -61,3 +67,27 @@ def s3_key(self): def __str__(self): """String repr.""" return f'Document(filename="{self.filename}", av_clean={self.av_clean})' + + +@receiver(post_delete, sender=Document) +def document_post_delete(sender, **kwargs): + """Handle document delete.""" + instance = kwargs['instance'] + if instance.uploaded_on is None: + return + + # grab only needed vars for closure, so instance can go out-of-scope + bucket = instance.s3_bucket + key = instance.s3_key + + def delete_document(): + try: + delete_s3_obj(bucket, key) + except Exception: + msg = 'Exception during s3 object removal.' + logger.exception(msg) + client.captureException(msg) + + transaction.on_commit( + lambda: executor.submit(delete_document) + ) diff --git a/datahub/interaction/test/test_interaction_views.py b/datahub/interaction/test/test_interaction_views.py index 375a97523..463a34367 100644 --- a/datahub/interaction/test/test_interaction_views.py +++ b/datahub/interaction/test/test_interaction_views.py @@ -5,12 +5,12 @@ from datahub.company.test.factories import AdviserFactory, CompanyFactory, ContactFactory from datahub.core import constants -from datahub.core.test_utils import LeelooTestCase +from datahub.core.test_utils import APITestMixin from datahub.interaction.test.factories import InteractionFactory from datahub.investment.test.factories import InvestmentProjectFactory -class InteractionTestCase(LeelooTestCase): +class TestInteraction(APITestMixin): """Interaction test case.""" def test_interaction_detail_view(self): diff --git a/datahub/investment/__init__.py b/datahub/investment/__init__.py index e69de29bb..df816583e 100644 --- a/datahub/investment/__init__.py +++ b/datahub/investment/__init__.py @@ -0,0 +1 @@ +default_app_config = 'datahub.investment.apps.InvestmentConfig' diff --git a/datahub/investment/apps.py b/datahub/investment/apps.py new file mode 100644 index 000000000..076d823a4 --- /dev/null +++ b/datahub/investment/apps.py @@ -0,0 +1,14 @@ +from django.apps import AppConfig + + +class InvestmentConfig(AppConfig): + """Configuration class for this app.""" + + name = 'datahub.investment' + + def ready(self): + """Registers the signals for this app. + + This is the preferred way to register signals in the Django documentation. + """ + import datahub.investment.signals # noqa: F401 diff --git a/datahub/investment/models.py b/datahub/investment/models.py index b7c6dfcb1..af4d08e05 100644 --- a/datahub/investment/models.py +++ b/datahub/investment/models.py @@ -3,9 +3,8 @@ import uuid from django.conf import settings +from django.core.exceptions import ObjectDoesNotExist from django.db import models, transaction -from django.db.models.signals import post_save -from django.dispatch import receiver from model_utils import Choices from datahub.core.constants import InvestmentProjectStage @@ -117,8 +116,11 @@ def project_code(self): """ if self.cdms_project_code: return self.cdms_project_code - project_num = self.investmentprojectcode.id - return f'DHP-{project_num:08d}' + try: + project_num = self.investmentprojectcode.id + return f'DHP-{project_num:08d}' + except ObjectDoesNotExist: + return None class IProjectValueAbstract(models.Model): @@ -236,6 +238,10 @@ class InvestmentProjectTeamMember(models.Model): adviser = models.ForeignKey('company.Advisor', on_delete=models.CASCADE, related_name='+') role = models.CharField(max_length=MAX_LENGTH) + def __str__(self): + """Human-readable representation.""" + return f'{self.investment_project} – {self.adviser} – {self.role}' + class Meta: # noqa: D101 unique_together = (('investment_project', 'adviser'),) @@ -305,6 +311,12 @@ class Meta: # noqa: D101 ('project', 'doc_type', 'filename'), ) + def delete(self, using=None, keep_parents=False): + """Ensure document is removed when parent is being deleted.""" + result = super().delete(using, keep_parents) + self.document.delete(using, keep_parents) + return result + @classmethod def create_from_declaration_request(cls, project, field, filename): """Create investment document along with correct Document creation.""" @@ -322,16 +334,3 @@ def create_from_declaration_request(cls, project, field, filename): investment_doc.save() return investment_doc - - -@receiver(post_save, sender=InvestmentProject) -def project_post_save(sender, **kwargs): - """Creates a project code for investment projects on creation. - - Projects with a CDMS project code do not get a new project code. - """ - instance = kwargs['instance'] - created = kwargs['created'] - raw = kwargs['raw'] - if created and not raw and not instance.cdms_project_code: - InvestmentProjectCode.objects.create(project=instance) diff --git a/datahub/investment/serializers.py b/datahub/investment/serializers.py index cc887c550..a8f3ddc87 100644 --- a/datahub/investment/serializers.py +++ b/datahub/investment/serializers.py @@ -1,7 +1,6 @@ """Investment serialisers for views.""" from rest_framework import serializers -from reversion.models import Version import datahub.metadata.models as meta_models from datahub.company.models import Company, Contact @@ -12,17 +11,14 @@ from datahub.investment.validate import validate -class IProjectSerializer(serializers.ModelSerializer): +class IProjectSummarySerializer(serializers.ModelSerializer): """Serialiser for investment project endpoints.""" + incomplete_fields = serializers.SerializerMethodField() project_code = serializers.CharField(read_only=True) - investment_type = NestedRelatedField(meta_models.InvestmentType) stage = NestedRelatedField(meta_models.InvestmentProjectStage, required=False) - # phase is deprecated – remove once front end is using stage - phase = NestedRelatedField(meta_models.InvestmentProjectStage, - required=False, source='stage') project_shareable = serializers.BooleanField(required=True) investor_company = NestedRelatedField( Company, required=True, allow_null=False @@ -60,12 +56,26 @@ class IProjectSerializer(serializers.ModelSerializer): ) archived_by = NestedAdviserField(read_only=True) + def get_incomplete_fields(self, instance): + """Returns the names of the fields that still need to be completed in order to + move to the next stage. + """ + return tuple(validate(instance=instance, next_stage=True)) + def validate(self, data): """Validates the object after individual fields have been validated. Performs stage-dependent validation of the different sections. + + When transitioning stage, all fields required for the new stage are + validated. In other cases, only the fields being modified are validated. + If a project ends up in an invalid state, this avoids the user being + unable to rectify the situation. """ - errors = validate(self.instance, data) + fields = None + if self.partial and 'stage' not in data: + fields = data.keys() + errors = validate(self.instance, data, fields=fields) if errors: raise serializers.ValidationError(errors) @@ -75,6 +85,7 @@ class Meta: # noqa: D101 model = InvestmentProject fields = ( 'id', + 'incomplete_fields', 'name', 'project_code', 'description', @@ -93,7 +104,6 @@ class Meta: # noqa: D101 'approved_non_fdi', 'investment_type', 'stage', - 'phase', # For backwards compatibility 'investor_company', 'intermediate_company', 'client_contacts', @@ -125,61 +135,6 @@ class Meta: # noqa: D101 } -class IProjectAuditSerializer(serializers.Serializer): - """Serializer for Investment Project audit log.""" - - def to_representation(self, instance): - """Overwrite serialization process completely to get the Versions.""" - versions = Version.objects.get_for_object(instance) - version_pairs = ( - (versions[n], versions[n + 1]) for n in range(len(versions) - 1) - ) - - return { - 'results': self._construct_changelog(version_pairs), - } - - def _construct_changelog(self, version_pairs): - changelog = [] - - for v_new, v_old in version_pairs: - version_creator = v_new.revision.user - creator_repr = None - if version_creator: - creator_repr = { - 'id': str(version_creator.pk), - 'first_name': version_creator.first_name, - 'last_name': version_creator.last_name, - 'name': version_creator.name, - 'email': version_creator.email, - } - - changelog.append({ - 'user': creator_repr, - 'timestamp': v_new.revision.date_created, - 'comment': v_new.revision.comment or '', - 'changes': self._diff_versions( - v_old.field_dict, v_new.field_dict - ), - }) - - return changelog - - @staticmethod - def _diff_versions(old_version, new_version): - changes = {} - - for field_name, new_value in new_version.items(): - if field_name not in old_version: - changes[field_name] = [None, new_value] - else: - old_value = old_version[field_name] - if old_value != new_value: - changes[field_name] = [old_value, new_value] - - return changes - - class IProjectValueSerializer(serializers.ModelSerializer): """Serialiser for investment project value objects.""" @@ -296,19 +251,19 @@ class Meta: # noqa: D101 ) -class IProjectUnifiedSerializer(IProjectSerializer, IProjectValueSerializer, - IProjectRequirementsSerializer, IProjectTeamSerializer): +class IProjectSerializer(IProjectSummarySerializer, IProjectValueSerializer, + IProjectRequirementsSerializer, IProjectTeamSerializer): """Serialiser for investment projects, used with the new unified investment endpoint.""" class Meta: # noqa: D101 model = InvestmentProject fields = ( - IProjectSerializer.Meta.fields + + IProjectSummarySerializer.Meta.fields + IProjectValueSerializer.Meta.fields + IProjectRequirementsSerializer.Meta.fields + IProjectTeamSerializer.Meta.fields ) - extra_kwargs = IProjectSerializer.Meta.extra_kwargs + extra_kwargs = IProjectSummarySerializer.Meta.extra_kwargs class IProjectDocumentSerializer(serializers.ModelSerializer): diff --git a/datahub/investment/signals.py b/datahub/investment/signals.py new file mode 100644 index 000000000..3aebf06ab --- /dev/null +++ b/datahub/investment/signals.py @@ -0,0 +1,20 @@ +from django.db.models.signals import post_save +from django.dispatch import receiver + +from datahub.investment.models import InvestmentProject, InvestmentProjectCode + + +@receiver(post_save, sender=InvestmentProject, dispatch_uid='project_post_save') +def project_post_save(sender, **kwargs): + """Creates a project code for investment projects on creation. + + Projects with a CDMS project code do not get a new project code. + + This generates project codes for fixtures loaded via manage.py loaddata + (i.e. when kwargs['raw'] is True), though that may need to change if + fixed project codes are required for that fixtures. + """ + instance = kwargs['instance'] + created = kwargs['created'] + if created and not instance.cdms_project_code: + InvestmentProjectCode.objects.create(project=instance) diff --git a/datahub/investment/test/test_investment_serializers.py b/datahub/investment/test/test_investment_serializers.py deleted file mode 100644 index 56547fccd..000000000 --- a/datahub/investment/test/test_investment_serializers.py +++ /dev/null @@ -1,23 +0,0 @@ -from datahub.investment import serializers - - -def test_audit_log_diff_algo(): - """Test simple diff algorithm.""" - given = { - 'old': { - 'field1': 'val1', - 'field2': 'val2', - }, - 'new': { - 'field1': 'val1', - 'field2': 'new-val', - 'field3': 'added', - }, - } - - expected = { - 'field2': ['val2', 'new-val'], - 'field3': [None, 'added'], - } - - assert serializers.IProjectAuditSerializer._diff_versions(given['old'], given['new']) == expected diff --git a/datahub/investment/test/test_models.py b/datahub/investment/test/test_models.py index 506ed482b..d88397a5e 100644 --- a/datahub/investment/test/test_models.py +++ b/datahub/investment/test/test_models.py @@ -27,6 +27,15 @@ def test_project_code_datahub(): assert project.project_code == f'DHP-{project_num:08d}' +def test_no_project_code(): + """Tests that None is returned when a project code is not set.""" + # cdms_project_code is set and removed to avoid a DH project code + # being generated + project = InvestmentProjectFactory(cdms_project_code='P-79661656') + project.cdms_project_code = None + assert project.project_code is None + + def test_project_manager_team_none(): """Tests project_manager_team for a project without a project manager.""" project = InvestmentProjectFactory() diff --git a/datahub/investment/test/test_validate.py b/datahub/investment/test/test_validate.py index f69d6c932..95627f0e7 100644 --- a/datahub/investment/test/test_validate.py +++ b/datahub/investment/test/test_validate.py @@ -3,7 +3,7 @@ from datahub.company.test.factories import AdviserFactory, ContactFactory from datahub.core import constants from datahub.investment.serializers import ( - IProjectRequirementsSerializer, IProjectSerializer, IProjectTeamSerializer, + IProjectRequirementsSerializer, IProjectSummarySerializer, IProjectTeamSerializer, IProjectValueSerializer ) from datahub.investment.test.factories import InvestmentProjectFactory @@ -19,7 +19,7 @@ def test_validate_project_fail(): investment_type_id=constants.InvestmentType.fdi.value.id, fdi_type_id=None ) - errors = validate(instance=project, fields=IProjectSerializer.Meta.fields) + errors = validate(instance=project, fields=IProjectSummarySerializer.Meta.fields) assert errors == { 'fdi_type': 'This field is required.' } @@ -30,7 +30,7 @@ def test_validate_project_instance_success(): project = InvestmentProjectFactory( client_contacts=[ContactFactory().id, ContactFactory().id] ) - errors = validate(instance=project, fields=IProjectSerializer.Meta.fields) + errors = validate(instance=project, fields=IProjectSummarySerializer.Meta.fields) assert not errors @@ -40,7 +40,7 @@ def test_validate_non_fdi_type(): project = InvestmentProjectFactory( investment_type_id=investment_type_id ) - errors = validate(instance=project, fields=IProjectSerializer.Meta.fields) + errors = validate(instance=project, fields=IProjectSummarySerializer.Meta.fields) assert 'non_fdi_type' in errors assert 'fdi_type' not in errors @@ -51,7 +51,7 @@ def test_validate_fdi_type(): project = InvestmentProjectFactory( investment_type_id=investment_type_id ) - errors = validate(instance=project, fields=IProjectSerializer.Meta.fields) + errors = validate(instance=project, fields=IProjectSummarySerializer.Meta.fields) assert 'fdi_type' in errors assert 'non_fdi_type' not in errors @@ -62,7 +62,7 @@ def test_validate_project_referral_website(): project = InvestmentProjectFactory( referral_source_activity_id=referral_source_id ) - errors = validate(instance=project, fields=IProjectSerializer.Meta.fields) + errors = validate(instance=project, fields=IProjectSummarySerializer.Meta.fields) assert 'referral_source_activity_website' in errors assert 'referral_source_activity_event' not in errors assert 'referral_source_activity_marketing' not in errors @@ -74,7 +74,7 @@ def test_validate_project_referral_event(): project = InvestmentProjectFactory( referral_source_activity_id=referral_source_id ) - errors = validate(instance=project, fields=IProjectSerializer.Meta.fields) + errors = validate(instance=project, fields=IProjectSummarySerializer.Meta.fields) assert 'referral_source_activity_event' in errors assert 'referral_source_activity_website' not in errors assert 'referral_source_activity_marketing' not in errors @@ -86,7 +86,7 @@ def test_validate_project_referral_marketing(): project = InvestmentProjectFactory( referral_source_activity_id=referral_source_id ) - errors = validate(instance=project, fields=IProjectSerializer.Meta.fields) + errors = validate(instance=project, fields=IProjectSummarySerializer.Meta.fields) assert 'referral_source_activity_marketing' in errors assert 'referral_source_activity_website' not in errors assert 'referral_source_activity_event' not in errors @@ -101,7 +101,7 @@ def test_validate_project_update_data(): 'referral_source_activity': referral_source } errors = validate(instance=project, update_data=update_data, - fields=IProjectSerializer.Meta.fields) + fields=IProjectSummarySerializer.Meta.fields) assert 'referral_source_activity_marketing' in errors assert 'referral_source_activity_website' not in errors assert 'referral_source_activity_event' not in errors diff --git a/datahub/investment/test/test_views.py b/datahub/investment/test/test_views.py index 973838f58..5366055fe 100644 --- a/datahub/investment/test/test_views.py +++ b/datahub/investment/test/test_views.py @@ -2,591 +2,32 @@ import re import uuid +from collections import Counter from datetime import datetime from unittest.mock import patch import pytest import reversion +from django.utils.timezone import now from rest_framework import status from rest_framework.reverse import reverse from datahub.company.test.factories import (AdviserFactory, CompanyFactory, ContactFactory) from datahub.core import constants -from datahub.core.test_utils import LeelooTestCase -from datahub.core.utils import executor -from datahub.documents.av_scan import virus_scan_document -from datahub.investment import views -from datahub.investment.models import InvestmentProjectTeamMember, IProjectDocument -from datahub.investment.test.factories import ( - InvestmentProjectFactory, InvestmentProjectTeamMemberFactory -) - - -class InvestmentViewsTestCase(LeelooTestCase): - """Tests for the deprecated project, value, team and requirements views.""" - - def test_list_projects_success(self): - """Test successfully listing projects.""" - project = InvestmentProjectFactory() - url = reverse('api-v3:investment:project') - response = self.api_client.get(url) - - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['count'] == 1 - assert response_data['results'][0]['id'] == str(project.id) - - def test_list_projects_investor_company_success(self): - """Test successfully listing projects for an investor company.""" - company = CompanyFactory() - project = InvestmentProjectFactory(investor_company_id=company.id) - InvestmentProjectFactory() - url = reverse('api-v3:investment:project') - response = self.api_client.get(url, { - 'investor_company_id': str(company.id) - }) - - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['count'] == 1 - assert response_data['results'][0]['id'] == str(project.id) - - def test_create_project_complete_success(self): - """Test successfully creating a project.""" - contacts = [ContactFactory(), ContactFactory()] - investor_company = CompanyFactory() - intermediate_company = CompanyFactory() - adviser = AdviserFactory() - url = reverse('api-v3:investment:project') - aerospace_id = constants.Sector.aerospace_assembly_aircraft.value.id - new_site_id = (constants.FDIType.creation_of_new_site_or_activity - .value.id) - retail_business_activity = constants.InvestmentBusinessActivity.retail - business_activity_id = retail_business_activity.value.id - request_data = { - 'name': 'project name', - 'description': 'project description', - 'nda_signed': False, - 'estimated_land_date': '2020-12-12', - 'project_shareable': False, - 'investment_type': { - 'id': constants.InvestmentType.fdi.value.id - }, - 'stage': { - 'id': constants.InvestmentProjectStage.prospect.value.id - }, - 'business_activities': [{ - 'id': business_activity_id - }], - 'client_contacts': [{ - 'id': str(contacts[0].id) - }, { - 'id': str(contacts[1].id) - }], - 'client_relationship_manager': { - 'id': str(adviser.id) - }, - 'fdi_type': { - 'id': new_site_id - }, - 'investor_company': { - 'id': str(investor_company.id) - }, - 'intermediate_company': { - 'id': str(intermediate_company.id) - }, - 'referral_source_activity': { - 'id': constants.ReferralSourceActivity.cold_call.value.id - }, - 'referral_source_adviser': { - 'id': str(adviser.id) - }, - 'sector': { - 'id': str(aerospace_id) - } - } - response = self.api_client.post(url, data=request_data, format='json') - assert response.status_code == status.HTTP_201_CREATED - response_data = response.json() - assert response_data['name'] == request_data['name'] - assert response_data['description'] == request_data['description'] - assert response_data['nda_signed'] == request_data['nda_signed'] - assert (response_data['estimated_land_date'] == request_data[ - 'estimated_land_date']) - assert re.match('^DHP-\d+$', response_data['project_code']) - - assert (response_data['investment_type']['id'] == request_data[ - 'investment_type']['id']) - assert response_data['investor_company']['id'] == str( - investor_company.id) - assert response_data['intermediate_company']['id'] == str( - intermediate_company.id) - assert response_data['referral_source_adviser']['id'] == str( - adviser.id) - assert response_data['stage']['id'] == request_data['stage']['id'] - assert len(response_data['client_contacts']) == 2 - assert sorted(contact['id'] for contact in response_data[ - 'client_contacts']) == sorted(contact.id for contact in contacts) - assert len(response_data['business_activities']) == 1 - assert (response_data['business_activities'][0]['id'] == - business_activity_id) - - def test_create_project_fail(self): - """Test creating a project with missing required values.""" - url = reverse('api-v3:investment:project') - request_data = {} - response = self.api_client.post(url, data=request_data, format='json') - assert response.status_code == status.HTTP_400_BAD_REQUEST - response_data = response.json() - assert response_data == { - 'business_activities': ['This field is required.'], - 'client_contacts': ['This field is required.'], - 'client_relationship_manager': ['This field is required.'], - 'description': ['This field is required.'], - 'estimated_land_date': ['This field is required.'], - 'investor_company': ['This field is required.'], - 'investment_type': ['This field is required.'], - 'name': ['This field is required.'], - 'nda_signed': ['This field is required.'], - 'project_shareable': ['This field is required.'], - 'referral_source_activity': ['This field is required.'], - 'referral_source_adviser': ['This field is required.'], - 'sector': ['This field is required.'] - } - - def test_create_project_fail_none(self): - """Test creating a project with None for required values.""" - url = reverse('api-v3:investment:project') - request_data = { - 'business_activities': None, - 'client_contacts': None, - 'client_relationship_manager': None, - 'description': None, - 'estimated_land_date': None, - 'investor_company': None, - 'investment_type': None, - 'name': None, - 'nda_signed': None, - 'project_shareable': None, - 'referral_source_activity': None, - 'referral_source_adviser': None, - 'sector': None - } - response = self.api_client.post(url, data=request_data, format='json') - assert response.status_code == status.HTTP_400_BAD_REQUEST - response_data = response.json() - assert response_data == { - 'business_activities': ['This field may not be null.'], - 'client_contacts': ['This field may not be null.'], - 'client_relationship_manager': ['This field may not be null.'], - 'description': ['This field may not be null.'], - 'estimated_land_date': ['This field may not be null.'], - 'investor_company': ['This field may not be null.'], - 'investment_type': ['This field may not be null.'], - 'name': ['This field may not be null.'], - 'nda_signed': ['This field may not be null.'], - 'project_shareable': ['This field may not be null.'], - 'referral_source_activity': ['This field may not be null.'], - 'referral_source_adviser': ['This field may not be null.'], - 'sector': ['This field may not be null.'] - } - - def test_create_project_fail_empty_to_many(self): - """Test creating a project with empty to-many field values.""" - url = reverse('api-v3:investment:project') - request_data = { - 'business_activities': [], - 'client_contacts': [] - } - response = self.api_client.post(url, data=request_data, format='json') - assert response.status_code == status.HTTP_400_BAD_REQUEST - response_data = response.json() - assert response_data.keys() >= { - 'business_activities', 'client_contacts' - } - assert response_data['business_activities'] == [ - 'This list may not be empty.'] - assert response_data['client_contacts'] == [ - 'This list may not be empty.'] - - def test_get_project_success(self): - """Test successfully getting a project.""" - contacts = [ContactFactory().id, ContactFactory().id] - project = InvestmentProjectFactory(client_contacts=contacts) - url = reverse('api-v3:investment:project-item', kwargs={'pk': project.pk}) - response = self.api_client.get(url) - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['id'] == str(project.id) - assert response_data['name'] == project.name - assert response_data['description'] == project.description - assert response_data['nda_signed'] == project.nda_signed - assert response_data['project_code'] == project.project_code - assert (response_data['estimated_land_date'] == - str(project.estimated_land_date)) - assert (response_data['investment_type']['id'] == - str(project.investment_type.id)) - assert (response_data['stage']['id'] == str(project.stage.id)) - assert sorted(contact['id'] for contact in response_data[ - 'client_contacts']) == sorted(contacts) - - def test_patch_project_conditional_failure(self): - """Test updating a project w/ missing conditionally required value.""" - project = InvestmentProjectFactory( - client_contacts=[ContactFactory().id, ContactFactory().id] - ) - url = reverse('api-v3:investment:project-item', kwargs={'pk': project.pk}) - request_data = { - 'investment_type': { - 'id': str(constants.InvestmentType.fdi.value.id) - } - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_400_BAD_REQUEST - response_data = response.json() - assert response_data == { - 'fdi_type': ['This field is required.'] - } - - def test_patch_project_success(self): - """Test successfully partially updating a project.""" - project = InvestmentProjectFactory( - client_contacts=[ContactFactory().id, ContactFactory().id] - ) - url = reverse('api-v3:investment:project-item', kwargs={'pk': project.pk}) - new_contact = ContactFactory() - request_data = { - 'name': 'new name', - 'description': 'new description', - 'client_contacts': [{ - 'id': str(new_contact.id) - }] - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['name'] == request_data['name'] - assert response_data['description'] == request_data['description'] - assert len(response_data['client_contacts']) == 1 - assert response_data['client_contacts'][0]['id'] == str(new_contact.id) - - def test_change_stage_assign_pm_failure(self): - """Tests moving an incomplete project to the Assign PM stage.""" - project = InvestmentProjectFactory() - url = reverse('api-v3:investment:project-item', kwargs={'pk': project.pk}) - request_data = { - 'stage': { - 'id': constants.InvestmentProjectStage.assign_pm.value.id - } - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_400_BAD_REQUEST - response_data = response.json() - assert response_data == { - 'client_cannot_provide_total_investment': [ - 'This field is required.'], - 'number_new_jobs': ['This field is required.'], - 'total_investment': ['This field is required.'], - 'client_considering_other_countries': ['This field is required.'], - 'client_requirements': ['This field is required.'], - 'site_decided': ['This field is required.'], - 'strategic_drivers': ['This field is required.'], - 'uk_region_locations': ['This field is required.'], - } - - def test_change_stage_assign_pm_success(self): - """Tests moving a complete project to the Assign PM stage.""" - strategic_drivers = [ - constants.InvestmentStrategicDriver.access_to_market.value.id - ] - project = InvestmentProjectFactory( - client_contacts=[ContactFactory().id, ContactFactory().id], - client_cannot_provide_total_investment=False, - total_investment=100, - number_new_jobs=0, - client_considering_other_countries=False, - client_requirements='client reqs', - site_decided=False, - strategic_drivers=strategic_drivers, - uk_region_locations=[constants.UKRegion.england.value.id] - ) - url = reverse('api-v3:investment:project-item', kwargs={'pk': project.pk}) - request_data = { - 'stage': { - 'id': constants.InvestmentProjectStage.assign_pm.value.id - } - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_200_OK - - def test_change_stage_active_failure(self): - """Tests moving an incomplete project to the Active stage.""" - project = InvestmentProjectFactory( - client_contacts=[ContactFactory().id, ContactFactory().id] - ) - url = reverse('api-v3:investment:project-item', kwargs={'pk': project.pk}) - request_data = { - 'stage': { - 'id': constants.InvestmentProjectStage.active.value.id - } - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_400_BAD_REQUEST - response_data = response.json() - assert response_data == { - 'client_cannot_provide_total_investment': [ - 'This field is required.'], - 'number_new_jobs': ['This field is required.'], - 'total_investment': ['This field is required.'], - 'client_considering_other_countries': ['This field is required.'], - 'client_requirements': ['This field is required.'], - 'site_decided': ['This field is required.'], - 'strategic_drivers': ['This field is required.'], - 'uk_region_locations': ['This field is required.'], - 'project_assurance_adviser': ['This field is required.'], - 'project_manager': ['This field is required.'], - } - - def test_change_stage_active_success(self): - """Tests moving a complete project to the Active stage.""" - adviser = AdviserFactory() - strategic_drivers = [ - constants.InvestmentStrategicDriver.access_to_market.value.id - ] - project = InvestmentProjectFactory( - client_contacts=[ContactFactory().id, ContactFactory().id], - client_cannot_provide_total_investment=False, - total_investment=100, - number_new_jobs=0, - client_considering_other_countries=False, - client_requirements='client reqs', - site_decided=False, - strategic_drivers=strategic_drivers, - uk_region_locations=[constants.UKRegion.england.value.id], - project_assurance_adviser=adviser, - project_manager=adviser - ) - url = reverse('api-v3:investment:project-item', kwargs={'pk': project.pk}) - request_data = { - 'stage': { - 'id': constants.InvestmentProjectStage.active.value.id - } - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_200_OK - - def test_get_value_success(self): - """Test successfully getting a project value object.""" - project = InvestmentProjectFactory( - client_cannot_provide_foreign_investment=False, - client_cannot_provide_total_investment=False, - total_investment=100, - foreign_equity_investment=100, - government_assistance=True, - number_new_jobs=0, - number_safeguarded_jobs=10, - r_and_d_budget=False, - non_fdi_r_and_d_budget=False, - new_tech_to_uk=False, - export_revenue=True - ) - url = reverse('api-v3:investment:value-item', kwargs={'pk': project.pk}) - response = self.api_client.get(url) - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert (response_data['client_cannot_provide_foreign_investment'] is - False) - assert (response_data['client_cannot_provide_total_investment'] is - False) - assert response_data['total_investment'] == '100' - assert response_data['foreign_equity_investment'] == '100' - assert response_data['government_assistance'] is True - assert response_data['total_investment'] == '100' - assert response_data['number_new_jobs'] == 0 - assert response_data['number_safeguarded_jobs'] == 10 - assert response_data['r_and_d_budget'] is False - assert response_data['non_fdi_r_and_d_budget'] is False - assert response_data['new_tech_to_uk'] is False - assert response_data['export_revenue'] is True - assert response_data['value_complete'] is True - - def test_patch_value_success(self): - """Test successfully partially updating a project value object.""" - salary_id = constants.SalaryRange.below_25000.value.id - project = InvestmentProjectFactory(total_investment=999, - number_new_jobs=100) - url = reverse('api-v3:investment:value-item', kwargs={'pk': project.pk}) - request_data = { - 'number_new_jobs': 555, - 'average_salary': {'id': salary_id}, - 'government_assistance': True - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['number_new_jobs'] == 555 - assert response_data['government_assistance'] is True - assert response_data['total_investment'] == '999' - assert response_data['value_complete'] is False - assert response_data['average_salary']['id'] == salary_id - - def test_get_requirements_success(self): - """Test successfully getting a project requirements object.""" - countries = [ - constants.Country.united_kingdom.value.id, - constants.Country.united_states.value.id - ] - strategic_drivers = [ - constants.InvestmentStrategicDriver.access_to_market.value.id - ] - uk_region_locations = [constants.UKRegion.england.value.id] - project = InvestmentProjectFactory( - client_requirements='client reqs', - site_decided=True, - address_line_1='address 1', - client_considering_other_countries=True, - competitor_countries=countries, - strategic_drivers=strategic_drivers, - uk_region_locations=uk_region_locations - ) - url = reverse('api-v3:investment:requirements-item', - kwargs={'pk': project.pk}) - response = self.api_client.get(url) - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['client_requirements'] == 'client reqs' - assert response_data['site_decided'] is True - assert response_data['client_considering_other_countries'] is True - assert response_data['requirements_complete'] is True - assert response_data['address_line_1'] == 'address 1' - assert sorted(country['id'] for country in response_data[ - 'competitor_countries']) == sorted(countries) - assert sorted(driver['id'] for driver in response_data[ - 'strategic_drivers']) == sorted(strategic_drivers) - - def test_patch_requirements_success(self): - """Test successfully partially updating a requirements object.""" - project = InvestmentProjectFactory(client_requirements='client reqs', - site_decided=True, - address_line_1='address 1') - url = reverse('api-v3:investment:requirements-item', - kwargs={'pk': project.pk}) - request_data = { - 'address_line_1': 'address 1 new', - 'address_line_2': 'address 2 new' - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['requirements_complete'] is False - assert response_data['client_requirements'] == 'client reqs' - assert response_data['site_decided'] is True - assert response_data['address_line_1'] == 'address 1 new' - assert response_data['address_line_2'] == 'address 2 new' - - def test_get_team_success(self): - """Test successfully getting a project requirements object.""" - crm_team = constants.Team.crm.value - huk_team = constants.Team.healthcare_uk.value - pm_adviser = AdviserFactory(dit_team_id=crm_team.id) - pa_adviser = AdviserFactory(dit_team_id=huk_team.id) - project = InvestmentProjectFactory( - project_manager_id=pm_adviser.id, - project_assurance_adviser_id=pa_adviser.id - ) - url = reverse('api-v3:investment:team-item', - kwargs={'pk': project.pk}) - response = self.api_client.get(url) - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data == { - 'project_manager': { - 'id': str(pm_adviser.pk), - 'first_name': pm_adviser.first_name, - 'last_name': pm_adviser.last_name - }, - 'project_assurance_adviser': { - 'id': str(pa_adviser.pk), - 'first_name': pa_adviser.first_name, - 'last_name': pa_adviser.last_name - }, - 'project_manager_team': { - 'id': str(crm_team.id), - 'name': crm_team.name - }, - 'project_assurance_team': { - 'id': str(huk_team.id), - 'name': huk_team.name - }, - 'team_members': [], - 'team_complete': True - } - - def test_get_team_empty(self): - """Test successfully getting an empty project requirements object.""" - project = InvestmentProjectFactory( - stage_id=constants.InvestmentProjectStage.assign_pm.value.id - ) - url = reverse('api-v3:investment:team-item', - kwargs={'pk': project.pk}) - response = self.api_client.get(url) - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data == { - 'project_manager': None, - 'project_assurance_adviser': None, - 'project_manager_team': None, - 'project_assurance_team': None, - 'team_complete': False, - 'team_members': [] - } - - def test_patch_team_success(self): - """Test successfully partially updating a requirements object.""" - crm_team = constants.Team.crm.value - huk_team = constants.Team.healthcare_uk.value - adviser_1 = AdviserFactory(dit_team_id=crm_team.id) - adviser_2 = AdviserFactory(dit_team_id=huk_team.id) - project = InvestmentProjectFactory( - project_manager_id=adviser_1.id, - project_assurance_adviser_id=adviser_2.id - ) - url = reverse('api-v3:investment:team-item', - kwargs={'pk': project.pk}) - request_data = { - 'project_manager': { - 'id': str(adviser_2.id) - } - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data == { - 'project_manager': { - 'id': str(adviser_2.pk), - 'first_name': adviser_2.first_name, - 'last_name': adviser_2.last_name - }, - 'project_assurance_adviser': { - 'id': str(adviser_2.pk), - 'first_name': adviser_2.first_name, - 'last_name': adviser_2.last_name - }, - 'project_manager_team': { - 'id': str(huk_team.id), - 'name': huk_team.name - }, - 'project_assurance_team': { - 'id': str(huk_team.id), - 'name': huk_team.name - }, - 'team_members': [], - 'team_complete': True - } +from datahub.core.test_utils import ( + APITestMixin, synchronous_executor_submit, synchronous_transaction_on_commit +) +from datahub.core.utils import executor +from datahub.documents.av_scan import virus_scan_document +from datahub.investment import views +from datahub.investment.models import InvestmentProjectTeamMember, IProjectDocument +from datahub.investment.test.factories import ( + InvestmentProjectFactory, InvestmentProjectTeamMemberFactory +) -class UnifiedViewsTestCase(LeelooTestCase): +class TestUnifiedViews(APITestMixin): """Tests for the unified investment views.""" def test_list_projects_success(self): @@ -797,6 +238,61 @@ def test_get_project_success(self): assert sorted(contact['id'] for contact in response_data[ 'client_contacts']) == sorted(contacts) + def test_create_project_conditional_failure(self): + """Test creating a project w/ missing conditionally required value.""" + contacts = [ContactFactory(), ContactFactory()] + investor_company = CompanyFactory() + intermediate_company = CompanyFactory() + adviser = AdviserFactory() + url = reverse('api-v3:investment:investment-collection') + aerospace_id = constants.Sector.aerospace_assembly_aircraft.value.id + retail_business_activity_id = constants.InvestmentBusinessActivity.retail.value.id + request_data = { + 'name': 'project name', + 'description': 'project description', + 'nda_signed': False, + 'estimated_land_date': '2020-12-12', + 'project_shareable': False, + 'investment_type': { + 'id': constants.InvestmentType.fdi.value.id + }, + 'stage': { + 'id': constants.InvestmentProjectStage.prospect.value.id + }, + 'business_activities': [{ + 'id': retail_business_activity_id + }], + 'client_contacts': [{ + 'id': str(contacts[0].id) + }, { + 'id': str(contacts[1].id) + }], + 'client_relationship_manager': { + 'id': str(adviser.id) + }, + 'investor_company': { + 'id': str(investor_company.id) + }, + 'intermediate_company': { + 'id': str(intermediate_company.id) + }, + 'referral_source_activity': { + 'id': constants.ReferralSourceActivity.cold_call.value.id + }, + 'referral_source_adviser': { + 'id': str(adviser.id) + }, + 'sector': { + 'id': str(aerospace_id) + } + } + response = self.api_client.post(url, data=request_data, format='json') + assert response.status_code == status.HTTP_400_BAD_REQUEST + response_data = response.json() + assert response_data == { + 'fdi_type': ['This field is required.'] + } + def test_patch_project_conditional_failure(self): """Test updating a project w/ missing conditionally required value.""" project = InvestmentProjectFactory( @@ -806,7 +302,8 @@ def test_patch_project_conditional_failure(self): request_data = { 'investment_type': { 'id': str(constants.InvestmentType.fdi.value.id) - } + }, + 'fdi_type': None } response = self.api_client.patch(url, data=request_data, format='json') assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -879,6 +376,24 @@ def test_patch_project_success(self): assert len(response_data['client_contacts']) == 1 assert response_data['client_contacts'][0]['id'] == str(new_contact.id) + def test_incomplete_fields_prospect(self): + """Tests moving an incomplete project to the Assign PM stage.""" + project = InvestmentProjectFactory() + url = reverse('api-v3:investment:investment-item', kwargs={'pk': project.pk}) + response = self.api_client.get(url) + assert response.status_code == status.HTTP_200_OK + response_data = response.json() + assert Counter(response_data['incomplete_fields']) == Counter(( + 'client_cannot_provide_total_investment', + 'number_new_jobs', + 'total_investment', + 'client_considering_other_countries', + 'client_requirements', + 'site_decided', + 'strategic_drivers', + 'uk_region_locations', + )) + def test_change_stage_assign_pm_failure(self): """Tests moving an incomplete project to the Assign PM stage.""" project = InvestmentProjectFactory() @@ -1067,6 +582,29 @@ def test_change_stage_verify_win_success(self): response = self.api_client.patch(url, data=request_data, format='json') assert response.status_code == status.HTTP_200_OK + def test_invalid_state_validation(self): + """Tests validation when a project that is in an invalid state. + + An invalid state means that fields that are required for the current stage have + not been completed. Generally, this should be impossible as those fields should've + been completed before moving to the current stage. Only the fields being modified + should be validated in this state (unless the stage is being modified). + """ + project = InvestmentProjectFactory( + stage_id=constants.InvestmentProjectStage.active.value.id, + project_manager=None + ) + url = reverse('api-v3:investment:investment-item', kwargs={'pk': project.pk}) + request_data = { + 'project_manager': None + } + response = self.api_client.patch(url, data=request_data, format='json') + assert response.status_code == status.HTTP_400_BAD_REQUEST + response_data = response.json() + assert response_data == { + 'project_manager': ['This field is required.'], + } + def test_get_value_success(self): """Test successfully getting a project value object.""" project = InvestmentProjectFactory( @@ -1270,53 +808,8 @@ def test_patch_team_success(self): assert response_data['team_members'] == [] assert response_data['team_complete'] is True - def test_get_phase_backwards_compatibility(self): - """Tests that phase works as an alias for stage with GET.""" - project = InvestmentProjectFactory( - ) - url = reverse('api-v3:investment:investment-item', kwargs={'pk': project.pk}) - response = self.api_client.get(url) - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['phase'] == { - 'id': constants.InvestmentProjectStage.prospect.value.id, - 'name': constants.InvestmentProjectStage.prospect.value.name - } - assert response_data['phase'] == response_data['stage'] - - def test_patch_phase_backwards_compatibility(self): - """Tests that phase works as an alias for stage with PATCH.""" - strategic_drivers = [ - constants.InvestmentStrategicDriver.access_to_market.value.id - ] - project = InvestmentProjectFactory( - client_contacts=[ContactFactory().id, ContactFactory().id], - client_cannot_provide_total_investment=False, - total_investment=100, - number_new_jobs=0, - client_considering_other_countries=False, - client_requirements='client reqs', - site_decided=False, - strategic_drivers=strategic_drivers, - uk_region_locations=[constants.UKRegion.england.value.id] - ) - url = reverse('api-v3:investment:investment-item', kwargs={'pk': project.pk}) - request_data = { - 'phase': { - 'id': constants.InvestmentProjectStage.assign_pm.value.id - } - } - response = self.api_client.patch(url, data=request_data, format='json') - assert response.status_code == status.HTTP_200_OK - response_data = response.json() - assert response_data['phase'] == { - 'id': constants.InvestmentProjectStage.assign_pm.value.id, - 'name': constants.InvestmentProjectStage.assign_pm.value.name - } - assert response_data['phase'] == response_data['stage'] - -class TeamMemberViewsTestCase(LeelooTestCase): +class TestTeamMemberViews(APITestMixin): """Tests for the team member views.""" def test_add_team_member_nonexistent_project(self): @@ -1499,13 +992,11 @@ def test_delete_team_member_success(self): assert str(new_team_members[0].adviser.pk) == team_members[1].adviser.pk -class AuditLogViewTestCase(LeelooTestCase): +class TestAuditLogView(APITestMixin): """Tests for the audit log view.""" def test_audit_log_view(self): """Test retrieval of audit log.""" - user = self.get_user() - initial_datetime = datetime.utcnow() with reversion.create_revision(): iproject = InvestmentProjectFactory( @@ -1514,7 +1005,7 @@ def test_audit_log_view(self): reversion.set_comment('Initial') reversion.set_date_created(initial_datetime) - reversion.set_user(user) + reversion.set_user(self.user) changed_datetime = datetime.utcnow() with reversion.create_revision(): @@ -1523,7 +1014,7 @@ def test_audit_log_view(self): reversion.set_comment('Changed') reversion.set_date_created(changed_datetime) - reversion.set_user(user) + reversion.set_user(self.user) url = reverse('api-v3:investment:audit-item', kwargs={'pk': iproject.pk}) @@ -1535,13 +1026,14 @@ def test_audit_log_view(self): assert len(response_data) == 1, 'Only one entry in audit log' entry = response_data[0] - assert entry['user']['name'] == user.name, 'Valid user captured' + assert entry['user']['name'] == self.user.name, 'Valid user captured' assert entry['comment'] == 'Changed', 'Comments can be set manually' assert entry['timestamp'] == changed_datetime.isoformat(), 'TS can be set manually' - assert entry['changes']['description'] == ['Initial desc', 'New desc'], 'Changes are reflected' + assert entry['changes']['description'] == ['Initial desc', 'New desc'], \ + 'Changes are reflected' -class ArchiveViewsTestCase(LeelooTestCase): +class TestArchiveViews(APITestMixin): """Tests for the archive and unarchive views.""" def test_archive_project_success(self): @@ -1615,7 +1107,7 @@ def test_unarchive_project_success(self): assert response_data['archived_reason'] == '' -class DocumentViewsTestCase(LeelooTestCase): +class TestDocumentViews(APITestMixin): """Tests for the document views.""" def test_documents_list_is_filtered_by_project(self): @@ -1700,6 +1192,41 @@ def test_document_upload_status(self, mock_submit): assert response.status_code == status.HTTP_200_OK mock_submit.assert_called_once_with(virus_scan_document, str(doc.pk)) + @patch.object(executor, 'submit') + @patch('datahub.core.utils.executor.submit', synchronous_executor_submit) + @patch('django.db.transaction.on_commit', synchronous_transaction_on_commit) + def test_document_delete_of_not_uploaded_doc_does_not_trigger_s3_delete(self, mock_submit): + """Tests document deletion.""" + project = InvestmentProjectFactory() + doc = IProjectDocument.create_from_declaration_request(project, 'fdi_type', 'test.txt') + + url = reverse('api-v3:investment:document-item', + kwargs={'project_pk': project.pk, 'doc_pk': doc.pk}) + + response = self.api_client.delete(url) + assert response.status_code == status.HTTP_204_NO_CONTENT + assert mock_submit.called is False + + @patch('datahub.core.utils.get_s3_client') + @patch('datahub.core.utils.executor.submit', synchronous_executor_submit) + @patch('django.db.transaction.on_commit', synchronous_transaction_on_commit) + def test_document_delete(self, mock_s3): + """Tests document deletion.""" + project = InvestmentProjectFactory() + doc = IProjectDocument.create_from_declaration_request(project, 'fdi_type', 'test.txt') + doc.document.uploaded_on = now() + doc.document.save() + + url = reverse('api-v3:investment:document-item', + kwargs={'project_pk': project.pk, 'doc_pk': doc.pk}) + + response = self.api_client.delete(url) + assert response.status_code == status.HTTP_204_NO_CONTENT + mock_s3().delete_object.assert_called_with( + Bucket=doc.document.s3_bucket, + Key=doc.document.s3_key, + ) + def test_document_upload_status_wrong_status(self): """Tests request validation in the document status endpoint.""" project = InvestmentProjectFactory() @@ -1727,11 +1254,7 @@ def test_document_upload_status_no_status(self): assert 'status' in response.json() -@pytest.mark.parametrize('view_set', (views.IProjectTeamViewSet, - views.IProjectRequirementsViewSet, - views.IProjectViewSet, - views.IProjectValueViewSet, - views.IProjectAuditViewSet)) +@pytest.mark.parametrize('view_set', (views.IProjectAuditViewSet,)) def test_view_set_name(view_set): """Test that the view name is a string.""" assert isinstance(view_set().get_view_name(), str) diff --git a/datahub/investment/urls.py b/datahub/investment/urls.py index 15dd69cea..f0ca89d8c 100644 --- a/datahub/investment/urls.py +++ b/datahub/investment/urls.py @@ -3,8 +3,7 @@ from django.conf.urls import url from datahub.investment.views import ( - IProjectAuditViewSet, IProjectDocumentViewSet, IProjectRequirementsViewSet, - IProjectTeamMembersViewSet, IProjectTeamViewSet, IProjectUnifiedViewSet, IProjectValueViewSet, + IProjectAuditViewSet, IProjectDocumentViewSet, IProjectTeamMembersViewSet, IProjectViewSet ) @@ -18,16 +17,6 @@ 'patch': 'partial_update' }) -unified_project_collection = IProjectUnifiedViewSet.as_view({ - 'get': 'list', - 'post': 'create' -}) - -unified_project_item = IProjectUnifiedViewSet.as_view({ - 'get': 'retrieve', - 'patch': 'partial_update' -}) - project_team_member_collection = IProjectTeamMembersViewSet.as_view({ 'post': 'create', 'delete': 'destroy_all' @@ -39,21 +28,6 @@ 'delete': 'destroy' }) -value_item = IProjectValueViewSet.as_view({ - 'get': 'retrieve', - 'patch': 'partial_update' -}) - -requirements_item = IProjectRequirementsViewSet.as_view({ - 'get': 'retrieve', - 'patch': 'partial_update' -}) - -team_item = IProjectTeamViewSet.as_view({ - 'get': 'retrieve', - 'patch': 'partial_update' -}) - audit_item = IProjectAuditViewSet.as_view({ 'get': 'retrieve', }) @@ -73,6 +47,7 @@ project_document_item = IProjectDocumentViewSet.as_view({ 'get': 'retrieve', + 'delete': 'destroy', }) project_document_callback = IProjectDocumentViewSet.as_view({ @@ -80,10 +55,9 @@ }) urlpatterns = [ - url(r'^investment$', unified_project_collection, name='investment-collection'), - url(r'^investment/(?P[0-9a-z-]{36})$', unified_project_item, + url(r'^investment$', project_collection, name='investment-collection'), + url(r'^investment/(?P[0-9a-z-]{36})$', project_item, name='investment-item'), - url(r'^investment/project$', project_collection, name='project'), url(r'^investment/(?P[0-9a-z-]{36})/archive$', archive_item, name='archive-item'), url(r'^investment/(?P[0-9a-z-]{36})/team-member$', project_team_member_collection, @@ -94,18 +68,10 @@ name='document-collection'), url(r'^investment/(?P[0-9a-z-]{36})/document/(?P[0-9a-z-]{36})$', project_document_item, name='document-item'), - url(r'^investment/(?P[0-9a-z-]{36})/document/(?P[0-9a-z-]{36})/upload-callback$', - project_document_callback, name='document-item-callback'), + url(r'^investment/(?P[0-9a-z-]{36})/document/(?P[0-9a-z-]{36})/' + r'upload-callback$', project_document_callback, name='document-item-callback'), url(r'^investment/(?P[0-9a-z-]{36})/unarchive$', unarchive_item, name='unarchive-item'), - url(r'^investment/(?P[0-9a-z-]{36})/project$', project_item, - name='project-item'), - url(r'^investment/(?P[0-9a-z-]{36})/value$', value_item, - name='value-item'), - url(r'^investment/(?P[0-9a-z-]{36})/requirements$', requirements_item, - name='requirements-item'), - url(r'^investment/(?P[0-9a-z-]{36})/team$', team_item, - name='team-item'), url(r'^investment/(?P[0-9a-z-]{36})/audit$', audit_item, name='audit-item'), ] diff --git a/datahub/investment/views.py b/datahub/investment/views.py index 9abed52df..04860ab58 100644 --- a/datahub/investment/views.py +++ b/datahub/investment/views.py @@ -1,11 +1,13 @@ """Investment views.""" +from django.db import transaction from django.http import Http404 from django.shortcuts import get_object_or_404 from django_filters.rest_framework import DjangoFilterBackend -from rest_framework import mixins, status +from rest_framework import status from rest_framework.response import Response from datahub.core.mixins import ArchivableViewSetMixin +from datahub.core.serializers import AuditSerializer from datahub.core.utils import executor from datahub.core.viewsets import CoreViewSetV3 from datahub.documents.av_scan import virus_scan_document @@ -13,51 +15,15 @@ InvestmentProject, InvestmentProjectTeamMember, IProjectDocument ) from datahub.investment.serializers import ( - IProjectAuditSerializer, IProjectDocumentSerializer, IProjectRequirementsSerializer, - IProjectSerializer, IProjectTeamMemberSerializer, IProjectTeamSerializer, - IProjectUnifiedSerializer, IProjectValueSerializer, UploadStatusSerializer + IProjectDocumentSerializer, IProjectSerializer, IProjectTeamMemberSerializer, + UploadStatusSerializer ) -class IProjectViewSet(ArchivableViewSetMixin, CoreViewSetV3): - """Investment project views. - - This is a subset of the fields on an InvestmentProject object. - - Deprecated. - """ - - serializer_class = IProjectSerializer - queryset = InvestmentProject.objects.select_related( - 'archived_by', - 'investment_type', - 'stage', - 'investor_company', - 'intermediate_company', - 'client_relationship_manager', - 'referral_source_adviser', - 'referral_source_activity', - 'referral_source_activity_website', - 'referral_source_activity_marketing', - 'fdi_type', - 'non_fdi_type', - 'sector' - ).prefetch_related( - 'client_contacts', - 'business_activities' - ) - filter_backends = (DjangoFilterBackend,) - filter_fields = ('investor_company_id',) - - def get_view_name(self): - """Returns the view set name for the DRF UI.""" - return 'Investment projects' - - class IProjectAuditViewSet(CoreViewSetV3): """Investment Project audit views.""" - serializer_class = IProjectAuditSerializer + serializer_class = AuditSerializer queryset = InvestmentProject.objects.all() def get_view_name(self): @@ -65,70 +31,13 @@ def get_view_name(self): return 'Investment project audit log' -class IProjectValueViewSet(CoreViewSetV3): - """Investment project value views. - - This is a subset of the fields on an InvestmentProject object. - - Deprecated. - """ - - serializer_class = IProjectValueSerializer - queryset = InvestmentProject.objects.select_related('average_salary') - - def get_view_name(self): - """Returns the view set name for the DRF UI.""" - return 'Investment project values' - - -class IProjectRequirementsViewSet(CoreViewSetV3): - """Investment project requirements views. - - This is a subset of the fields on an InvestmentProject object. - - Deprecated. - """ - - serializer_class = IProjectRequirementsSerializer - queryset = InvestmentProject.objects.prefetch_related( - 'competitor_countries', - 'uk_region_locations', - 'strategic_drivers' - ) - - def get_view_name(self): - """Returns the view set name for the DRF UI.""" - return 'Investment project requirements' - - -class IProjectTeamViewSet(CoreViewSetV3): - """Investment project team views. - - This is a subset of the fields on an InvestmentProject object. - - Deprecated. - """ - - serializer_class = IProjectTeamSerializer - queryset = InvestmentProject.objects.select_related( - 'project_manager', - 'project_manager__dit_team', - 'project_assurance_adviser', - 'project_assurance_adviser__dit_team' - ) - - def get_view_name(self): - """Returns the view set name for the DRF UI.""" - return 'Investment project teams' - - -class IProjectUnifiedViewSet(ArchivableViewSetMixin, CoreViewSetV3): +class IProjectViewSet(ArchivableViewSetMixin, CoreViewSetV3): """Unified investment project views. This replaces the previous project, value, team and requirements endpoints. """ - serializer_class = IProjectUnifiedSerializer + serializer_class = IProjectSerializer queryset = InvestmentProject.objects.select_related( 'archived_by', 'investment_type', @@ -163,7 +72,7 @@ def get_view_name(self): return 'Investment projects' -class IProjectTeamMembersViewSet(mixins.DestroyModelMixin, CoreViewSetV3): +class IProjectTeamMembersViewSet(CoreViewSetV3): """Investment project team member views.""" serializer_class = IProjectTeamMemberSerializer @@ -222,7 +131,9 @@ class IProjectDocumentViewSet(CoreViewSetV3): def list(self, request, *args, **kwargs): """Custom pre-filtered list.""" - queryset = self.filter_queryset(self.get_queryset().filter(project_id=self.kwargs['project_pk'])) + queryset = self.filter_queryset(self.get_queryset().filter( + project_id=self.kwargs['project_pk']) + ) page = self.paginate_queryset(queryset) if page is not None: @@ -256,6 +167,11 @@ def upload_complete_callback(self, request, *args, **kwargs): }, ) + def perform_destroy(self, instance): + """Perform destroy in transaction/savepoint mode.""" + with transaction.atomic(): + return super().perform_destroy(instance) + def get_object(self): """Ensures that object lookup honors the project pk.""" queryset = self.get_queryset().filter(project__id=self.kwargs['project_pk']) diff --git a/datahub/korben/spec.py b/datahub/korben/spec.py index 408b3b566..d786f96ea 100644 --- a/datahub/korben/spec.py +++ b/datahub/korben/spec.py @@ -114,7 +114,9 @@ ('optevia_TurnoverRange.Id', 'turnover_range_id'), ), concat=( - (('optevia_Address2', 'optevia_Address3', 'optevia_Address4'), 'registered_address_2', ', '), + (('optevia_Address2', 'optevia_Address3', 'optevia_Address4'), + 'registered_address_2', + ', '), ), ), Mapping( diff --git a/datahub/korben/utils.py b/datahub/korben/utils.py index 5ffaf461d..956cc3149 100644 --- a/datahub/korben/utils.py +++ b/datahub/korben/utils.py @@ -59,7 +59,9 @@ def fkey_deps(models): def cdms_datetime_to_datetime(value): - """Parses a cdms datetime as string and returns the equivalent datetime value. Dates in CDMS are always UTC.""" + """Parses a cdms datetime as string and returns the equivalent datetime value. + Dates in CDMS are always UTC. + """ if not value: return None diff --git a/datahub/leads/test/test_views.py b/datahub/leads/test/test_views.py index 4c278833a..cbe5f338d 100644 --- a/datahub/leads/test/test_views.py +++ b/datahub/leads/test/test_views.py @@ -6,13 +6,13 @@ from rest_framework import status from rest_framework.reverse import reverse -from datahub.core.test_utils import LeelooTestCase +from datahub.core.test_utils import APITestMixin from datahub.leads.test.factories import BusinessLeadFactory FROZEN_TIME = '2017-04-18T13:25:30.986208' -class BusinessLeadViewsTestCase(LeelooTestCase): +class TestBusinessLeadViews(APITestMixin): """Business lead views test case.""" def test_list_leads_success(self): diff --git a/datahub/metadata/test/test_views.py b/datahub/metadata/test/test_views.py index 649dd2c2b..bac638b55 100644 --- a/datahub/metadata/test/test_views.py +++ b/datahub/metadata/test/test_views.py @@ -36,7 +36,6 @@ 'investment-business-activity', 'investment-strategic-driver', 'salary-range', - 'investment-project-phase', 'investment-project-stage', ) @@ -65,7 +64,6 @@ 'investment business activity view', 'investment strategic driver view', 'salary range view', - 'investment project phase view', 'investment project stage view', ) @@ -139,13 +137,6 @@ def test_view_name_generation(): '£30,000 – £34,000', '£35,000 and above' ]), - ('investment-project-phase', [ - 'Prospect', - 'Assign PM', - 'Active', - 'Verify win', - 'Won' - ]), ('investment-project-stage', [ 'Prospect', 'Assign PM', @@ -159,7 +150,6 @@ def test_view_name_generation(): 'turnover', 'employee-range', 'salary-range', - 'investment-project-phase', 'investment-project-stage', ) diff --git a/datahub/metadata/views.py b/datahub/metadata/views.py index ed64b2837..5c6f17cc3 100644 --- a/datahub/metadata/views.py +++ b/datahub/metadata/views.py @@ -31,8 +31,6 @@ 'investment-business-activity': models.InvestmentBusinessActivity, 'investment-strategic-driver': models.InvestmentStrategicDriver, 'salary-range': models.SalaryRange, - # deprecated alias for investment-project-stage - 'investment-project-phase': models.InvestmentProjectStage, 'investment-project-stage': models.InvestmentProjectStage, } diff --git a/datahub/omis/__init__.py b/datahub/omis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/datahub/omis/order/__init__.py b/datahub/omis/order/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/datahub/omis/order/apps.py b/datahub/omis/order/apps.py new file mode 100644 index 000000000..9fe88f97f --- /dev/null +++ b/datahub/omis/order/apps.py @@ -0,0 +1,9 @@ +from django.apps import AppConfig + + +class OrderConfig(AppConfig): + """ + Django App Config for the Order app. + """ + + name = 'order' diff --git a/datahub/omis/order/migrations/0001_initial.py b/datahub/omis/order/migrations/0001_initial.py new file mode 100644 index 000000000..31c54f434 --- /dev/null +++ b/datahub/omis/order/migrations/0001_initial.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.2 on 2017-07-13 15:36 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('company', '0007_remove_mptt'), + ('metadata', '0002_rename_phase_to_stage'), + ] + + operations = [ + migrations.CreateModel( + name='Order', + fields=[ + ('created_on', models.DateTimeField(auto_now_add=True, null=True)), + ('modified_on', models.DateTimeField(auto_now=True, null=True)), + ('id', models.UUIDField(db_index=True, default=uuid.uuid4, primary_key=True, serialize=False)), + ('reference', models.CharField(max_length=100)), + ('company', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='orders', to='company.Company')), + ('contact', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='orders', to='company.Contact')), + ('primary_market', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='orders', to='metadata.Country')), + ], + options={ + 'abstract': False, + }, + ), + ] diff --git a/datahub/omis/order/migrations/__init__.py b/datahub/omis/order/migrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/datahub/omis/order/models.py b/datahub/omis/order/models.py new file mode 100644 index 000000000..e629bb78f --- /dev/null +++ b/datahub/omis/order/models.py @@ -0,0 +1,74 @@ +import uuid + +from django.db import models +from django.utils.crypto import get_random_string +from django.utils.timezone import now + +from datahub.company.models import Company, Contact +from datahub.core.models import BaseModel + +from datahub.metadata.models import Country + + +class Order(BaseModel): + """ + Details regarding an OMIS Order. + """ + + id = models.UUIDField(primary_key=True, db_index=True, default=uuid.uuid4) + reference = models.CharField(max_length=100) + + company = models.ForeignKey( + Company, + related_name="%(class)ss", # noqa: Q000 + on_delete=models.PROTECT, + ) + contact = models.ForeignKey( + Contact, + related_name="%(class)ss", # noqa: Q000 + on_delete=models.PROTECT + ) + + primary_market = models.ForeignKey( + Country, + related_name="%(class)ss", # noqa: Q000 + null=True, + on_delete=models.SET_NULL + ) + + def __str__(self): + """Human-readable representation""" + return self.reference + + def _calculate_reference(self): + """ + Returns a random unused reference of form: + <(3) letters><(3) numbers>/ e.g. GEA962/16 + or RuntimeError if no reference can be generated. + """ + year_suffix = now().strftime('%y') + manager = self.__class__.objects + + max_retries = 10 + tries = 0 + while tries < max_retries: + reference = '{letters}{numbers}/{year}'.format( + letters=get_random_string(length=3, allowed_chars='ACEFHJKMNPRTUVWXY'), + numbers=get_random_string(length=3, allowed_chars='123456789'), + year=year_suffix + ) + if not manager.filter(reference=reference).exists(): + return reference + tries += 1 + + # This should never happen as we have 3.5 milion choices per year + # and it's basically unrealistic to have more than 10 collisions. + raise RuntimeError('Cannot generate random reference') + + def save(self, *args, **kwargs): + """ + Like the django save but it creates a reference if it doesn't exist. + """ + if not self.reference: + self.reference = self._calculate_reference() + return super().save(*args, **kwargs) diff --git a/datahub/omis/order/serializers.py b/datahub/omis/order/serializers.py new file mode 100644 index 000000000..4958396ae --- /dev/null +++ b/datahub/omis/order/serializers.py @@ -0,0 +1,38 @@ +from rest_framework import serializers + +from datahub.company.models import Company, Contact +from datahub.core.serializers import NestedRelatedField +from datahub.metadata.models import Country + +from .models import Order + + +class OrderSerializer(serializers.ModelSerializer): + """Order DRF serializer""" + + id = serializers.UUIDField(read_only=True) + reference = serializers.CharField(read_only=True) + company = NestedRelatedField(Company) + contact = NestedRelatedField(Contact) + primary_market = NestedRelatedField(Country) + + class Meta: # noqa: D101 + model = Order + fields = [ + 'id', + 'reference', + 'company', + 'contact', + 'primary_market' + ] + + def validate(self, data): + """ + Extra check that a contact works at the given company. + """ + if data['contact'].company != data['company']: + raise serializers.ValidationError({ + 'contact': 'The contact does not work at the given company.' + }) + + return data diff --git a/datahub/omis/order/test/__init__.py b/datahub/omis/order/test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/datahub/omis/order/test/factories.py b/datahub/omis/order/test/factories.py new file mode 100644 index 000000000..286d0350a --- /dev/null +++ b/datahub/omis/order/test/factories.py @@ -0,0 +1,20 @@ +"""Model instance factories for order tests.""" + +import uuid + +import factory + +from datahub.company.test.factories import CompanyFactory, ContactFactory +from datahub.core.constants import Country + + +class OrderFactory(factory.django.DjangoModelFactory): + """Order factory.""" + + id = factory.LazyFunction(lambda: str(uuid.uuid4())) + company = factory.SubFactory(CompanyFactory) + contact = factory.SubFactory(ContactFactory) + primary_market_id = Country.france.value.id + + class Meta: + model = 'order.Order' diff --git a/datahub/omis/order/test/test_models.py b/datahub/omis/order/test/test_models.py new file mode 100644 index 000000000..74c055089 --- /dev/null +++ b/datahub/omis/order/test/test_models.py @@ -0,0 +1,66 @@ +from unittest import mock + +import pytest +from freezegun import freeze_time + +from datahub.omis.order.test.factories import OrderFactory + +pytestmark = pytest.mark.django_db + + +class TestOrder: + """ + Tests for the Order model. + """ + + @freeze_time('2017-07-12 13:00:00.000000+00:00') + @mock.patch('datahub.omis.order.models.get_random_string') + def test_generates_reference_if_doesnt_exist(self, mock_get_random_string): + """ + Test that if an Order is saved without reference, the system generates one automatically. + """ + mock_get_random_string.side_effect = [ + 'ABC', '123', 'CBA', '321' + ] + + # create 1st + order = OrderFactory() + assert order.reference == 'ABC123/17' + + # create 2nd + order = OrderFactory() + assert order.reference == 'CBA321/17' + + @freeze_time('2017-07-12 13:00:00.000000+00:00') + @mock.patch('datahub.omis.order.models.get_random_string') + def test_doesnt_generate_reference_if_present(self, mock_get_random_string): + """ + Test that when creating a new Order, if the system generates a reference that already + exists, it skips it and generates the next one. + """ + # create existing Order with ref == 'ABC123/17' + OrderFactory(reference='ABC123/17') + + mock_get_random_string.side_effect = [ + 'ABC', '123', 'CBA', '321' + ] + + # ABC123/17 already exists so create CBA321/17 instead + order = OrderFactory() + assert order.reference == 'CBA321/17' + + @freeze_time('2017-07-12 13:00:00.000000+00:00') + @mock.patch('datahub.omis.order.models.get_random_string') + def test_cannot_generate_reference(self, mock_get_random_string): + """ + Test that if there are more than 10 collisions, the generator algorithm raises a + RuntimeError. + """ + max_retries = 10 + OrderFactory(reference='ABC123/17') + + mock_get_random_string.side_effect = ['ABC', '123'] * max_retries + + with pytest.raises(RuntimeError): + for index in range(max_retries): + OrderFactory() diff --git a/datahub/omis/order/test/test_views.py b/datahub/omis/order/test/test_views.py new file mode 100644 index 000000000..b8d879143 --- /dev/null +++ b/datahub/omis/order/test/test_views.py @@ -0,0 +1,135 @@ +import pytest +from freezegun import freeze_time +from rest_framework import status +from rest_framework.reverse import reverse + +from datahub.company.test.factories import CompanyFactory, ContactFactory +from datahub.core import constants +from datahub.core.test_utils import APITestMixin + +from .factories import OrderFactory + +# mark the whole module for db use +pytestmark = pytest.mark.django_db + + +class TestAddOrder(APITestMixin): + """Add Order test case.""" + + @freeze_time('2017-04-18 13:00:00.000000+00:00') + def test_success(self): + """ + Test a successful call to create an Order. + """ + company = CompanyFactory() + contact = ContactFactory(company=company) + country = constants.Country.france.value + + url = reverse('api-v3:omis:order:list') + response = self.api_client.post(url, { + 'company': { + 'id': company.pk + }, + 'contact': { + 'id': contact.pk + }, + 'primary_market': { + 'id': country.id + }, + }, format='json') + + assert response.status_code == status.HTTP_201_CREATED + assert response.json() == { + 'id': response.json()['id'], + 'reference': response.json()['reference'], + 'company': { + 'id': company.pk, + 'name': company.name + }, + 'contact': { + 'id': contact.pk, + 'name': contact.name + }, + 'primary_market': { + 'id': country.id, + 'name': country.name, + } + } + + def test_fails_if_contact_not_from_company(self): + """ + Test that if the contact does not work at the company specified, the validation fails. + """ + company = CompanyFactory() + contact = ContactFactory() # doesn't work at `company` + country = constants.Country.france.value + + url = reverse('api-v3:omis:order:list') + response = self.api_client.post(url, { + 'company': { + 'id': company.pk + }, + 'contact': { + 'id': contact.pk + }, + 'primary_market': { + 'id': country.id + } + }, format='json') + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json() == { + 'contact': ['The contact does not work at the given company.'] + } + + def test_general_validation(self): + """ + Test create an Order general validation. + """ + url = reverse('api-v3:omis:order:list') + response = self.api_client.post(url, {}, format='json') + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.json() == { + 'company': ['This field is required.'], + 'contact': ['This field is required.'], + 'primary_market': ['This field is required.'] + } + + +class TestViewOrder(APITestMixin): + """View order test case.""" + + def test_get(self): + """Test getting an existing order.""" + order = OrderFactory() + + url = reverse('api-v3:omis:order:detail', kwargs={'pk': order.pk}) + response = self.api_client.get(url) + + assert response.status_code == status.HTTP_200_OK + assert response.json() == { + 'id': order.id, + 'reference': order.reference, + 'company': { + 'id': str(order.company.id), + 'name': order.company.name + }, + 'contact': { + 'id': str(order.contact.id), + 'name': order.contact.name + }, + 'primary_market': { + 'id': str(order.primary_market.id), + 'name': order.primary_market.name + } + } + + def test_not_found(self): + """Test 404 when getting a non-existing order""" + url = reverse('api-v3:omis:order:detail', kwargs={ + 'pk': '00000000-0000-0000-0000-000000000000' + }) + response = self.api_client.get(url) + + assert response.status_code == status.HTTP_404_NOT_FOUND diff --git a/datahub/omis/order/urls.py b/datahub/omis/order/urls.py new file mode 100644 index 000000000..d253f0003 --- /dev/null +++ b/datahub/omis/order/urls.py @@ -0,0 +1,19 @@ +"""Company views URL config.""" + +from django.conf.urls import url + +from .views import OrderViewSet + + +order_collection = OrderViewSet.as_view({ + 'post': 'create' +}) + +order_item = OrderViewSet.as_view({ + 'get': 'retrieve' +}) + +urlpatterns = [ + url(r'^order$', order_collection, name='list'), + url(r'^order/(?P[0-9a-z-]{36})$', order_item, name='detail'), +] diff --git a/datahub/omis/order/views.py b/datahub/omis/order/views.py new file mode 100644 index 000000000..3f2dba966 --- /dev/null +++ b/datahub/omis/order/views.py @@ -0,0 +1,15 @@ +from datahub.core.viewsets import CoreViewSetV3 + +from .models import Order +from .serializers import OrderSerializer + + +class OrderViewSet(CoreViewSetV3): + """Order ViewSet""" + + serializer_class = OrderSerializer + queryset = Order.objects.select_related( + 'company', + 'contact', + 'primary_market', + ) diff --git a/datahub/omis/urls.py b/datahub/omis/urls.py new file mode 100644 index 000000000..36d57691a --- /dev/null +++ b/datahub/omis/urls.py @@ -0,0 +1,7 @@ +from django.conf.urls import include, url + +from .order import urls as order_urls + +urlpatterns = [ + url(r'^', include((order_urls, 'order'), namespace='order')), +] diff --git a/datahub/search/elasticsearch.py b/datahub/search/elasticsearch.py index 85262e075..d6fd6b513 100644 --- a/datahub/search/elasticsearch.py +++ b/datahub/search/elasticsearch.py @@ -46,7 +46,29 @@ def get_search_term_query(term): ]) -def get_basic_search_query(term, entities=('company',), offset=0, limit=100): +def remap_sort_field(field): + """Replaces fields to aliases suitable for sorting.""" + name_map = { + 'name': 'name_keyword', + } + return name_map.get(field, field) + + +def get_sort_query(qs, field_order=None): + """Attaches sort query.""" + if field_order is None: + return qs + + tokens = field_order.rsplit(':', maxsplit=1) + order = tokens[1] if len(tokens) > 1 else 'asc' + + qs = qs.sort({ + remap_sort_field(tokens[0]): {'order': order} + }) + return qs + + +def get_basic_search_query(term, entities=('company',), field_order=None, offset=0, limit=100): """Performs basic search looking for name and then _all in entity. Also returns number of results in other entities. @@ -56,6 +78,7 @@ def get_basic_search_query(term, entities=('company',), offset=0, limit=100): s = s.post_filter( Q('bool', should=[Q('term', _type=entity) for entity in entities]) ) + s = get_sort_query(s, field_order=field_order) s.aggs.bucket( 'count_by_type', 'terms', field='_type' @@ -64,7 +87,13 @@ def get_basic_search_query(term, entities=('company',), offset=0, limit=100): return s[offset:offset + limit] -def get_search_by_entity_query(term=None, filters=None, entity=None, ranges=None, offset=0, limit=100): +def get_search_by_entity_query(term=None, + filters=None, + entity=None, + ranges=None, + field_order=None, + offset=0, + limit=100): """Perform filtered search for given terms in given entity.""" query = [Q('term', _type=entity)] if term != '': @@ -90,6 +119,8 @@ def get_search_by_entity_query(term=None, filters=None, entity=None, ranges=None ) s = Search(index=settings.ES_INDEX).query('bool', must=query) + s = get_sort_query(s, field_order=field_order) + s = s.post_filter('bool', must=query_filter) return s[offset:offset + limit] diff --git a/datahub/search/management/commands/sync_es.py b/datahub/search/management/commands/sync_es.py index 0ce55db41..d03ab9bf8 100644 --- a/datahub/search/management/commands/sync_es.py +++ b/datahub/search/management/commands/sync_es.py @@ -19,10 +19,12 @@ def get_dataset(): """Returns dataset that will be synchronised with Elasticsearch.""" - company_prefetch_fields = ('registered_address_country', 'business_type', 'sector', 'employee_range', - 'turnover_range', 'account_manager', 'export_to_countries', 'future_interest_countries', - 'trading_address_country', 'headquarter_type', 'classification', - 'one_list_account_owner',) + company_prefetch_fields = ( + 'registered_address_country', 'business_type', 'sector', 'employee_range', + 'turnover_range', 'account_manager', 'export_to_countries', 'future_interest_countries', + 'trading_address_country', 'headquarter_type', 'classification', + 'one_list_account_owner', + ) company_qs = Company.objects.prefetch_related(*company_prefetch_fields).all().order_by('pk') contact_qs = Contact.objects.all().order_by('pk') @@ -62,7 +64,8 @@ def sync_dataset(item, batch_size=1, stdout=None): rows_processed += num_actions batches_processed += 1 if stdout and batches_processed % 100 == 0: - stdout.write(f'Rows processed: {rows_processed}/{total_rows} {rows_processed*100//total_rows}%') + stdout.write(f'Rows processed: {rows_processed}/{total_rows} ' + f'{rows_processed*100//total_rows}%') if stdout: stdout.write(f'Rows processed: {rows_processed}/{total_rows} 100%. Done!') diff --git a/datahub/search/models.py b/datahub/search/models.py index 4a96c398b..4d9dfc0a0 100644 --- a/datahub/search/models.py +++ b/datahub/search/models.py @@ -48,13 +48,16 @@ def _contact_mapping(field): return Nested(properties={'id': String(index='not_analyzed'), 'first_name': String(copy_to=f'{field}.name'), 'last_name': String(copy_to=f'{field}.name'), - 'name': String(), + 'name': String(index='not_analyzed'), }) def _id_name_mapping(): """Mapping for id name fields.""" - return Nested(properties={'id': String(index='not_analyzed'), 'name': String()}) + return Nested(properties={ + 'id': String(index='not_analyzed'), + 'name': String(index='not_analyzed') + }) def _id_uri_mapping(): @@ -98,7 +101,8 @@ def dbmodel_to_dict(cls, dbmodel): result = {col: fn(getattr(dbmodel, col)) for col, fn in cls.MAPPINGS.items() if getattr(dbmodel, col, None) is not None} - fields = [field for field in dbmodel._meta.get_fields() if field.name not in cls.IGNORED_FIELDS] + fields = [field for field in dbmodel._meta.get_fields() + if field.name not in cls.IGNORED_FIELDS] obj = {f.name: getattr(dbmodel, f.name) for f in fields if f.name not in result} result.update(obj.items()) @@ -124,7 +128,7 @@ class Company(DocType, MapDBModelToDict): archived_reason = String() business_type = _id_name_mapping() classification = _id_name_mapping() - company_number = String() + company_number = String(index='not_analyzed') companies_house_data = _company_mapping() created_on = Date() description = String() @@ -141,14 +145,14 @@ class Company(DocType, MapDBModelToDict): registered_address_country = _id_name_mapping() registered_address_county = String() registered_address_postcode = String() - registered_address_town = String() + registered_address_town = String(index='not_analyzed') sector = _id_name_mapping() trading_address_1 = String() trading_address_2 = String() trading_address_country = _id_name_mapping() trading_address_county = String() trading_address_postcode = String() - trading_address_town = String() + trading_address_town = String(index='not_analyzed') turnover_range = _id_name_mapping() uk_region = _id_name_mapping() uk_based = Boolean() @@ -183,6 +187,7 @@ class Company(DocType, MapDBModelToDict): 'children', 'servicedeliverys', 'investor_investment_projects', 'intermediate_investment_projects', 'investee_projects', 'tree_id', 'lft', 'rght', 'business_leads', 'interactions', + 'orders' ) class Meta: @@ -207,19 +212,19 @@ class Contact(DocType, MapDBModelToDict): first_name = String(copy_to='name') last_name = String(copy_to='name') primary = Boolean() - telephone_countrycode = String() - telephone_number = String() - email = String() + telephone_countrycode = String(index='not_analyzed') + telephone_number = String(index='not_analyzed') + email = String(index='not_analyzed') address_same_as_company = Boolean() address_1 = String() address_2 = String() - address_town = String() - address_county = String() + address_town = String(index='not_analyzed') + address_county = String(index='not_analyzed') address_postcode = String() telephone_alternative = String() email_alternative = String() notes = String() - job_title = String() + job_title = String(index='not_analyzed') contactable_by_dit = Boolean() contactable_by_dit_partners = Boolean() contactable_by_email = Boolean() @@ -239,7 +244,7 @@ class Contact(DocType, MapDBModelToDict): } IGNORED_FIELDS = ( - 'interactions', 'servicedeliverys', 'investment_projects', + 'interactions', 'servicedeliverys', 'investment_projects', 'orders' ) class Meta: @@ -302,7 +307,7 @@ class InvestmentProject(DocType, MapDBModelToDict): referral_source_activity = _id_name_mapping() referral_source_activity_marketing = _id_name_mapping() referral_source_activity_website = _id_name_mapping() - referral_source_activity_event = String() + referral_source_activity_event = String(index='not_analyzed') referral_source_advisor = _contact_mapping('referral_source_advisor') sector = _id_name_mapping() average_salary = _id_name_mapping() @@ -313,7 +318,7 @@ class InvestmentProject(DocType, MapDBModelToDict): 'business_activities': lambda col: [_id_name_dict(c) for c in col.all()], 'client_contacts': lambda col: [_contact_dict(c) for c in col.all()], 'client_relationship_manager': _id_name_dict, - 'team_members': lambda col: [_contact_dict(c) for c in col.all()], + 'team_members': lambda col: [_contact_dict(c.adviser) for c in col.all()], 'fdi_type': _id_name_dict, 'fdi_type_documents': lambda col: [_id_uri_dict(c) for c in col.all()], 'intermediate_company': _id_name_dict, diff --git a/datahub/search/test/test_elasticsearch.py b/datahub/search/test/test_elasticsearch.py index d6c4ad369..a43203009 100644 --- a/datahub/search/test/test_elasticsearch.py +++ b/datahub/search/test/test_elasticsearch.py @@ -161,7 +161,8 @@ def test_search_by_entity_query(): 'path': 'trading_address_country', 'query': { 'term': { - 'trading_address_country.id': '80756b9a-5d95-e211-a939-e4115bead28a' + 'trading_address_country.id': + '80756b9a-5d95-e211-a939-e4115bead28a' } } } @@ -239,6 +240,16 @@ def test_remap_fields(): assert remapped['uk_based'] is False +def test_remap_sort_field(): + """Test sort fields remapping.""" + fields = { + 'name': 'name_keyword' + } + + for key, value in fields.items(): + assert elasticsearch.remap_sort_field(key) == value + + def test_date_range_fields(): """Tests date range fields.""" now = '2017-06-13T09:44:31.062870' diff --git a/datahub/search/test/test_views.py b/datahub/search/test/test_views.py index 5aec1e09d..a958b6546 100644 --- a/datahub/search/test/test_views.py +++ b/datahub/search/test/test_views.py @@ -9,14 +9,14 @@ from datahub.company.test.factories import CompanyFactory from datahub.core import constants from datahub.core.test_utils import ( - LeelooTestCase, synchronous_executor_submit, synchronous_transaction_on_commit, + APITestMixin, synchronous_executor_submit, synchronous_transaction_on_commit, ) pytestmark = pytest.mark.django_db @pytest.mark.usefixtures('setup_data', 'post_save_handlers') -class SearchTestCase(LeelooTestCase): +class TestSearch(APITestMixin): """Tests search views.""" def test_basic_search_all_companies(self): @@ -134,23 +134,25 @@ def test_search_company(self): term = 'abc defg' url = f"{reverse('api-v3:search:company')}?offset=0&limit=100" + united_states_id = constants.Country.united_states.value.id response = self.api_client.post(url, { 'original_query': term, - 'trading_address_country': constants.Country.united_states.value.id, + 'trading_address_country': united_states_id, }) assert response.status_code == status.HTTP_200_OK assert response.data['count'] == 1 assert len(response.data['results']) == 1 - assert response.data['results'][0]['trading_address_country']['id'] == constants.Country.united_states.value.id + assert response.data['results'][0]['trading_address_country']['id'] == united_states_id def test_search_company_no_filters(self): """Tests case where there is no filters provided.""" url = f"{reverse('api-v3:search:company')}?offset=0&limit=100" response = self.api_client.post(url, {}) - assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.status_code == status.HTTP_200_OK + assert len(response.data['results']) > 0 def test_search_foreign_company_json(self): """Tests detailed company search.""" @@ -186,7 +188,8 @@ def test_search_contact_no_filters(self): url = f"{reverse('api-v3:search:contact')}?offset=0&limit=100" response = self.api_client.post(url, {}) - assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.status_code == status.HTTP_200_OK + assert len(response.data['results']) > 0 def test_search_investment_project_json(self): """Tests detailed investment project search.""" @@ -228,7 +231,8 @@ def test_search_investment_project_no_filters(self): url = f"{reverse('api-v3:search:investment_project')}?offset=0&limit=100" response = self.api_client.post(url, {}) - assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.status_code == status.HTTP_200_OK + assert len(response.data['results']) > 0 @mock.patch('datahub.core.utils.executor.submit', synchronous_executor_submit) @mock.patch('django.db.transaction.on_commit', synchronous_transaction_on_commit) @@ -249,8 +253,83 @@ def test_search_results_quality(self): 'entity': 'company' }) + assert response.status_code == status.HTTP_200_OK assert response.data['count'] == 4 assert ['The Advisory', 'The Advisory Group', 'The Risk Advisory Group', 'The Advisories'] == [company['name'] for company in response.data['companies']] + + @mock.patch('datahub.core.utils.executor.submit', synchronous_executor_submit) + @mock.patch('django.db.transaction.on_commit', synchronous_transaction_on_commit) + def test_search_sort_desc(self): + """Tests quality of results.""" + CompanyFactory(name='Water 1').save() + CompanyFactory(name='water 2').save() + CompanyFactory(name='water 3').save() + CompanyFactory(name='Water 4').save() + + connections.get_connection().indices.refresh() + + term = 'Water' + + url = reverse('api-v3:search:basic') + response = self.api_client.get(url, { + 'term': term, + 'sortby': 'name:desc', + 'entity': 'company' + }) + + assert response.status_code == status.HTTP_200_OK + assert response.data['count'] == 4 + assert ['Water 4', + 'water 3', + 'water 2', + 'Water 1'] == [company['name'] for company in response.data['companies']] + + @mock.patch('datahub.core.utils.executor.submit', synchronous_executor_submit) + @mock.patch('django.db.transaction.on_commit', synchronous_transaction_on_commit) + def test_search_sort_asc(self): + """Tests quality of results.""" + CompanyFactory(name='Fire 4').save() + CompanyFactory(name='fire 3').save() + CompanyFactory(name='fire 2').save() + CompanyFactory(name='Fire 1').save() + + connections.get_connection().indices.refresh() + + term = 'Fire' + + url = reverse('api-v3:search:company') + response = self.api_client.post(url, { + 'original_query': term, + 'sortby': 'name:asc' + }) + + assert response.status_code == status.HTTP_200_OK + assert response.data['count'] == 4 + assert ['Fire 1', + 'fire 2', + 'fire 3', + 'Fire 4'] == [company['name'] for company in response.data['results']] + + @mock.patch('datahub.core.utils.executor.submit', synchronous_executor_submit) + @mock.patch('django.db.transaction.on_commit', synchronous_transaction_on_commit) + def test_search_sort_invalid(self): + """Tests quality of results.""" + CompanyFactory(name='Fire 4').save() + CompanyFactory(name='fire 3').save() + CompanyFactory(name='fire 2').save() + CompanyFactory(name='Fire 1').save() + + connections.get_connection().indices.refresh() + + term = 'Fire' + + url = reverse('api-v3:search:company') + response = self.api_client.post(url, { + 'original_query': term, + 'sortby': 'some_field_that_doesnt_exist:asc' + }) + + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/datahub/search/views.py b/datahub/search/views.py index e9c680915..af0905397 100644 --- a/datahub/search/views.py +++ b/datahub/search/views.py @@ -12,6 +12,10 @@ class SearchBasicAPIView(APIView): http_method_names = ('get',) + SORT_BY_FIELDS = ( + 'name', 'created_on', + ) + def get(self, request, format=None): """Performs basic search.""" if 'term' not in request.query_params: @@ -20,7 +24,14 @@ def get(self, request, format=None): entity = request.query_params.get('entity', 'company') if entity not in ('company', 'contact', 'investment_project'): - raise ValidationError('Entity is not one of "company", "contact" or "investment_project".') + raise ValidationError('Entity is not one of "company", "contact" or ' + '"investment_project".') + + sortby = request.query_params.get('sortby') + if sortby: + field = sortby.rsplit(':')[0] + if field not in self.SORT_BY_FIELDS: + raise ValidationError(f'"sortby" field is not one of {self.SORT_BY_FIELDS}.') offset = int(request.query_params.get('offset', 0)) limit = int(request.query_params.get('limit', 100)) @@ -28,6 +39,7 @@ def get(self, request, format=None): results = elasticsearch.get_basic_search_query( term=term, entities=entity.split(','), + field_order=sortby, offset=offset, limit=limit ).execute() @@ -53,6 +65,16 @@ def get(self, request, format=None): class SearchCompanyAPIView(APIView): """Filtered company search view.""" + SORT_BY_FIELDS = ( + 'account_manager.name', 'alias', 'archived', 'archived_by', + 'contacts.name', 'business_type.name', + 'classification.name', 'company_number', 'companies_house_data.company_number', + 'created_on', 'employee_range.name', 'headquarter_type.name', 'id', 'modified_on', + 'name', 'registered_address_town', 'sector.name', 'trading_address_town', + 'turnover_range.name', 'uk_region.name', 'uk_based', + 'export_to_countries.name', 'future_interest_countries.name', + ) + FILTER_FIELDS = ( 'name', 'alias', 'sector', 'account_manager', 'export_to_country', 'future_interest_country', 'description', 'uk_region', 'uk_based', @@ -67,17 +89,21 @@ def post(self, request, format=None): for field in self.FILTER_FIELDS if field in request.data} filters = elasticsearch.remap_fields(filters) - if len(filters.keys()) == 0: - raise ValidationError('Missing required at least one filter.') - original_query = request.data.get('original_query', '') + sortby = request.data.get('sortby') + if sortby: + field = sortby.rsplit(':')[0] + if field not in self.SORT_BY_FIELDS: + raise ValidationError(f'"sortby" field is not one of {self.SORT_BY_FIELDS}.') + offset = int(request.query_params.get('offset', 0)) limit = int(request.query_params.get('limit', 100)) results = elasticsearch.get_search_company_query( term=original_query, filters=filters, + field_order=sortby, offset=offset, limit=limit, ).execute() @@ -93,6 +119,16 @@ def post(self, request, format=None): class SearchContactAPIView(APIView): """Filtered contact search view.""" + SORT_BY_FIELDS = ( + 'archived', 'archived', 'created_on', + 'modified_on', 'id', 'name', 'title.name', 'primary', + 'telephone_countrycode', 'telephone_number', + 'email', 'address_same_as_company', 'address_town', 'address_county', + 'job_title', 'contactable_by_dit', 'contactable_by_dit_partners', 'contactable_by_email', + 'contactable_by_phone', 'address_country.name', 'adviser.name', 'archived_by.name', + 'company.name', + ) + FILTER_FIELDS = ( 'first_name', 'last_name', 'job_title', 'company', 'adviser', 'notes', ) @@ -106,17 +142,21 @@ def post(self, request, format=None): filters = elasticsearch.remap_fields(filters) - if len(filters.keys()) == 0: - raise ValidationError('Missing required at least one filter.') - original_query = request.data.get('original_query', '') + sortby = request.data.get('sortby') + if sortby: + field = sortby.rsplit(':')[0] + if field not in self.SORT_BY_FIELDS: + raise ValidationError(f'"sortby" field is not one of {self.SORT_BY_FIELDS}.') + offset = int(request.data.get('offset', 0)) limit = int(request.data.get('limit', 100)) results = elasticsearch.get_search_contact_query( term=original_query, filters=filters, + field_order=sortby, offset=offset, limit=limit, ).execute() @@ -132,6 +172,27 @@ def post(self, request, format=None): class SearchInvestmentProjectAPIView(APIView): """Filtered investment project search view.""" + SORT_BY_FIELDS = ( + 'id', 'approved_commitment_to_invest', + 'approved_fdi', 'approved_good_value', + 'approved_high_value', 'approved_landed', + 'approved_non_fdi', 'actual_land_date', + 'business_activities.name', 'client_contacts.name', + 'client_relationship_manager.name', 'project_manager.name', + 'project_assurance_adviser.name', 'team_members.name', + 'archived', 'archived_by.name', 'created_on', 'modified_on', + 'estimated_land_date', 'fdi_type.name', 'intermediate_company.name', + 'uk_company.name', 'investor_company.name', 'investment_type.name', 'name', + 'r_and_d_budget', 'non_fdi_r_and_d_budget', 'new_tech_to_uk', 'export_revenue', + 'site_decided', 'nda_signed', 'government_assistance', + 'client_cannot_provide_total_investment', 'total_investment', + 'foreign_equity_investment', 'number_new_jobs', 'non_fdi_type.name', + 'stage.name', 'project_code', 'project_shareable', + 'referral_source_activity.name', 'referral_source_activity_marketing.name', + 'referral_source_activity_website.name', 'referral_source_activity_event', + 'referral_source_advisor.name', 'sector.name', 'average_salary.name', + ) + FILTER_FIELDS = ( 'client_relationship_manager', 'description', 'estimated_land_date_after', 'estimated_land_date_before', 'investor_company', 'investment_type', @@ -144,16 +205,19 @@ def post(self, request, format=None): """Performs filtered contact search.""" filters = {field: request.data[field] for field in self.FILTER_FIELDS if field in request.data} - filters = elasticsearch.remap_fields(filters) - if len(filters.keys()) == 0: - raise ValidationError('Missing required at least one filter.') try: filters, ranges = elasticsearch.date_range_fields(filters) except ValueError: raise ValidationError('Date(s) in incorrect format.') + sortby = request.data.get('sortby') + if sortby: + field = sortby.rsplit(':')[0] + if field not in self.SORT_BY_FIELDS: + raise ValidationError(f'"sortby" field is not one of {self.SORT_BY_FIELDS}.') + offset = int(request.data.get('offset', 0)) limit = int(request.data.get('limit', 100)) @@ -161,6 +225,7 @@ def post(self, request, format=None): term='', filters=filters, ranges=ranges, + field_order=sortby, offset=offset, limit=limit, ).execute() diff --git a/datahub/user/test/test_views.py b/datahub/user/test/test_views.py index c38fd6859..3918c2792 100644 --- a/datahub/user/test/test_views.py +++ b/datahub/user/test/test_views.py @@ -1,10 +1,10 @@ from rest_framework import status from rest_framework.reverse import reverse -from datahub.core.test_utils import get_test_user, LeelooTestCase +from datahub.core.test_utils import APITestMixin, get_test_user -class UserViewTestCase(LeelooTestCase): +class TestUserView(APITestMixin): """User view test case.""" def test_who_am_i_authenticated(self): diff --git a/datahub/v2/repos/service_deliveries.py b/datahub/v2/repos/service_deliveries.py index 0de73ea54..0d71a1463 100644 --- a/datahub/v2/repos/service_deliveries.py +++ b/datahub/v2/repos/service_deliveries.py @@ -41,7 +41,9 @@ def get(self, object_id): entity = self.model_class.objects.get(id=object_id) except self.model_class.DoesNotExist: raise DoesNotExistException() - data = utils.model_to_json_api_data(entity, self.schema_class(), url_builder=self.url_builder) + data = utils.model_to_json_api_data( + entity, self.schema_class(), url_builder=self.url_builder + ) return utils.build_repo_response(data=data) def filter(self, company_id=utils.DEFAULT, contact_id=utils.DEFAULT, offset=0, limit=100): @@ -54,7 +56,8 @@ def filter(self, company_id=utils.DEFAULT, contact_id=utils.DEFAULT, offset=0, l start, end = offset, offset + limit queryset = self.model_class.objects.filter(**filters) entities = list(queryset[start:end]) - data = [utils.model_to_json_api_data(entity, self.schema_class(), self.url_builder) for entity in entities] + data = [utils.model_to_json_api_data(entity, self.schema_class(), self.url_builder) + for entity in entities] return utils.build_repo_response(data=data) def upsert(self, data): @@ -87,7 +90,8 @@ def inject_service_offer(self, data): ) if not service_offer_id: raise RepoDataValidationError( - detail={'relationships.service': 'This combination of service and service provider does not exist.'} + detail={'relationships.service': 'This combination of service and service ' + 'provider does not exist.'} ) else: data['relationships'].update({ diff --git a/datahub/v2/repos/utils.py b/datahub/v2/repos/utils.py index bff498f58..dfe9e5dd8 100644 --- a/datahub/v2/repos/utils.py +++ b/datahub/v2/repos/utils.py @@ -172,8 +172,10 @@ def extract_id_for_relationship_from_data(data, relationship_name): def replace_colander_null(data): """Replace colander.null with None in deserialized data.""" cleaned_data = { - 'attributes': {k: None if v is colander.null else v for k, v in data['attributes'].items()}, - 'relationships': {k: None if v is colander.null else v for k, v in data['relationships'].items()} + 'attributes': {k: None if v is colander.null else v + for k, v in data['attributes'].items()}, + 'relationships': {k: None if v is colander.null else v + for k, v in data['relationships'].items()} } data.update(cleaned_data) return data diff --git a/datahub/v2/tests/repos/test_service_deliveries_repo.py b/datahub/v2/tests/repos/test_service_deliveries_repo.py index f854c5a48..e1ca9656b 100644 --- a/datahub/v2/tests/repos/test_service_deliveries_repo.py +++ b/datahub/v2/tests/repos/test_service_deliveries_repo.py @@ -1,7 +1,6 @@ import uuid import pytest -from django.test import TestCase from django.utils.timezone import now from freezegun import freeze_time @@ -18,7 +17,7 @@ DUMMY_CONFIG = config = {'url_builder': lambda kwargs: None} -class ServiceDeliveriesRepoTestCase(TestCase): +class TestServiceDeliveriesRepo: """Service delivery repo test case.""" def test_get(self): @@ -33,13 +32,16 @@ def test_get(self): assert isinstance(result, RepoResponse) expected_relationships = { 'contact': {'data': {'id': str(service_delivery.contact.pk), 'type': 'Contact'}}, - 'status': {'data': {'id': str(service_delivery.status.pk), 'type': 'ServiceDeliveryStatus'}}, 'company': {'data': {'id': str(service_delivery.company.pk), 'type': 'Company'}}, 'service': {'data': {'id': str(service_delivery.service.pk), 'type': 'Service'}}, 'dit_team': {'data': {'id': str(service_delivery.dit_team.pk), 'type': 'Team'}}, 'uk_region': {'data': {'id': str(service_delivery.uk_region.pk), 'type': 'UKRegion'}}, - 'dit_adviser': {'data': {'id': str( - service_delivery.dit_adviser.pk), 'type': 'Adviser'}} + 'status': { + 'data': {'id': str(service_delivery.status.pk), 'type': 'ServiceDeliveryStatus'} + }, + 'dit_adviser': { + 'data': {'id': str(service_delivery.dit_adviser.pk), 'type': 'Adviser'} + } } assert data['relationships'] == expected_relationships assert data['relationships']['company']['data']['type'] == 'Company' @@ -59,6 +61,7 @@ def test_insert(self): user = get_test_user() company = factories.CompanyFactory() contact = factories.ContactFactory() + offered_id = constants.ServiceDeliveryStatus.offered.value.id data = { 'type': 'ServiceDelivery', 'attributes': { @@ -70,7 +73,7 @@ def test_insert(self): 'status': { 'data': { 'type': 'ServiceDeliveryStatus', - 'id': constants.ServiceDeliveryStatus.offered.value.id + 'id': offered_id } }, 'company': { @@ -119,7 +122,7 @@ def test_insert(self): expected_relationships = { 'dit_adviser': {'data': {'type': 'Adviser', 'id': str(user.pk)}}, 'status': {'data': { - 'type': 'ServiceDeliveryStatus', 'id': constants.ServiceDeliveryStatus.offered.value.id} + 'type': 'ServiceDeliveryStatus', 'id': offered_id} }, 'contact': {'data': {'type': 'Contact', 'id': str(contact.pk)}}, 'dit_team': {'data': {'type': 'Team', 'id': str(service_offer.dit_team.pk)}}, @@ -186,7 +189,9 @@ def test_filter_by_company_id(self): dit_team=service_offer.dit_team, company=company ) - result = ServiceDeliveryDatabaseRepo(config=DUMMY_CONFIG).filter(company_id=str(company.pk)) + result = ServiceDeliveryDatabaseRepo(config=DUMMY_CONFIG).filter( + company_id=str(company.pk) + ) data = result.data assert isinstance(result, RepoResponse) assert len(data) == 1 @@ -205,7 +210,9 @@ def test_filter_by_contact_id(self): dit_team=service_offer.dit_team, contact=contact ) - result = ServiceDeliveryDatabaseRepo(config=DUMMY_CONFIG).filter(contact_id=str(contact.pk)) + result = ServiceDeliveryDatabaseRepo(config=DUMMY_CONFIG).filter( + contact_id=str(contact.pk) + ) data = result.data assert isinstance(result, RepoResponse) assert len(data) == 1 diff --git a/datahub/v2/tests/views/test_parsers.py b/datahub/v2/tests/views/test_parsers.py index 81be43143..52eab5a8c 100644 --- a/datahub/v2/tests/views/test_parsers.py +++ b/datahub/v2/tests/views/test_parsers.py @@ -2,10 +2,10 @@ from rest_framework.reverse import reverse -from datahub.core.test_utils import LeelooTestCase +from datahub.core.test_utils import APITestMixin -class JSONParserTestCase(LeelooTestCase): +class TestJSONParser(APITestMixin): """Test generic parser error through a v2 view.""" def test_data_key_not_in_post_body(self): @@ -65,8 +65,8 @@ def test_data_contains_incorrect_entity_name(self): ) content = json.loads(response.content.decode('utf-8')) expected_content = {'errors': [ - {'detail': 'The resource object\'s type (whatever) is not the type that constitute the collection ' - 'represented by the endpoint (ServiceDelivery).', + {'detail': 'The resource object\'s type (whatever) is not the type that constitute ' + 'the collection represented by the endpoint (ServiceDelivery).', 'source': {'pointer': '/data/detail'} }] } diff --git a/datahub/v2/tests/views/test_service_delivery.py b/datahub/v2/tests/views/test_service_delivery.py index 379bb11d4..afec22469 100644 --- a/datahub/v2/tests/views/test_service_delivery.py +++ b/datahub/v2/tests/views/test_service_delivery.py @@ -7,13 +7,13 @@ from datahub.company.test.factories import CompanyFactory, ContactFactory from datahub.core import constants -from datahub.core.test_utils import LeelooTestCase +from datahub.core.test_utils import APITestMixin from datahub.interaction.test.factories import ServiceDeliveryFactory, ServiceOfferFactory from datahub.metadata.test.factories import EventFactory -class ServiceDeliveryViewTestCase(LeelooTestCase): +class TestServiceDeliveryView(APITestMixin): """Service Delivery view test case.""" def test_service_delivery_detail_view(self): @@ -37,7 +37,12 @@ def test_service_delivery_detail_view_not_found(self): response = self.api_client.get(url) content = json.loads(response.content.decode('utf-8')) assert response.status_code == status.HTTP_404_NOT_FOUND - expected_content = {'errors': [{'source': {'pointer': '/data/detail'}, 'detail': 'Not found.'}]} + expected_content = { + 'errors': [{ + 'source': {'pointer': '/data/detail'}, + 'detail': 'Not found.' + }] + } assert content == expected_content def test_service_delivery_list_view(self): @@ -284,7 +289,8 @@ def test_filter_service_deliveries_by_company(self): response = self.api_client.get(url, data={'company_id': company.pk}) content = json.loads(response.content.decode('utf-8')) assert response.status_code == status.HTTP_200_OK - assert {element['id'] for element in content['data']} == {str(servicedelivery.pk), str(servicedelivery2.pk)} + expected_ids = {str(servicedelivery.pk), str(servicedelivery2.pk)} + assert {element['id'] for element in content['data']} == expected_ids def test_filter_service_deliveries_by_contact(self): """Filter by contact.""" @@ -308,7 +314,8 @@ def test_filter_service_deliveries_by_contact(self): response = self.api_client.get(url, data={'contact_id': contact.pk}) content = json.loads(response.content.decode('utf-8')) assert response.status_code == status.HTTP_200_OK - assert {element['id'] for element in content['data']} == {str(servicedelivery.pk), str(servicedelivery2.pk)} + expected_ids = {str(servicedelivery.pk), str(servicedelivery2.pk)} + assert {element['id'] for element in content['data']} == expected_ids def test_add_service_delivery_incorrect_service_team_event_combination(self): """Test add new service delivery with invalid service/team/even combination.""" diff --git a/fixtures/test_data.yaml b/fixtures/test_data.yaml index 09361130a..3d4c2b93d 100644 --- a/fixtures/test_data.yaml +++ b/fixtures/test_data.yaml @@ -106,6 +106,38 @@ last_name: Rogers dit_team: 7648318c-9698-e211-a939-e4115bead28a +- model: company.advisor + pk: d7493b4e-5d7b-4834-98d9-28b78a74052a + fields: + email: michael.wining@example.com + first_name: Michael + last_name: Wining + dit_team: 7648318c-9698-e211-a939-e4115bead28a + +- model: company.advisor + pk: 5a644146-5298-4741-91f3-d5d7558adf47 + fields: + email: paula.churing@example.com + first_name: Paula + last_name: Churing + dit_team: 7648318c-9698-e211-a939-e4115bead28a + +- model: company.advisor + pk: a80ff5fd-8904-4940-bf96-fe8047e34be5 + fields: + email: barry.oling@example.com + first_name: Barry + last_name: Oling + dit_team: b85bf640-9798-e211-a939-e4115bead28a + +- model: company.advisor + pk: f1eb4363-0a37-4344-bd96-e90abeaf483e + fields: + email: jenny.carey@example.com + first_name: Jenny + last_name: Carey + dit_team: 4672f846-9798-e211-a939-e4115bead28a + - model: company.contact pk: 9b1138ab-ec7b-497f-b8c3-27fed21694ef fields: @@ -194,6 +226,7 @@ fields: name: New hotel (FDI) description: This is a dummy investment project for testing + stage: c9864359-fb1a-4646-a4c1-97d10189fc03 # Assign PM nda_signed: false estimated_land_date: 2020-01-01 investment_type: 3e143372-496c-4d1e-8278-6fdd3da9b48b # FDI @@ -209,7 +242,7 @@ sector: 034be3be-5329-e511-b6bc-e4115bead28a business_activities: - a2dbd807-ae52-421c-8d1d-88adfc7a506b - stage: c9864359-fb1a-4646-a4c1-97d10189fc03 + client_cannot_provide_total_investment: false total_investment: 1000000.0 foreign_equity_investment: 200000.0 government_assistance: true @@ -228,9 +261,10 @@ fields: name: New hotel (Non-FDI) description: This is a dummy investment project for testing + stage: 8a320cc9-ae2e-443e-9d26-2f36452c2ced # Prospect nda_signed: false estimated_land_date: 2020-01-01 - investment_type: 9c364e64-2b28-401b-b2df-50e08b0bca44 # FDI + investment_type: 9c364e64-2b28-401b-b2df-50e08b0bca44 # Non-FDI project_shareable: true not_shareable_reason: '' investor_company: 0f5216e0-849f-11e6-ae22-56b6b6499611 @@ -249,6 +283,7 @@ fields: name: New hotel (commitment to invest) description: This is a dummy investment project for testing + stage: 8a320cc9-ae2e-443e-9d26-2f36452c2ced # Prospect nda_signed: false estimated_land_date: 2020-01-01 investment_type: 031269ab-b7ec-40e9-8a4e-7371404f0622 # Commitment to invest @@ -265,6 +300,153 @@ business_activities: - a2dbd807-ae52-421c-8d1d-88adfc7a506b +- model: investment.investmentproject + pk: 18750b26-a8c3-41b2-8d3a-fb0b930c2270 + fields: + name: New restaurant + description: This is a dummy investment project for testing + stage: 7606cc19-20da-4b74-aba1-2cec0d753ad8 # Active + nda_signed: false + estimated_land_date: 2020-01-01 + investment_type: 3e143372-496c-4d1e-8278-6fdd3da9b48b # FDI + project_shareable: false + not_shareable_reason: Commercially sensitive + investor_company: 0f5216e0-849f-11e6-ae22-56b6b6499611 + client_contacts: + - 952232d2-1d25-4c3a-bcac-2f3a30a94da9 + client_relationship_manager: e83a608e-84a4-11e6-ae22-56b6b6499611 + referral_source_adviser: e83a608e-84a4-11e6-ae22-56b6b6499611 + referral_source_activity: aba8f653-264f-48d8-950e-07f9c418c7b0 + fdi_type: f8447013-cfdc-4f35-a146-6619665388b3 + sector: 034be3be-5329-e511-b6bc-e4115bead28a + business_activities: + - a2dbd807-ae52-421c-8d1d-88adfc7a506b + client_cannot_provide_total_investment: false + total_investment: 1000000.0 + client_cannot_provide_foreign_investment: false + foreign_equity_investment: 200000.0 + government_assistance: true + number_new_jobs: 20 + average_salary: 2943bf3d-32dd-43be-8ad4-969b006dee7b + site_decided: false + client_considering_other_countries: false + uk_region_locations: + - 814cd12a-6095-e211-a939-e4115bead28a + client_requirements: Anywhere + strategic_drivers: + - 382aa6d1-a362-4166-a09d-f579d9f3be75 + project_manager: d7493b4e-5d7b-4834-98d9-28b78a74052a + project_assurance_adviser: 5a644146-5298-4741-91f3-d5d7558adf47 + +- model: investment.investmentprojectteammember + pk: 1 + fields: + investment_project: 18750b26-a8c3-41b2-8d3a-fb0b930c2270 + adviser: a80ff5fd-8904-4940-bf96-fe8047e34be5 + role: Sector adviser + +- model: investment.investmentprojectteammember + pk: 2 + fields: + investment_project: 18750b26-a8c3-41b2-8d3a-fb0b930c2270 + adviser: f1eb4363-0a37-4344-bd96-e90abeaf483e + role: Finance adviser + +- model: investment.investmentproject + pk: ea3a03ba-b239-4956-b2fb-f35c91109674 + fields: + name: New fruit machine + description: This is a dummy investment project for testing + stage: 49b8f6f3-0c50-4150-a965-2c974f3149e3 # Verify win + nda_signed: false + estimated_land_date: 2020-01-01 + investment_type: 3e143372-496c-4d1e-8278-6fdd3da9b48b # FDI + project_shareable: false + not_shareable_reason: Commercially sensitive + investor_company: 0f5216e0-849f-11e6-ae22-56b6b6499611 + client_contacts: + - 952232d2-1d25-4c3a-bcac-2f3a30a94da9 + client_relationship_manager: e83a608e-84a4-11e6-ae22-56b6b6499611 + referral_source_adviser: e83a608e-84a4-11e6-ae22-56b6b6499611 + referral_source_activity: aba8f653-264f-48d8-950e-07f9c418c7b0 + fdi_type: f8447013-cfdc-4f35-a146-6619665388b3 + sector: 034be3be-5329-e511-b6bc-e4115bead28a + business_activities: + - a2dbd807-ae52-421c-8d1d-88adfc7a506b + client_cannot_provide_total_investment: false + total_investment: 1000000.0 + client_cannot_provide_foreign_investment: false + foreign_equity_investment: 200000.0 + government_assistance: true + number_new_jobs: 20 + average_salary: 2943bf3d-32dd-43be-8ad4-969b006dee7b + site_decided: false + number_safeguarded_jobs: 1 + r_and_d_budget: false + non_fdi_r_and_d_budget: true + new_tech_to_uk: true + export_revenue: true + client_considering_other_countries: false + uk_region_locations: + - 814cd12a-6095-e211-a939-e4115bead28a + client_requirements: Anywhere + strategic_drivers: + - 382aa6d1-a362-4166-a09d-f579d9f3be75 + project_manager: d7493b4e-5d7b-4834-98d9-28b78a74052a + project_assurance_adviser: 5a644146-5298-4741-91f3-d5d7558adf47 + uk_company: a73efeba-8499-11e6-ae22-56b6b6499611 + address_line_1: 10 Eastings Road + address_line_2: London + address_line_postcode: W1 2AA + + +- model: investment.investmentproject + pk: 5d341b34-1fc8-4638-b4b1-a0922ebf401e + fields: + name: New airport + description: This is a dummy investment project for testing + stage: 945ea6d1-eee3-4f5b-9144-84a75b71b8e6 # Won + nda_signed: false + estimated_land_date: 2020-01-01 + investment_type: 3e143372-496c-4d1e-8278-6fdd3da9b48b # FDI + project_shareable: false + not_shareable_reason: Commercially sensitive + investor_company: 0f5216e0-849f-11e6-ae22-56b6b6499611 + client_contacts: + - 952232d2-1d25-4c3a-bcac-2f3a30a94da9 + client_relationship_manager: e83a608e-84a4-11e6-ae22-56b6b6499611 + referral_source_adviser: e83a608e-84a4-11e6-ae22-56b6b6499611 + referral_source_activity: aba8f653-264f-48d8-950e-07f9c418c7b0 + fdi_type: f8447013-cfdc-4f35-a146-6619665388b3 + sector: 034be3be-5329-e511-b6bc-e4115bead28a + business_activities: + - a2dbd807-ae52-421c-8d1d-88adfc7a506b + client_cannot_provide_total_investment: false + total_investment: 1000000.0 + client_cannot_provide_foreign_investment: false + foreign_equity_investment: 200000.0 + government_assistance: true + number_new_jobs: 20 + average_salary: 2943bf3d-32dd-43be-8ad4-969b006dee7b + site_decided: false + number_safeguarded_jobs: 1 + r_and_d_budget: false + non_fdi_r_and_d_budget: true + new_tech_to_uk: true + export_revenue: true + client_considering_other_countries: false + uk_region_locations: + - 814cd12a-6095-e211-a939-e4115bead28a + client_requirements: Anywhere + strategic_drivers: + - 382aa6d1-a362-4166-a09d-f579d9f3be75 + project_manager: d7493b4e-5d7b-4834-98d9-28b78a74052a + project_assurance_adviser: 5a644146-5298-4741-91f3-d5d7558adf47 + uk_company: a73efeba-8499-11e6-ae22-56b6b6499611 + address_line_1: 10 Eastings Road + address_line_2: London + address_line_postcode: W1 2AA + - model: interaction.interaction pk: 94d03877-fceb-44a0-8813-17b727d2e7f6 fields: diff --git a/requirements.in b/requirements.in index bc97dd50c..d0fd7a92a 100644 --- a/requirements.in +++ b/requirements.in @@ -3,10 +3,10 @@ lxml==3.8.0 cssselect==1.0.1 # Django and django related -Django==1.11.2 +Django==1.11.3 djangorestframework==3.6.3 django-environ==0.4.3 -django-extensions==1.7.9 +django-extensions==1.8.1 django-filter==1.0.4 django-reversion==2.0.9 django-pglocks==1.0.2 @@ -18,7 +18,7 @@ whitenoise==3.3.0 pyyaml==3.12 colander==1.3.3 pyquery==1.2.17 -python-dateutil==2.6.0 +python-dateutil==2.6.1 # Persistency layer psycopg2==2.7.1 diff --git a/requirements.txt b/requirements.txt index f05f30900..2aeb39398 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ # pip-compile --output-file requirements.txt requirements.in # boto3==1.4.4 -botocore==1.5.75 # via boto3, s3transfer +botocore==1.5.81 # via boto3, s3transfer certifi==2017.4.17 # via requests chardet==3.0.4 # via requests click==6.7 # via pip-tools @@ -16,14 +16,14 @@ cssselect==1.0.1 decorator==4.0.11 # via ipython, traitlets django-debug-toolbar==1.8 django-environ==0.4.3 -django-extensions==1.7.9 +django-extensions==1.8.1 django-filter==1.0.4 django-model-utils==3.0.0 django-mptt==0.8.7 django-oauth-toolkit==1.0.0 django-pglocks==1.0.2 django-reversion==2.0.9 -django==1.11.2 # via django-debug-toolbar, django-environ, django-model-utils, django-oauth-toolkit, django-reversion +django==1.11.3 # via django-debug-toolbar, django-environ, django-model-utils, django-oauth-toolkit, django-reversion djangorestframework==3.6.3 docutils==0.13.1 # via botocore elasticsearch-dsl==2.2.0 @@ -69,8 +69,8 @@ pyquery==1.2.17 pytest-cov==2.5.1 pytest-django==3.1.2 pytest-sugar==0.8.0 -pytest==3.1.2 # via pytest-cov, pytest-django, pytest-sugar -python-dateutil==2.6.0 +pytest==3.1.3 # via pytest-cov, pytest-django, pytest-sugar +python-dateutil==2.6.1 pytz==2017.2 # via django pyyaml==3.12 raven==6.1.0 diff --git a/setup.cfg b/setup.cfg index cad34481d..995573b3f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ # D104: Missing docstring in public package exclude = */migrations/*,__pycache__,manage.py,config/*, ignore = D203, D100, D104 -max-line-length = 119 +max-line-length = 99 [flake8] @@ -21,7 +21,7 @@ max-line-length = 119 # P101: format string does contain unindexed parameters exclude = */migrations/*,__pycache__,manage.py,config/*,conftest.py,factories.py,datahub/korben/test/*,env/* ignore = D203, D100, D104, D200, D205, D400, D401, P101 -max-line-length = 119 +max-line-length = 99 max-complexity = 10 application-import-names = datahub,loading_scripts import_order_style = smarkets