From 41779c7d23603a24759fbc63b8623adefe126a9e Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Wed, 7 Feb 2024 12:50:42 +0000 Subject: [PATCH 01/28] Add endpoint to stream appeal documents --- api/appeals/filters.py | 6 ++ api/appeals/tests/test_views.py | 89 ++++++++++++++++++++++++++++ api/appeals/urls.py | 5 ++ api/appeals/views.py | 16 +++-- api/core/views.py | 17 ++++++ api/documents/views.py | 11 ++-- api/organisations/views/documents.py | 13 ++-- 7 files changed, 139 insertions(+), 18 deletions(-) create mode 100644 api/appeals/filters.py diff --git a/api/appeals/filters.py b/api/appeals/filters.py new file mode 100644 index 0000000000..70c3d63740 --- /dev/null +++ b/api/appeals/filters.py @@ -0,0 +1,6 @@ +from rest_framework import filters + + +class AppealFilter(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return queryset.filter(appeal_id=view.kwargs["pk"]) diff --git a/api/appeals/tests/test_views.py b/api/appeals/tests/test_views.py index 301c98c62a..2a8baec9f8 100644 --- a/api/appeals/tests/test_views.py +++ b/api/appeals/tests/test_views.py @@ -1,5 +1,8 @@ from unittest import mock +from moto import mock_aws + +from django.http import FileResponse from django.urls import reverse from django.utils.timezone import now @@ -180,3 +183,89 @@ def test_get_document_different_organisation(self): response.status_code, status.HTTP_403_FORBIDDEN, ) + + +@mock_aws +class TestAppealDocumentStream(DataTestClient): + def setUp(self): + super().setUp() + self.appeal = AppealFactory() + application = self.create_standard_application_case( + organisation=self.exporter_user.organisation, + ) + application.appeal = self.appeal + application.save() + + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") + + def test_get_document_stream(self): + appeal_document = AppealDocumentFactory( + appeal=self.appeal, + s3_key="thisisakey", + safe=True, + ) + + url = reverse( + "appeals:document_stream", + kwargs={ + "pk": str(self.appeal.pk), + "document_pk": str(appeal_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_get_document_stream_invalid_appeal_pk(self): + appeal_document = AppealDocumentFactory(appeal=self.appeal) + + url = reverse( + "appeals:document_stream", + kwargs={ + "pk": "0f415f8a-e3e8-4c49-b053-ef03b1c477d5", + "document_pk": str(appeal_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual( + response.status_code, + status.HTTP_404_NOT_FOUND, + ) + + def test_get_document_stream_invalid_document_pk(self): + url = reverse( + "appeals:document_stream", + kwargs={ + "pk": str(self.appeal.pk), + "document_pk": "0b551122-1ac2-4ea2-82b3-f1aaf0bf4923", + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual( + response.status_code, + status.HTTP_404_NOT_FOUND, + ) + + def test_get_document_stream_different_organisation(self): + self.appeal.baseapplication.organisation = self.create_organisation_with_exporter_user()[0] + self.appeal.baseapplication.save() + appeal_document = AppealDocumentFactory(appeal=self.appeal) + + url = reverse( + "appeals:document_stream", + kwargs={ + "pk": str(self.appeal.pk), + "document_pk": str(appeal_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual( + response.status_code, + status.HTTP_403_FORBIDDEN, + ) diff --git a/api/appeals/urls.py b/api/appeals/urls.py index d0dbba8889..1eb60b948e 100644 --- a/api/appeals/urls.py +++ b/api/appeals/urls.py @@ -15,4 +15,9 @@ views.AppealDocumentAPIView.as_view(), name="document", ), + path( + "/documents//stream/", + views.AppealDocumentStreamAPIView.as_view(), + name="document_stream", + ), ] diff --git a/api/appeals/views.py b/api/appeals/views.py index 30a402fd8e..96cf168703 100644 --- a/api/appeals/views.py +++ b/api/appeals/views.py @@ -7,7 +7,9 @@ from api.core.authentication import ExporterAuthentication from api.core.permissions import IsExporterInOrganisation +from api.core.views import DocumentStreamAPIView +from .filters import AppealFilter from .models import ( Appeal, AppealDocument, @@ -36,9 +38,15 @@ def get_serializer_context(self): class AppealDocumentAPIView(BaseAppealDocumentAPIView, RetrieveAPIView): + filter_backends = (AppealFilter,) lookup_url_kwarg = "document_pk" + queryset = AppealDocument.objects.all() - def get_queryset(self): - return AppealDocument.objects.filter( - appeal_id=self.kwargs["pk"], - ) + +class AppealDocumentStreamAPIView(BaseAppealDocumentAPIView, DocumentStreamAPIView): + filter_backends = (AppealFilter,) + lookup_url_kwarg = "document_pk" + queryset = AppealDocument.objects.all() + + def get_document(self, instance): + return instance diff --git a/api/core/views.py b/api/core/views.py index 5de0c3db49..fe83d00071 100644 --- a/api/core/views.py +++ b/api/core/views.py @@ -1,8 +1,13 @@ from django.contrib import admin +from django.http import Http404 from django.shortcuts import redirect from django.urls import reverse from django.views.generic import View +from rest_framework.generics import RetrieveAPIView + +from api.documents.libraries.s3_operations import document_download_stream + class LoginProviderView(View): """If user if not logged in then send them to staff sso, otherwise show them vanilla django admin login page""" @@ -12,3 +17,15 @@ def dispatch(self, request): return redirect(reverse("authbroker_client:login")) # to show the "you're not an admin" message. return admin.site.login(request) + + +class DocumentStreamAPIView(RetrieveAPIView): + def get_document(self, instance): + raise NotImplementedError() + + def retrieve(self, request, *args, **kwargs): + instance = self.get_object() + document = self.get_document(instance) + if not document.safe: + raise Http404() + return document_download_stream(document) diff --git a/api/documents/views.py b/api/documents/views.py index c8820a77a7..61bde4f3af 100644 --- a/api/documents/views.py +++ b/api/documents/views.py @@ -10,7 +10,7 @@ from api.cases.generated_documents.signing import get_certificate_data from api.core.authentication import SharedAuthentication from api.core.exceptions import NotFoundError -from api.documents.libraries.s3_operations import document_download_stream +from api.core.views import DocumentStreamAPIView from api.documents.models import Document from api.documents.serializers import DocumentViewSerializer from api.documents import permissions @@ -47,15 +47,14 @@ def get(self, request): return response -class DocumentStream(RetrieveAPIView): +class DocumentStream(DocumentStreamAPIView): """ Get streamed content of a document. """ authentication_classes = (SharedAuthentication,) - queryset = Document.objects.filter(safe=True) + queryset = Document.objects.all() permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) - def retrieve(self, request, *args, **kwargs): - document = self.get_object() - return document_download_stream(document) + def get_document(self, instance): + return instance diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 8d1807fcc2..27ab30ba49 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -1,12 +1,11 @@ from rest_framework import viewsets -from rest_framework.generics import RetrieveAPIView from django.http import JsonResponse from api.audit_trail import service as audit_trail_service from api.audit_trail.enums import AuditType from api.core.authentication import SharedAuthentication -from api.documents.libraries.s3_operations import document_download_stream +from api.core.views import DocumentStreamAPIView from api.organisations import ( filters, models, @@ -85,14 +84,12 @@ def update(self, request, pk, document_on_application_pk): return JsonResponse({"document": serializer.data}, status=200) -class DocumentOnOrganisationStreamView(RetrieveAPIView): +class DocumentOnOrganisationStreamView(DocumentStreamAPIView): authentication_classes = (SharedAuthentication,) filter_backends = (filters.OrganisationFilter,) lookup_url_kwarg = "document_on_application_pk" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) - queryset = models.DocumentOnOrganisation.objects.filter(document__safe=True) + queryset = models.DocumentOnOrganisation.objects.all() - def retrieve(self, request, *args, **kwargs): - document = self.get_object() - document = document.document - return document_download_stream(document) + def get_document(self, instance): + return instance.document From 6ad06abbac8a5977bce556e068431f70358cb06f Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Wed, 7 Feb 2024 16:43:42 +0000 Subject: [PATCH 02/28] Add good document streaming endpoint --- api/goods/filters.py | 6 ++ api/goods/permissions.py | 14 +++ api/goods/tests/test_goods_documents.py | 110 ++++++++++++++++++++++++ api/goods/urls.py | 5 ++ api/goods/views.py | 20 +++++ 5 files changed, 155 insertions(+) create mode 100644 api/goods/filters.py create mode 100644 api/goods/permissions.py diff --git a/api/goods/filters.py b/api/goods/filters.py new file mode 100644 index 0000000000..45cefe7253 --- /dev/null +++ b/api/goods/filters.py @@ -0,0 +1,6 @@ +from rest_framework import filters + + +class GoodFilter(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + return queryset.filter(good_id=view.kwargs["pk"]) diff --git a/api/goods/permissions.py b/api/goods/permissions.py new file mode 100644 index 0000000000..576a62ceab --- /dev/null +++ b/api/goods/permissions.py @@ -0,0 +1,14 @@ +from rest_framework import permissions + +from api.goods.enums import GoodStatus +from api.organisations.libraries.get_organisation import get_request_user_organisation_id + + +class IsDocumentInOrganisation(permissions.BasePermission): + def has_object_permission(self, request, view, obj): + return obj.good.organisation_id == get_request_user_organisation_id(request) + + +class IsGoodDraft(permissions.BasePermission): + def has_object_permission(self, request, view, obj): + return obj.good.status == GoodStatus.DRAFT diff --git a/api/goods/tests/test_goods_documents.py b/api/goods/tests/test_goods_documents.py index b9e70a7550..25222625be 100644 --- a/api/goods/tests/test_goods_documents.py +++ b/api/goods/tests/test_goods_documents.py @@ -1,10 +1,16 @@ from django.urls import reverse from rest_framework import status +from moto import mock_aws +from parameterized import parameterized + +from django.http import FileResponse + from api.applications.models import GoodOnApplication from test_helpers.clients import DataTestClient from api.applications.tests.factories import StandardApplicationFactory +from api.goods.enums import GoodStatus from api.goods.tests.factories import GoodFactory from api.organisations.tests.factories import OrganisationFactory @@ -121,3 +127,107 @@ def test_edit_product_document_description(self): response = self.client.get(url, **self.exporter_headers) self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.json()["document"]["description"], "Updated document description") + + +@mock_aws +class GoodDocumentStreamTests(DataTestClient): + def setUp(self): + super().setUp() + self.good = GoodFactory(organisation=self.organisation) + self.create_default_bucket() + self.put_object_in_default_bucket("thisisakey", b"test") + + def test_get_good_document_stream(self): + good_document = self.create_good_document( + good=self.good, + user=self.exporter_user, + organisation=self.organisation, + s3_key="thisisakey", + name="doc1.pdf", + ) + + url = reverse( + "goods:document_stream", + kwargs={ + "pk": str(self.good.pk), + "doc_pk": str(good_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIsInstance(response, FileResponse) + self.assertEqual(b"".join(response.streaming_content), b"test") + + def test_get_good_document_stream_invalid_good_pk(self): + another_good = GoodFactory(organisation=self.organisation) + good_document = self.create_good_document( + good=self.good, + user=self.exporter_user, + organisation=self.organisation, + s3_key="thisisakey", + name="doc1.pdf", + ) + + url = reverse( + "goods:document_stream", + kwargs={ + "pk": str(another_good.pk), + "doc_pk": str(good_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_get_good_document_stream_other_organisation(self): + other_organisation = self.create_organisation_with_exporter_user()[0] + self.good.organisation = other_organisation + self.good.save() + good_document = self.create_good_document( + good=self.good, + user=self.exporter_user, + organisation=other_organisation, + s3_key="thisisakey", + name="doc1.pdf", + ) + + url = reverse( + "goods:document_stream", + kwargs={ + "pk": str(self.good.pk), + "doc_pk": str(good_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + @parameterized.expand( + [ + GoodStatus.SUBMITTED, + GoodStatus.QUERY, + GoodStatus.VERIFIED, + ], + ) + def test_get_good_document_stream_good_not_draft(self, good_status): + self.good.status = good_status + self.good.save() + good_document = self.create_good_document( + good=self.good, + user=self.exporter_user, + organisation=self.organisation, + s3_key="thisisakey", + name="doc1.pdf", + ) + + url = reverse( + "goods:document_stream", + kwargs={ + "pk": str(self.good.pk), + "doc_pk": str(good_document.pk), + }, + ) + response = self.client.get(url, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) diff --git a/api/goods/urls.py b/api/goods/urls.py index 1cda5dc63f..d51dffa0c7 100644 --- a/api/goods/urls.py +++ b/api/goods/urls.py @@ -23,6 +23,11 @@ views.GoodDocumentDetail.as_view(), name="document", ), + path( + "/documents//stream/", + views.GoodDocumentStream.as_view(), + name="document_stream", + ), path( "document_internal_good_on_application//", views.DocumentGoodOnApplicationInternalView.as_view(), diff --git a/api/goods/views.py b/api/goods/views.py index 9bd33c6d69..05546ad342 100644 --- a/api/goods/views.py +++ b/api/goods/views.py @@ -18,9 +18,11 @@ from api.core.authentication import ExporterAuthentication, SharedAuthentication, GovAuthentication from api.core.exceptions import BadRequestError from api.core.helpers import str_to_bool +from api.core.views import DocumentStreamAPIView from api.documents.libraries.delete_documents_on_bad_request import delete_documents_on_bad_request from api.documents.models import Document from api.goods.enums import GoodStatus, GoodPvGraded, ItemCategory +from api.goods.filters import GoodFilter from api.goods.goods_paginator import GoodListPaginator from api.goods.helpers import ( FIREARMS_CORE_TYPES, @@ -31,6 +33,10 @@ from api.goods.libraries.get_goods import get_good, get_good_document from api.goods.libraries.save_good import create_or_update_good from api.goods.models import Good, GoodDocument +from api.goods.permissions import ( + IsDocumentInOrganisation, + IsGoodDraft, +) from api.goods.serializers import ( GoodAttachingSerializer, GoodCreateSerializer, @@ -539,6 +545,20 @@ def delete(self, request, pk, doc_pk): return JsonResponse({"document": "deleted success"}) +class GoodDocumentStream(DocumentStreamAPIView): + authentication_classes = (ExporterAuthentication,) + filter_backends = (GoodFilter,) + lookup_url_kwarg = "doc_pk" + queryset = GoodDocument.objects.all() + permission_classes = ( + IsDocumentInOrganisation, + IsGoodDraft, + ) + + def get_document(self, instance): + return instance + + class DocumentGoodOnApplicationInternalView(APIView): authentication_classes = (GovAuthentication,) serializer_class = GoodOnApplicationInternalDocumentCreateSerializer From d3afc515af50606a5b29635f7458e4e0ee952e6d Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Fri, 9 Feb 2024 12:50:17 +0000 Subject: [PATCH 03/28] Extract out common filter class for parent object filtering --- api/conf/settings_test.py | 4 ++ api/core/filters.py | 17 +++++++ api/core/tests/apps.py | 6 +++ api/core/tests/migrations/0001_initial.py | 32 ++++++++++++ api/core/tests/migrations/__init__.py | 0 api/core/tests/models.py | 10 ++++ api/core/tests/serializers.py | 12 +++++ api/core/tests/test_filters.py | 62 +++++++++++++++++++++++ api/core/tests/urls.py | 19 +++++++ api/core/tests/views.py | 17 +++++++ api/goods/filters.py | 6 --- api/goods/views.py | 5 +- api/organisations/filters.py | 6 --- api/organisations/views/documents.py | 8 +-- 14 files changed, 187 insertions(+), 17 deletions(-) create mode 100644 api/core/filters.py create mode 100644 api/core/tests/apps.py create mode 100644 api/core/tests/migrations/0001_initial.py create mode 100644 api/core/tests/migrations/__init__.py create mode 100644 api/core/tests/models.py create mode 100644 api/core/tests/serializers.py create mode 100644 api/core/tests/test_filters.py create mode 100644 api/core/tests/urls.py create mode 100644 api/core/tests/views.py delete mode 100644 api/goods/filters.py delete mode 100644 api/organisations/filters.py diff --git a/api/conf/settings_test.py b/api/conf/settings_test.py index e28fcb845a..30b35d95be 100644 --- a/api/conf/settings_test.py +++ b/api/conf/settings_test.py @@ -8,3 +8,7 @@ SUPPRESS_TEST_OUTPUT = True AWS_ENDPOINT_URL = None + +INSTALLED_APPS += [ + "api.core.tests.apps.CoreTestsConfig", +] diff --git a/api/core/filters.py b/api/core/filters.py new file mode 100644 index 0000000000..de1948fecc --- /dev/null +++ b/api/core/filters.py @@ -0,0 +1,17 @@ +from rest_framework import filters + +from django.core.exceptions import ImproperlyConfigured + + +class ParentFilter(filters.BaseFilterBackend): + def filter_queryset(self, request, queryset, view): + parent_id_lookup_field = getattr(view, "parent_id_lookup_field", None) + if not parent_id_lookup_field: + raise ImproperlyConfigured( + f"Cannot use {self.__class__.__name__} on a view which does not have a parent_id_lookup_field attribute" + ) + + lookup = { + parent_id_lookup_field: view.kwargs["pk"], + } + return queryset.filter(**lookup) diff --git a/api/core/tests/apps.py b/api/core/tests/apps.py new file mode 100644 index 0000000000..b293bee58c --- /dev/null +++ b/api/core/tests/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class CoreTestsConfig(AppConfig): + name = "api.core.tests" + label = "api_core_tests" diff --git a/api/core/tests/migrations/0001_initial.py b/api/core/tests/migrations/0001_initial.py new file mode 100644 index 0000000000..33d79b1f55 --- /dev/null +++ b/api/core/tests/migrations/0001_initial.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.9 on 2024-02-09 14:48 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="ParentModel", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(max_length=255)), + ], + ), + migrations.CreateModel( + name="ChildModel", + fields=[ + ("id", models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("name", models.CharField(max_length=255)), + ( + "parent", + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="api_core_tests.parentmodel"), + ), + ], + ), + ] diff --git a/api/core/tests/migrations/__init__.py b/api/core/tests/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/tests/models.py b/api/core/tests/models.py new file mode 100644 index 0000000000..899a3b7455 --- /dev/null +++ b/api/core/tests/models.py @@ -0,0 +1,10 @@ +from django.db import models + + +class ParentModel(models.Model): + name = models.CharField(max_length=255) + + +class ChildModel(models.Model): + name = models.CharField(max_length=255) + parent = models.ForeignKey(ParentModel, on_delete=models.CASCADE) diff --git a/api/core/tests/serializers.py b/api/core/tests/serializers.py new file mode 100644 index 0000000000..c029a50307 --- /dev/null +++ b/api/core/tests/serializers.py @@ -0,0 +1,12 @@ +from rest_framework import serializers + +from api.core.tests.models import ChildModel + + +class ChildModelSerializer(serializers.ModelSerializer): + class Meta: + model = ChildModel + fields = ( + "id", + "name", + ) diff --git a/api/core/tests/test_filters.py b/api/core/tests/test_filters.py new file mode 100644 index 0000000000..2ac8dcefad --- /dev/null +++ b/api/core/tests/test_filters.py @@ -0,0 +1,62 @@ +import uuid + +from django.core.exceptions import ImproperlyConfigured +from django.test import ( + override_settings, + SimpleTestCase, + TestCase, +) +from django.urls import reverse + +from api.core.tests.models import ( + ChildModel, + ParentModel, +) + + +@override_settings( + ROOT_URLCONF="api.core.tests.urls", +) +class TestMisconfiguredParentFilter(SimpleTestCase): + def test_misconfigured_parent_filter(self): + url = reverse( + "test-misconfigured-parent-filter", + kwargs={ + "pk": str(uuid.uuid4()), + "child_pk": str(uuid.uuid4()), + }, + ) + with self.assertRaises(ImproperlyConfigured): + self.client.get(url) + + +@override_settings( + ROOT_URLCONF="api.core.tests.urls", +) +class TestParentFilter(TestCase): + def test_parent_filter(self): + parent = ParentModel.objects.create(name="parent") + child = ChildModel.objects.create(parent=parent, name="child") + url = reverse( + "test-parent-filter", + kwargs={ + "pk": str(parent.pk), + "child_pk": str(child.pk), + }, + ) + response = self.client.get(url) + self.assertEqual(response.status_code, 200) + + def test_parent_other_parent_filter(self): + parent = ParentModel.objects.create(name="parent") + child = ChildModel.objects.create(parent=parent, name="child") + other_parent = ParentModel.objects.create(name="other_parent") + url = reverse( + "test-parent-filter", + kwargs={ + "pk": str(other_parent.pk), + "child_pk": str(child.pk), + }, + ) + response = self.client.get(url) + self.assertEqual(response.status_code, 404) diff --git a/api/core/tests/urls.py b/api/core/tests/urls.py new file mode 100644 index 0000000000..90accb20f8 --- /dev/null +++ b/api/core/tests/urls.py @@ -0,0 +1,19 @@ +from django.urls import path + +from .views import ( + MisconfiguredParentFilterView, + ParentFilterView, +) + +urlpatterns = [ + path( + "misconfigured-parent//child//", + MisconfiguredParentFilterView.as_view(), + name="test-misconfigured-parent-filter", + ), + path( + "parent//child//", + ParentFilterView.as_view(), + name="test-parent-filter", + ), +] diff --git a/api/core/tests/views.py b/api/core/tests/views.py new file mode 100644 index 0000000000..e0da0c00fd --- /dev/null +++ b/api/core/tests/views.py @@ -0,0 +1,17 @@ +from rest_framework.generics import RetrieveAPIView + +from api.core.filters import ParentFilter +from api.core.tests.models import ChildModel +from api.core.tests.serializers import ChildModelSerializer + + +class MisconfiguredParentFilterView(RetrieveAPIView): + filter_backends = (ParentFilter,) + queryset = ChildModel.objects.all() + + +class ParentFilterView(RetrieveAPIView): + filter_backends = (ParentFilter,) + parent_id_lookup_field = "parent_id" + queryset = ChildModel.objects.all() + serializer_class = ChildModelSerializer diff --git a/api/goods/filters.py b/api/goods/filters.py deleted file mode 100644 index 45cefe7253..0000000000 --- a/api/goods/filters.py +++ /dev/null @@ -1,6 +0,0 @@ -from rest_framework import filters - - -class GoodFilter(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(good_id=view.kwargs["pk"]) diff --git a/api/goods/views.py b/api/goods/views.py index 05546ad342..8445dcdb98 100644 --- a/api/goods/views.py +++ b/api/goods/views.py @@ -18,11 +18,11 @@ from api.core.authentication import ExporterAuthentication, SharedAuthentication, GovAuthentication from api.core.exceptions import BadRequestError from api.core.helpers import str_to_bool +from api.core.filters import ParentFilter from api.core.views import DocumentStreamAPIView from api.documents.libraries.delete_documents_on_bad_request import delete_documents_on_bad_request from api.documents.models import Document from api.goods.enums import GoodStatus, GoodPvGraded, ItemCategory -from api.goods.filters import GoodFilter from api.goods.goods_paginator import GoodListPaginator from api.goods.helpers import ( FIREARMS_CORE_TYPES, @@ -547,7 +547,8 @@ def delete(self, request, pk, doc_pk): class GoodDocumentStream(DocumentStreamAPIView): authentication_classes = (ExporterAuthentication,) - filter_backends = (GoodFilter,) + filter_backends = (ParentFilter,) + parent_id_lookup_field = "good_id" lookup_url_kwarg = "doc_pk" queryset = GoodDocument.objects.all() permission_classes = ( diff --git a/api/organisations/filters.py b/api/organisations/filters.py deleted file mode 100644 index 58a86fab04..0000000000 --- a/api/organisations/filters.py +++ /dev/null @@ -1,6 +0,0 @@ -from rest_framework import filters - - -class OrganisationFilter(filters.BaseFilterBackend): - def filter_queryset(self, request, queryset, view): - return queryset.filter(organisation_id=view.kwargs["pk"]) diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 27ab30ba49..090bddb785 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -5,9 +5,9 @@ from api.audit_trail import service as audit_trail_service from api.audit_trail.enums import AuditType from api.core.authentication import SharedAuthentication +from api.core.filters import ParentFilter from api.core.views import DocumentStreamAPIView from api.organisations import ( - filters, models, permissions, serializers, @@ -16,8 +16,9 @@ class DocumentOnOrganisationView(viewsets.ModelViewSet): authentication_classes = (SharedAuthentication,) - filter_backends = (filters.OrganisationFilter,) + filter_backends = (ParentFilter,) lookup_url_kwarg = "document_on_application_pk" + parent_id_lookup_field = "organisation_id" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) queryset = models.DocumentOnOrganisation.objects.all() serializer_class = serializers.DocumentOnOrganisationSerializer @@ -86,8 +87,9 @@ def update(self, request, pk, document_on_application_pk): class DocumentOnOrganisationStreamView(DocumentStreamAPIView): authentication_classes = (SharedAuthentication,) - filter_backends = (filters.OrganisationFilter,) + filter_backends = (ParentFilter,) lookup_url_kwarg = "document_on_application_pk" + parent_id_lookup_field = "organisation_id" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) queryset = models.DocumentOnOrganisation.objects.all() From bc0b79e1cbc497374f5c51fef1d5a7b447d1e79d Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Mon, 12 Feb 2024 10:27:25 +0000 Subject: [PATCH 04/28] Set an explicit status for good on document stream test --- api/goods/tests/test_goods_documents.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/api/goods/tests/test_goods_documents.py b/api/goods/tests/test_goods_documents.py index 25222625be..0f60582469 100644 --- a/api/goods/tests/test_goods_documents.py +++ b/api/goods/tests/test_goods_documents.py @@ -133,7 +133,10 @@ def test_edit_product_document_description(self): class GoodDocumentStreamTests(DataTestClient): def setUp(self): super().setUp() - self.good = GoodFactory(organisation=self.organisation) + self.good = GoodFactory( + organisation=self.organisation, + status=GoodStatus.DRAFT, + ) self.create_default_bucket() self.put_object_in_default_bucket("thisisakey", b"test") From b874f403f68852b413b35661c81e902c68576263 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Mon, 12 Feb 2024 10:57:04 +0000 Subject: [PATCH 05/28] Update configuration attribute name for parent filter --- api/core/filters.py | 8 ++++---- api/core/tests/views.py | 2 +- api/goods/views.py | 2 +- api/organisations/views/documents.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/api/core/filters.py b/api/core/filters.py index de1948fecc..16eaf164cb 100644 --- a/api/core/filters.py +++ b/api/core/filters.py @@ -5,13 +5,13 @@ class ParentFilter(filters.BaseFilterBackend): def filter_queryset(self, request, queryset, view): - parent_id_lookup_field = getattr(view, "parent_id_lookup_field", None) - if not parent_id_lookup_field: + parent_filter_id_lookup_field = getattr(view, "parent_filter_id_lookup_field", None) + if not parent_filter_id_lookup_field: raise ImproperlyConfigured( - f"Cannot use {self.__class__.__name__} on a view which does not have a parent_id_lookup_field attribute" + f"Cannot use {self.__class__.__name__} on a view which does not have a parent_filter_id_lookup_field attribute" ) lookup = { - parent_id_lookup_field: view.kwargs["pk"], + parent_filter_id_lookup_field: view.kwargs["pk"], } return queryset.filter(**lookup) diff --git a/api/core/tests/views.py b/api/core/tests/views.py index e0da0c00fd..75964d3f84 100644 --- a/api/core/tests/views.py +++ b/api/core/tests/views.py @@ -12,6 +12,6 @@ class MisconfiguredParentFilterView(RetrieveAPIView): class ParentFilterView(RetrieveAPIView): filter_backends = (ParentFilter,) - parent_id_lookup_field = "parent_id" + parent_filter_id_lookup_field = "parent_id" queryset = ChildModel.objects.all() serializer_class = ChildModelSerializer diff --git a/api/goods/views.py b/api/goods/views.py index 8445dcdb98..7c55d26d7f 100644 --- a/api/goods/views.py +++ b/api/goods/views.py @@ -548,7 +548,7 @@ def delete(self, request, pk, doc_pk): class GoodDocumentStream(DocumentStreamAPIView): authentication_classes = (ExporterAuthentication,) filter_backends = (ParentFilter,) - parent_id_lookup_field = "good_id" + parent_filter_id_lookup_field = "good_id" lookup_url_kwarg = "doc_pk" queryset = GoodDocument.objects.all() permission_classes = ( diff --git a/api/organisations/views/documents.py b/api/organisations/views/documents.py index 090bddb785..c0ff504aa8 100644 --- a/api/organisations/views/documents.py +++ b/api/organisations/views/documents.py @@ -17,8 +17,8 @@ class DocumentOnOrganisationView(viewsets.ModelViewSet): authentication_classes = (SharedAuthentication,) filter_backends = (ParentFilter,) + parent_filter_id_lookup_field = "organisation_id" lookup_url_kwarg = "document_on_application_pk" - parent_id_lookup_field = "organisation_id" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) queryset = models.DocumentOnOrganisation.objects.all() serializer_class = serializers.DocumentOnOrganisationSerializer @@ -88,8 +88,8 @@ def update(self, request, pk, document_on_application_pk): class DocumentOnOrganisationStreamView(DocumentStreamAPIView): authentication_classes = (SharedAuthentication,) filter_backends = (ParentFilter,) + parent_filter_id_lookup_field = "organisation_id" lookup_url_kwarg = "document_on_application_pk" - parent_id_lookup_field = "organisation_id" permission_classes = (permissions.IsCaseworkerOrInDocumentOrganisation,) queryset = models.DocumentOnOrganisation.objects.all() From fc6ef9d053dcc64d620d5c4a0b3805851ad32ab9 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 13 Feb 2024 08:59:20 +0000 Subject: [PATCH 06/28] Update UK sanctions list URL --- api/conf/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/conf/settings.py b/api/conf/settings.py index 71812f98fa..aa9d555997 100644 --- a/api/conf/settings.py +++ b/api/conf/settings.py @@ -456,7 +456,7 @@ def _build_redis_url(base_url, db_number, **query_args): { "un_sanctions_file": "https://scsanctions.un.org/resources/xml/en/consolidated.xml", "office_financial_sanctions_file": "https://ofsistorage.blob.core.windows.net/publishlive/2022format/ConList.xml", - "uk_sanctions_file": "https://assets.publishing.service.gov.uk/government/uploads/system/uploads/attachment_data/file/1129559/UK_Sanctions_List.xml", + "uk_sanctions_file": "https://assets.publishing.service.gov.uk/media/65ca02639c5b7f0012951caf/UK_Sanctions_List.xml", }, ) LITE_INTERNAL_NOTIFICATION_EMAILS = env.json("LITE_INTERNAL_NOTIFICATION_EMAILS", {}) From b67ae3d5fd45f55044af06972e7a8008b4bada76 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Thu, 8 Feb 2024 15:24:13 +0000 Subject: [PATCH 07/28] Update view and view test --- api/cases/tests/test_case_ecju_queries.py | 23 ++++++++++++ api/cases/views/views.py | 45 ++++++++++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/api/cases/tests/test_case_ecju_queries.py b/api/cases/tests/test_case_ecju_queries.py index d0815d47a9..bbe3b89abf 100644 --- a/api/cases/tests/test_case_ecju_queries.py +++ b/api/cases/tests/test_case_ecju_queries.py @@ -675,3 +675,26 @@ def test_exporter_cannot_delete_documents_of_closed_query(self): self.assertEqual(response.status_code, status.HTTP_200_OK) response = response.json() self.assertIsNotNone(response["document"]["id"]) + + @parameterized.expand(["this is some response text", ""]) + def test_exporter_responding_to_query_creates_case_note_mention_for_caseworker(self, response_text): + case = self.create_standard_application_case(self.organisation) + url = reverse("cases:case_ecju_queries", kwargs={"pk": case.id}) + question_text = "this is the question text" + data = {"question": question_text, "query_type": ECJUQueryType.ECJU} + + response = self.client.post(url, data, **self.gov_headers) + response_data = response.json() + ecju_query = EcjuQuery.objects.get(case=case) + + self.assertEqual(status.HTTP_201_CREATED, response.status_code) + self.assertEqual(response_data["ecju_query_id"], str(ecju_query.id)) + self.assertEqual(question_text, ecju_query.question) + self.assertIsNone(ecju_query.response) + + url = reverse("cases:case_ecju_query", kwargs={"pk": case.id, "ecju_pk": ecju_query.id}) + data = {"response": response_text} + response = self.client.put(url, data, **self.exporter_headers) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(list(response.json().keys()), ["ecju_query", "case_note", "case_note_mentions"]) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 48803d4dd8..3a7d16b049 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -74,6 +74,8 @@ EcjuQueryUserResponseSerializer, EcjuQueryDocumentCreateSerializer, EcjuQueryDocumentViewSerializer, + CaseNoteSerializer, + CaseNoteMentionsSerializer, ) from api.cases.service import get_destinations from api.compliance.helpers import generate_compliance_site_case @@ -101,6 +103,8 @@ from api.users.libraries.get_user import get_user_by_pk from lite_content.lite_api import strings from lite_content.lite_api.strings import Documents, Cases +from api.users.enums import SystemUser +from api.users.models import BaseUser class CaseDetail(APIView): @@ -649,7 +653,46 @@ def put(self, request, pk, ecju_pk): target=serializer.instance.case, payload={"ecju_response": data.get("response")}, ) - return JsonResponse(data={"ecju_query": serializer.data}, status=status.HTTP_201_CREATED) + + # Create case note mention notification for case worker. + # LITE system is the user that creates the case note. + exporter_user_full_name = getattr(get_user_by_pk(request.user.pk), "full_name", "Exporter user") + case_note_data = { + "text": f"{exporter_user_full_name} has responded to a query.", + "case": serializer.instance.case.id, + "user": SystemUser.id, + } + case_note_serializer = CaseNoteSerializer(data=case_note_data) + if case_note_serializer.is_valid(): + case_note_serializer.save() + case_note_mentions_data = [ + {"user": ecju_query.raised_by_user.pk, "case_note": case_note_serializer.instance.id} + ] + case_note_mentions_serializer = CaseNoteMentionsSerializer( + data=case_note_mentions_data, + many=True, + ) + if case_note_mentions_serializer.is_valid(): + case_note_mentions_serializer.save() + audit_trail_service.create_system_user_audit( + verb=AuditType.CREATED_CASE_NOTE_WITH_MENTIONS, target=serializer.instance.case, payload={} + ) + return JsonResponse( + data={ + "ecju_query": serializer.data, + "case_note": case_note_serializer.data, + "case_note_mentions": case_note_mentions_serializer.data, + }, + status=status.HTTP_201_CREATED, + ) + else: + return JsonResponse( + data={"errors": case_note_mentions_serializer.errors}, status=status.HTTP_400_BAD_REQUEST + ) + else: + return JsonResponse( + data={"errors": case_note_serializer.errors}, status=status.HTTP_400_BAD_REQUEST + ) else: return JsonResponse(data={}, status=status.HTTP_200_OK) From 8fe625e3e2226c338416c8410eb0be92c2bc02c3 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Thu, 8 Feb 2024 16:53:59 +0000 Subject: [PATCH 08/28] Refactor mentions code into separate function --- api/cases/views/views.py | 79 +++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 38 deletions(-) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 3a7d16b049..6d3634ff63 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -104,7 +104,6 @@ from lite_content.lite_api import strings from lite_content.lite_api.strings import Documents, Cases from api.users.enums import SystemUser -from api.users.models import BaseUser class CaseDetail(APIView): @@ -656,48 +655,52 @@ def put(self, request, pk, ecju_pk): # Create case note mention notification for case worker. # LITE system is the user that creates the case note. - exporter_user_full_name = getattr(get_user_by_pk(request.user.pk), "full_name", "Exporter user") - case_note_data = { - "text": f"{exporter_user_full_name} has responded to a query.", - "case": serializer.instance.case.id, - "user": SystemUser.id, - } - case_note_serializer = CaseNoteSerializer(data=case_note_data) - if case_note_serializer.is_valid(): - case_note_serializer.save() - case_note_mentions_data = [ - {"user": ecju_query.raised_by_user.pk, "case_note": case_note_serializer.instance.id} - ] - case_note_mentions_serializer = CaseNoteMentionsSerializer( - data=case_note_mentions_data, - many=True, - ) - if case_note_mentions_serializer.is_valid(): - case_note_mentions_serializer.save() - audit_trail_service.create_system_user_audit( - verb=AuditType.CREATED_CASE_NOTE_WITH_MENTIONS, target=serializer.instance.case, payload={} - ) - return JsonResponse( - data={ - "ecju_query": serializer.data, - "case_note": case_note_serializer.data, - "case_note_mentions": case_note_mentions_serializer.data, - }, - status=status.HTTP_201_CREATED, - ) - else: - return JsonResponse( - data={"errors": case_note_mentions_serializer.errors}, status=status.HTTP_400_BAD_REQUEST - ) - else: - return JsonResponse( - data={"errors": case_note_serializer.errors}, status=status.HTTP_400_BAD_REQUEST - ) + mentions_data = self._create_case_note_mention(request, ecju_query, serializer) + if "errors" in mentions_data.keys(): + return JsonResponse(data=mentions_data) + + return JsonResponse( + data={"ecju_query": serializer.data, **mentions_data}, + status=status.HTTP_201_CREATED, + ) else: return JsonResponse(data={}, status=status.HTTP_200_OK) return JsonResponse(data={"errors": serializer.errors}, status=status.HTTP_400_BAD_REQUEST) + def _create_case_note_mention(self, request, ecju_query, ecju_query_serializer): + exporter_user_full_name = getattr(get_user_by_pk(request.user.pk), "full_name", "Exporter user") + case_note_data = { + "text": f"{exporter_user_full_name} has responded to a query.", + "case": ecju_query_serializer.instance.case.id, + "user": SystemUser.id, + } + case_note_serializer = CaseNoteSerializer(data=case_note_data) + if case_note_serializer.is_valid(): + case_note_serializer.save() + case_note_mentions_data = [ + {"user": ecju_query.raised_by_user.pk, "case_note": case_note_serializer.instance.id} + ] + case_note_mentions_serializer = CaseNoteMentionsSerializer( + data=case_note_mentions_data, + many=True, + ) + if case_note_mentions_serializer.is_valid(): + case_note_mentions_serializer.save() + audit_trail_service.create_system_user_audit( + verb=AuditType.CREATED_CASE_NOTE_WITH_MENTIONS, + target=ecju_query_serializer.instance.case, + payload={}, + ) + return { + "case_note": case_note_serializer.data, + "case_note_mentions": case_note_mentions_serializer.data, + } + else: + return {"errors": case_note_mentions_serializer.errors} + else: + return {"errors": case_note_serializer.errors} + class EcjuQueryAddDocument(APIView): authentication_classes = (ExporterAuthentication,) From e4ac6d37331e7732007138e497d28efe7907a905 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Thu, 8 Feb 2024 17:19:53 +0000 Subject: [PATCH 09/28] Fix audit message --- api/cases/views/views.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 6d3634ff63..964e4caa94 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -689,8 +689,9 @@ def _create_case_note_mention(self, request, ecju_query, ecju_query_serializer): case_note_mentions_serializer.save() audit_trail_service.create_system_user_audit( verb=AuditType.CREATED_CASE_NOTE_WITH_MENTIONS, + action_object=case_note_serializer.instance, target=ecju_query_serializer.instance.case, - payload={}, + payload={"mention_users": case_note_mentions_serializer.get_user_mention_names()}, ) return { "case_note": case_note_serializer.data, From 24839dd5689f0062d2b9c37036a3a2c7eb9cd48c Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Thu, 8 Feb 2024 17:52:32 +0000 Subject: [PATCH 10/28] Fix 400 error response --- api/cases/views/views.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 964e4caa94..29e67170f1 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -657,7 +657,7 @@ def put(self, request, pk, ecju_pk): # LITE system is the user that creates the case note. mentions_data = self._create_case_note_mention(request, ecju_query, serializer) if "errors" in mentions_data.keys(): - return JsonResponse(data=mentions_data) + return JsonResponse(data=mentions_data, status=status.HTTP_400_BAD_REQUEST) return JsonResponse( data={"ecju_query": serializer.data, **mentions_data}, From 11ec0ade429e385597e0427094af984e66801fee Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Fri, 9 Feb 2024 12:21:04 +0000 Subject: [PATCH 11/28] Rewrite to create mention using case model --- api/cases/models.py | 20 +++++++++- api/cases/tests/test_case_ecju_queries.py | 26 ++++++++++++- api/cases/views/views.py | 47 ++++------------------- 3 files changed, 51 insertions(+), 42 deletions(-) diff --git a/api/cases/models.py b/api/cases/models.py index d54dc7b9c2..5598a5c247 100644 --- a/api/cases/models.py +++ b/api/cases/models.py @@ -14,7 +14,6 @@ from queryable_properties.managers import QueryablePropertiesManager from queryable_properties.properties import queryable_property - from api.audit_trail.enums import AuditType from api.cases.enums import ( AdviceType, @@ -53,6 +52,7 @@ UserOrganisationRelationship, ExporterNotification, ) +from api.users.enums import SystemUser from lite_content.lite_api import strings denial_reasons_logger = logging.getLogger(settings.DENIAL_REASONS_DELETION_LOGGER) @@ -313,6 +313,24 @@ def set_sub_status(self, sub_status_id): payload={"sub_status": self.sub_status.name, "status": CaseStatusEnum.get_text(self.status.status)}, ) + def create_system_mention(self, case_note_text, mention_user): + """ + Create a LITE system mention e.g. exporter responded to an ECJU query + """ + from api.audit_trail import service as audit_trail_service + + case_note = CaseNote(text=case_note_text, case=self, user=BaseUser.objects.get(id=SystemUser.id)) + case_note.save() + case_note_mentions = CaseNoteMentions(user=mention_user, case_note=case_note) + case_note_mentions.save() + audit_payload = { + "mention_users": [f"{mention_user.full_name} ({mention_user.team.name})"], + "additional_text": case_note_text, + } + audit_trail_service.create_system_user_audit( + verb=AuditType.CREATED_CASE_NOTE_WITH_MENTIONS, action_object=case_note, target=self, payload=audit_payload + ) + class CaseQueue(TimestampableModel): case = models.ForeignKey(Case, related_name="casequeues", on_delete=models.DO_NOTHING) diff --git a/api/cases/tests/test_case_ecju_queries.py b/api/cases/tests/test_case_ecju_queries.py index bbe3b89abf..e6a3012634 100644 --- a/api/cases/tests/test_case_ecju_queries.py +++ b/api/cases/tests/test_case_ecju_queries.py @@ -23,6 +23,7 @@ from api.staticdata.statuses.libraries.get_case_status import get_case_status_by_status from test_helpers.clients import DataTestClient from api.users.tests.factories import ExporterUserFactory +from api.cases.models import CaseNoteMentions faker = Faker() @@ -679,6 +680,8 @@ def test_exporter_cannot_delete_documents_of_closed_query(self): @parameterized.expand(["this is some response text", ""]) def test_exporter_responding_to_query_creates_case_note_mention_for_caseworker(self, response_text): case = self.create_standard_application_case(self.organisation) + + # caseworker raises a query url = reverse("cases:case_ecju_queries", kwargs={"pk": case.id}) question_text = "this is the question text" data = {"question": question_text, "query_type": ECJUQueryType.ECJU} @@ -687,14 +690,35 @@ def test_exporter_responding_to_query_creates_case_note_mention_for_caseworker(s response_data = response.json() ecju_query = EcjuQuery.objects.get(case=case) + self.assertFalse(ecju_query.is_query_closed) self.assertEqual(status.HTTP_201_CREATED, response.status_code) self.assertEqual(response_data["ecju_query_id"], str(ecju_query.id)) self.assertEqual(question_text, ecju_query.question) self.assertIsNone(ecju_query.response) + # exporter responds to the query url = reverse("cases:case_ecju_query", kwargs={"pk": case.id, "ecju_pk": ecju_query.id}) data = {"response": response_text} + response = self.client.put(url, data, **self.exporter_headers) + ecju_query = EcjuQuery.objects.get(case=case) self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertEqual(list(response.json().keys()), ["ecju_query", "case_note", "case_note_mentions"]) + self.assertTrue(ecju_query.is_query_closed) + + # check case note mention is created + case_note_mentions = CaseNoteMentions.objects.first() + case_note = case_note_mentions.case_note + audit_object = Audit.objects.first() + + expected_gov_user = ecju_query.raised_by_user + expected_exporter_user = ecju_query.responded_by_user + expected_mention_users_text = f"{expected_gov_user.full_name} ({expected_gov_user.team.name})" + expected_case_note_text = f"{expected_exporter_user.get_full_name()} has responded to a query." + expected_audit_payload = ( + {"mention_users": [expected_mention_users_text], "additional_text": expected_case_note_text}, + ) + + self.assertEqual(case_note_mentions.user, expected_gov_user) + self.assertEqual(case_note.text, expected_case_note_text) + self.assertEqual(audit_object.payload, expected_audit_payload) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 29e67170f1..c7b44d4407 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -629,6 +629,8 @@ def put(self, request, pk, ecju_pk): data={"errors": "Enter a reason why you are closing the query"}, status=status.HTTP_400_BAD_REQUEST ) + exporter_user_full_name = getattr(get_user_by_pk(request.user.pk), "full_name", "Exporter user") + data = {"responded_by_user": str(request.user.pk)} if request.data.get("response"): @@ -653,14 +655,13 @@ def put(self, request, pk, ecju_pk): payload={"ecju_response": data.get("response")}, ) - # Create case note mention notification for case worker. - # LITE system is the user that creates the case note. - mentions_data = self._create_case_note_mention(request, ecju_query, serializer) - if "errors" in mentions_data.keys(): - return JsonResponse(data=mentions_data, status=status.HTTP_400_BAD_REQUEST) + ecju_query.case.create_system_mention( + case_note_text=f"{exporter_user_full_name} has responded to a query.", + mention_user=ecju_query.raised_by_user, + ) return JsonResponse( - data={"ecju_query": serializer.data, **mentions_data}, + data={"ecju_query": serializer.data}, status=status.HTTP_201_CREATED, ) else: @@ -668,40 +669,6 @@ def put(self, request, pk, ecju_pk): return JsonResponse(data={"errors": serializer.errors}, status=status.HTTP_400_BAD_REQUEST) - def _create_case_note_mention(self, request, ecju_query, ecju_query_serializer): - exporter_user_full_name = getattr(get_user_by_pk(request.user.pk), "full_name", "Exporter user") - case_note_data = { - "text": f"{exporter_user_full_name} has responded to a query.", - "case": ecju_query_serializer.instance.case.id, - "user": SystemUser.id, - } - case_note_serializer = CaseNoteSerializer(data=case_note_data) - if case_note_serializer.is_valid(): - case_note_serializer.save() - case_note_mentions_data = [ - {"user": ecju_query.raised_by_user.pk, "case_note": case_note_serializer.instance.id} - ] - case_note_mentions_serializer = CaseNoteMentionsSerializer( - data=case_note_mentions_data, - many=True, - ) - if case_note_mentions_serializer.is_valid(): - case_note_mentions_serializer.save() - audit_trail_service.create_system_user_audit( - verb=AuditType.CREATED_CASE_NOTE_WITH_MENTIONS, - action_object=case_note_serializer.instance, - target=ecju_query_serializer.instance.case, - payload={"mention_users": case_note_mentions_serializer.get_user_mention_names()}, - ) - return { - "case_note": case_note_serializer.data, - "case_note_mentions": case_note_mentions_serializer.data, - } - else: - return {"errors": case_note_mentions_serializer.errors} - else: - return {"errors": case_note_serializer.errors} - class EcjuQueryAddDocument(APIView): authentication_classes = (ExporterAuthentication,) From 448abe45e31e8195e5f1f894e13a3b94da3850d1 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Fri, 9 Feb 2024 13:54:55 +0000 Subject: [PATCH 12/28] Remove unused imports --- api/cases/views/views.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index c7b44d4407..1bf64cabe6 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -74,8 +74,6 @@ EcjuQueryUserResponseSerializer, EcjuQueryDocumentCreateSerializer, EcjuQueryDocumentViewSerializer, - CaseNoteSerializer, - CaseNoteMentionsSerializer, ) from api.cases.service import get_destinations from api.compliance.helpers import generate_compliance_site_case @@ -103,7 +101,6 @@ from api.users.libraries.get_user import get_user_by_pk from lite_content.lite_api import strings from lite_content.lite_api.strings import Documents, Cases -from api.users.enums import SystemUser class CaseDetail(APIView): From c63bd504b6e32518ab897f10ec608e22cdc9045f Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Fri, 9 Feb 2024 14:11:07 +0000 Subject: [PATCH 13/28] Fix test --- api/cases/tests/test_case_ecju_queries.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/api/cases/tests/test_case_ecju_queries.py b/api/cases/tests/test_case_ecju_queries.py index e6a3012634..10e400b361 100644 --- a/api/cases/tests/test_case_ecju_queries.py +++ b/api/cases/tests/test_case_ecju_queries.py @@ -715,9 +715,10 @@ def test_exporter_responding_to_query_creates_case_note_mention_for_caseworker(s expected_exporter_user = ecju_query.responded_by_user expected_mention_users_text = f"{expected_gov_user.full_name} ({expected_gov_user.team.name})" expected_case_note_text = f"{expected_exporter_user.get_full_name()} has responded to a query." - expected_audit_payload = ( - {"mention_users": [expected_mention_users_text], "additional_text": expected_case_note_text}, - ) + expected_audit_payload = { + "mention_users": [expected_mention_users_text], + "additional_text": expected_case_note_text, + } self.assertEqual(case_note_mentions.user, expected_gov_user) self.assertEqual(case_note.text, expected_case_note_text) From a93a97c29744545af9619403f7d5e1b6d6ea0ae4 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Mon, 12 Feb 2024 11:33:16 +0000 Subject: [PATCH 14/28] Refactor and move function to helpers --- api/cases/helpers.py | 25 ++++++++++++++++++++++++- api/cases/models.py | 19 ------------------- api/cases/views/views.py | 4 +++- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/api/cases/helpers.py b/api/cases/helpers.py index 3bfeb801d5..f82b39e03e 100644 --- a/api/cases/helpers.py +++ b/api/cases/helpers.py @@ -1,9 +1,11 @@ from datetime import timedelta +from api.audit_trail.enums import AuditType from api.common.dates import is_bank_holiday, is_weekend from api.cases.enums import CaseTypeReferenceEnum from api.staticdata.statuses.enums import CaseStatusEnum -from api.users.models import GovUser, GovNotification +from api.users.models import BaseUser, GovUser, GovNotification +from api.users.enums import SystemUser def get_assigned_to_user_case_ids(user: GovUser, queue_id=None): @@ -81,3 +83,24 @@ def can_set_status(case, status): def working_days_in_range(start_date, end_date): dates_in_range = [start_date + timedelta(n) for n in range((end_date - start_date).days)] return len([date for date in dates_in_range if (not is_bank_holiday(date) and not is_weekend(date))]) + + +def create_system_mention(case, case_note_text, mention_user): + """ + Create a LITE system mention e.g. exporter responded to an ECJU query + """ + # to avoid circular import ImportError these must be imported here + from api.cases.models import CaseNote, CaseNoteMentions + from api.audit_trail import service as audit_trail_service + + case_note = CaseNote(text=case_note_text, case=case, user=BaseUser.objects.get(id=SystemUser.id)) + case_note.save() + case_note_mentions = CaseNoteMentions(user=mention_user, case_note=case_note) + case_note_mentions.save() + audit_payload = { + "mention_users": [f"{mention_user.full_name} ({mention_user.team.name})"], + "additional_text": case_note_text, + } + audit_trail_service.create_system_user_audit( + verb=AuditType.CREATED_CASE_NOTE_WITH_MENTIONS, action_object=case_note, target=case, payload=audit_payload + ) diff --git a/api/cases/models.py b/api/cases/models.py index 5598a5c247..529651a457 100644 --- a/api/cases/models.py +++ b/api/cases/models.py @@ -52,7 +52,6 @@ UserOrganisationRelationship, ExporterNotification, ) -from api.users.enums import SystemUser from lite_content.lite_api import strings denial_reasons_logger = logging.getLogger(settings.DENIAL_REASONS_DELETION_LOGGER) @@ -313,24 +312,6 @@ def set_sub_status(self, sub_status_id): payload={"sub_status": self.sub_status.name, "status": CaseStatusEnum.get_text(self.status.status)}, ) - def create_system_mention(self, case_note_text, mention_user): - """ - Create a LITE system mention e.g. exporter responded to an ECJU query - """ - from api.audit_trail import service as audit_trail_service - - case_note = CaseNote(text=case_note_text, case=self, user=BaseUser.objects.get(id=SystemUser.id)) - case_note.save() - case_note_mentions = CaseNoteMentions(user=mention_user, case_note=case_note) - case_note_mentions.save() - audit_payload = { - "mention_users": [f"{mention_user.full_name} ({mention_user.team.name})"], - "additional_text": case_note_text, - } - audit_trail_service.create_system_user_audit( - verb=AuditType.CREATED_CASE_NOTE_WITH_MENTIONS, action_object=case_note, target=self, payload=audit_payload - ) - class CaseQueue(TimestampableModel): case = models.ForeignKey(Case, related_name="casequeues", on_delete=models.DO_NOTHING) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 1bf64cabe6..afcdd626df 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -27,6 +27,7 @@ ) from api.cases.generated_documents.models import GeneratedCaseDocument from api.cases.generated_documents.serializers import AdviceDocumentGovSerializer +from api.cases.helpers import create_system_mention from api.cases.libraries.advice import group_advice from api.cases.libraries.finalise import get_required_decision_document_types from api.cases.libraries.get_case import get_case, get_case_document @@ -652,7 +653,8 @@ def put(self, request, pk, ecju_pk): payload={"ecju_response": data.get("response")}, ) - ecju_query.case.create_system_mention( + create_system_mention( + case=ecju_query.case, case_note_text=f"{exporter_user_full_name} has responded to a query.", mention_user=ecju_query.raised_by_user, ) From 093b06d9fe54dbc95971cbf92fb8923aa4e51812 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Mon, 12 Feb 2024 14:33:24 +0000 Subject: [PATCH 15/28] Rewrite to raise error if exporter user not found --- api/cases/views/views.py | 69 ++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index afcdd626df..81301ab6be 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -11,7 +11,7 @@ from rest_framework.views import APIView from api.applications.models import GoodOnApplication -from api.users.models import BaseNotification +from api.users.models import BaseNotification, ExporterUser from api.applications.serializers.advice import ( CountersignAdviceSerializer, CountryWithFlagsSerializer, @@ -606,11 +606,6 @@ def get(self, request, pk, ecju_pk): return JsonResponse(data={"ecju_query": serializer.data}, status=status.HTTP_200_OK) def put(self, request, pk, ecju_pk): - """ - If not validate only Will update the ecju query instance, with a response, and return the data details. - If validate only, this will return if the data is acceptable or not. - """ - ecju_query = get_ecju_query(ecju_pk) if ecju_query.response: return JsonResponse( @@ -618,53 +613,51 @@ def put(self, request, pk, ecju_pk): status=status.HTTP_400_BAD_REQUEST, ) - is_govuser = hasattr(request.user, "govuser") + is_govuser_request = hasattr(request.user, "govuser") is_blank_response = not bool(request.data.get("response")) - # response is required only when a govuser closes a query - if is_govuser and is_blank_response: + if is_govuser_request and is_blank_response: return JsonResponse( data={"errors": "Enter a reason why you are closing the query"}, status=status.HTTP_400_BAD_REQUEST ) - exporter_user_full_name = getattr(get_user_by_pk(request.user.pk), "full_name", "Exporter user") - data = {"responded_by_user": str(request.user.pk)} - if request.data.get("response"): data.update({"response": request.data["response"]}) serializer = EcjuQueryUserResponseSerializer(instance=ecju_query, data=data, partial=True) - if serializer.is_valid(): - if "validate_only" not in request.data or not request.data["validate_only"]: - serializer.save() - # Delete any notifications against this query - ecju_query_type = ContentType.objects.get_for_model(EcjuQuery) - BaseNotification.objects.filter(object_id=ecju_pk, content_type=ecju_query_type).delete() + serializer.save() - # If the user is a Govuser query is manually being closed by a caseworker - query_verb = AuditType.ECJU_QUERY_MANUALLY_CLOSED if is_govuser else AuditType.ECJU_QUERY_RESPONSE - audit_trail_service.create( - actor=request.user, - verb=query_verb, - action_object=serializer.instance, - target=serializer.instance.case, - payload={"ecju_response": data.get("response")}, - ) + # Delete any notifications against this query + ecju_query_type = ContentType.objects.get_for_model(EcjuQuery) + BaseNotification.objects.filter(object_id=ecju_pk, content_type=ecju_query_type).delete() - create_system_mention( - case=ecju_query.case, - case_note_text=f"{exporter_user_full_name} has responded to a query.", - mention_user=ecju_query.raised_by_user, - ) + # If the user is a govuser then query is manually being closed by a case worker + query_verb = AuditType.ECJU_QUERY_MANUALLY_CLOSED if is_govuser_request else AuditType.ECJU_QUERY_RESPONSE + audit_trail_service.create( + actor=request.user, + verb=query_verb, + action_object=serializer.instance, + target=serializer.instance.case, + payload={"ecju_response": data.get("response")}, + ) - return JsonResponse( - data={"ecju_query": serializer.data}, - status=status.HTTP_201_CREATED, - ) - else: - return JsonResponse(data={}, status=status.HTTP_200_OK) + # if an exporter responds to a query, create a mention notification + # for the case worker that lets them know the query has been responded to + if not is_govuser_request: + try: + exporter_user_full_name = ExporterUser.objects.get(baseuser_ptr_id=request.user.pk).full_name + create_system_mention( + case=ecju_query.case, + case_note_text=f"{exporter_user_full_name} has responded to a query.", + mention_user=ecju_query.raised_by_user, + ) + except ExporterUser.DoesNotExist: + raise NotFoundError({"user": f"ExporterUser not found for pk: {request.user.pk}"}) + + # TODO change this to 200 + return JsonResponse(data={"ecju_query": serializer.data}, status=status.HTTP_201_CREATED) return JsonResponse(data={"errors": serializer.errors}, status=status.HTTP_400_BAD_REQUEST) From 700a08c1b14f68f6882663df1b8db1f21887abf0 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Mon, 12 Feb 2024 14:38:04 +0000 Subject: [PATCH 16/28] Change HTTP_201_CREATED to HTTP_200_OK --- api/cases/tests/test_case_ecju_queries.py | 18 +++++++++--------- api/cases/views/views.py | 3 +-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/api/cases/tests/test_case_ecju_queries.py b/api/cases/tests/test_case_ecju_queries.py index 10e400b361..b4f4926577 100644 --- a/api/cases/tests/test_case_ecju_queries.py +++ b/api/cases/tests/test_case_ecju_queries.py @@ -466,7 +466,7 @@ def _test_exporter_responds_to_query(self, add_documents, query_type): query_response_url = reverse("cases:case_ecju_query", kwargs={"pk": case.id, "ecju_pk": ecju_query.id}) data = {"response": "Attached the requested documents"} response = self.client.put(query_response_url, data, **self.exporter_headers) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = response.json()["ecju_query"] self.assertEqual(response["response"], data["response"]) @@ -499,7 +499,7 @@ def test_caseworker_manually_closes_query(self): self.assertEqual(1, BaseNotification.objects.filter(object_id=ecju_query.id).count()) response = self.client.put(query_response_url, data, **self.gov_headers) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = response.json()["ecju_query"] self.assertEqual(response["response"], data["response"]) @@ -521,7 +521,7 @@ def test_close_query_has_optional_response_exporter(self): data = {"response": ""} response = self.client.put(query_response_url, data, **self.exporter_headers) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) response_ecju_query = response.json()["ecju_query"] self.assertIsNone(response_ecju_query["response"]) @@ -551,7 +551,7 @@ def test_caseworker_manually_closes_query_exporter_responds_raises_error(self): data = {"response": "exporter provided details"} response = self.client.put(query_response_url, data, **self.gov_headers) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = response.json()["ecju_query"] self.assertEqual(response["response"], data["response"]) @@ -568,7 +568,7 @@ def test_caseworker_manually_closes_query_already_closed_raises_error(self): data = {"response": "exporter provided details"} response = self.client.put(query_response_url, data, **self.exporter_headers) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = response.json()["ecju_query"] self.assertEqual(response["response"], data["response"]) @@ -595,7 +595,7 @@ def test_exporter_cannot_respond_to_same_ecju_query_twice(self): url = reverse("cases:case_ecju_query", kwargs={"pk": case.id, "ecju_pk": ecju_query.id}) data = {"response": "Additional details included"} response = self.client.put(url, data, **self.exporter_headers) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = response.json()["ecju_query"] self.assertEqual(response["response"], data["response"]) @@ -631,7 +631,7 @@ def test_exporter_cannot_add_documents_to_closed_query(self): query_response_url = reverse("cases:case_ecju_query", kwargs={"pk": case.id, "ecju_pk": ecju_query.id}) data = {"response": "Attached the requested documents"} response = self.client.put(query_response_url, data, **self.exporter_headers) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = response.json()["ecju_query"] self.assertEqual(response["response"], data["response"]) self.assertEqual(len(response["documents"]), 1) @@ -660,7 +660,7 @@ def test_exporter_cannot_delete_documents_of_closed_query(self): url = reverse("cases:case_ecju_query", kwargs={"pk": case.id, "ecju_pk": ecju_query.id}) data = {"response": "Additional details included"} response = self.client.put(url, data, **self.exporter_headers) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) response = response.json()["ecju_query"] self.assertEqual(response["response"], data["response"]) self.assertEqual(len(response["documents"]), 1) @@ -703,7 +703,7 @@ def test_exporter_responding_to_query_creates_case_note_mention_for_caseworker(s response = self.client.put(url, data, **self.exporter_headers) ecju_query = EcjuQuery.objects.get(case=case) - self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertTrue(ecju_query.is_query_closed) # check case note mention is created diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 81301ab6be..95f1e09592 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -656,8 +656,7 @@ def put(self, request, pk, ecju_pk): except ExporterUser.DoesNotExist: raise NotFoundError({"user": f"ExporterUser not found for pk: {request.user.pk}"}) - # TODO change this to 200 - return JsonResponse(data={"ecju_query": serializer.data}, status=status.HTTP_201_CREATED) + return JsonResponse(data={"ecju_query": serializer.data}, status=status.HTTP_200_OK) return JsonResponse(data={"errors": serializer.errors}, status=status.HTTP_400_BAD_REQUEST) From a6a44a228ce9df878531bb99a762ce0cb933a564 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Mon, 12 Feb 2024 14:43:00 +0000 Subject: [PATCH 17/28] Add docstring --- api/cases/views/views.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 95f1e09592..65355a0bc0 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -606,6 +606,9 @@ def get(self, request, pk, ecju_pk): return JsonResponse(data={"ecju_query": serializer.data}, status=status.HTTP_200_OK) def put(self, request, pk, ecju_pk): + """ + Update an ECJU query to be closed + """ ecju_query = get_ecju_query(ecju_pk) if ecju_query.response: return JsonResponse( From 24393c2af19ab5873e9ece1b0f91e09bfc4ca274 Mon Sep 17 00:00:00 2001 From: Henry Cooksley Date: Tue, 13 Feb 2024 09:26:39 +0000 Subject: [PATCH 18/28] Remove try-except block --- api/cases/tests/test_case_ecju_queries.py | 1 + api/cases/views/views.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/api/cases/tests/test_case_ecju_queries.py b/api/cases/tests/test_case_ecju_queries.py index b4f4926577..fb236eb65c 100644 --- a/api/cases/tests/test_case_ecju_queries.py +++ b/api/cases/tests/test_case_ecju_queries.py @@ -15,6 +15,7 @@ from api.audit_trail.serializers import AuditSerializer from api.cases.enums import ECJUQueryType from api.cases.models import EcjuQuery +from api.core.exceptions import NotFoundError from api.compliance.tests.factories import ComplianceSiteCaseFactory from api.licences.enums import LicenceStatus from api.licences.tests.factories import StandardLicenceFactory diff --git a/api/cases/views/views.py b/api/cases/views/views.py index 65355a0bc0..b00412c126 100644 --- a/api/cases/views/views.py +++ b/api/cases/views/views.py @@ -646,18 +646,15 @@ def put(self, request, pk, ecju_pk): payload={"ecju_response": data.get("response")}, ) - # if an exporter responds to a query, create a mention notification + # If an exporter responds to a query, create a mention notification # for the case worker that lets them know the query has been responded to if not is_govuser_request: - try: - exporter_user_full_name = ExporterUser.objects.get(baseuser_ptr_id=request.user.pk).full_name - create_system_mention( - case=ecju_query.case, - case_note_text=f"{exporter_user_full_name} has responded to a query.", - mention_user=ecju_query.raised_by_user, - ) - except ExporterUser.DoesNotExist: - raise NotFoundError({"user": f"ExporterUser not found for pk: {request.user.pk}"}) + exporter_user_full_name = ExporterUser.objects.get(baseuser_ptr_id=request.user.pk).full_name + create_system_mention( + case=ecju_query.case, + case_note_text=f"{exporter_user_full_name} has responded to a query.", + mention_user=ecju_query.raised_by_user, + ) return JsonResponse(data={"ecju_query": serializer.data}, status=status.HTTP_200_OK) From 33160865ec06190a9eb2c5cd090e4c1119eabc3e Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Mon, 12 Feb 2024 18:00:51 +0000 Subject: [PATCH 19/28] Add app to backup document data in DB --- api/conf/celery.py | 4 ++ api/conf/settings.py | 1 + api/document_data/__init__.py | 0 api/document_data/celery_tasks.py | 50 ++++++++++++++++++++ api/document_data/migrations/0001_initial.py | 40 ++++++++++++++++ api/document_data/migrations/__init__.py | 0 api/document_data/models.py | 12 +++++ 7 files changed, 107 insertions(+) create mode 100644 api/document_data/__init__.py create mode 100644 api/document_data/celery_tasks.py create mode 100644 api/document_data/migrations/0001_initial.py create mode 100644 api/document_data/migrations/__init__.py create mode 100644 api/document_data/models.py diff --git a/api/conf/celery.py b/api/conf/celery.py index d73c1f9232..f13bd48904 100644 --- a/api/conf/celery.py +++ b/api/conf/celery.py @@ -27,4 +27,8 @@ "task": "api.cases.celery_tasks.update_cases_sla", "schedule": crontab(hour=22, minute=30), }, + "backup document data 2am": { + "task": "api.document_data.celery_tasks.backup_document_data", + "schedule": crontab(hour=2, minute=0), + }, } diff --git a/api/conf/settings.py b/api/conf/settings.py index aa9d555997..fe5ea12b3d 100644 --- a/api/conf/settings.py +++ b/api/conf/settings.py @@ -118,6 +118,7 @@ "lite_routing", "api.appeals", "api.assessments", + "api.document_data", ] MIDDLEWARE = [ diff --git a/api/document_data/__init__.py b/api/document_data/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/document_data/celery_tasks.py b/api/document_data/celery_tasks.py new file mode 100644 index 0000000000..4f6a18e1b1 --- /dev/null +++ b/api/document_data/celery_tasks.py @@ -0,0 +1,50 @@ +from botocore.exceptions import ClientError +from celery import shared_task +from celery.utils.log import get_task_logger + +from api.documents.libraries.s3_operations import get_object +from api.documents.models import Document +from api.document_data.models import DocumentData + + +logger = get_task_logger(__name__) + + +MAX_ATTEMPTS = 3 +RETRY_BACKOFF = 180 + + +@shared_task( + autoretry_for=(Exception,), + max_retries=MAX_ATTEMPTS, + retry_backoff=RETRY_BACKOFF, +) +def backup_document_data(): + """Backup document data into the database.""" + for document in Document.objects.filter(safe=True): + try: + file = get_object(document.id, document.s3_key) + except ClientError: + logger.warning(f"Failed to retrieve file '{document.s3_key}' from S3 for document '{document.id}'") + continue + + if not file: + logger.warning(f"Failed to retrieve file '{document.s3_key}' from S3 for document '{document.id}'") + continue + + try: + document_data = DocumentData.objects.get(s3_key=document.s3_key) + except DocumentData.DoesNotExist: + DocumentData.objects.create( + data=file["Body"].read(), + last_modified=file["LastModified"], + s3_key=document.s3_key, + ) + logger.info(f"Created '{document.s3_key}' for document '{document.id}'") + continue + + if file["LastModified"] > document_data.last_modified: + document_data.last_modified = file["LastModified"] + document_data.data = file["Body"].read() + document_data.save() + logger.info(f"Updated '{document.s3_key}' for document '{document.id}'") diff --git a/api/document_data/migrations/0001_initial.py b/api/document_data/migrations/0001_initial.py new file mode 100644 index 0000000000..77dca1a822 --- /dev/null +++ b/api/document_data/migrations/0001_initial.py @@ -0,0 +1,40 @@ +# Generated by Django 4.2.9 on 2024-02-12 17:51 + +from django.db import migrations, models +import django.utils.timezone +import model_utils.fields +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="DocumentData", + fields=[ + ( + "created_at", + model_utils.fields.AutoCreatedField( + default=django.utils.timezone.now, editable=False, verbose_name="created_at" + ), + ), + ( + "updated_at", + model_utils.fields.AutoLastModifiedField( + default=django.utils.timezone.now, editable=False, verbose_name="updated_at" + ), + ), + ("id", models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ("s3_key", models.CharField(max_length=1000)), + ("data", models.BinaryField()), + ("last_modified", models.DateTimeField()), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/api/document_data/migrations/__init__.py b/api/document_data/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/document_data/models.py b/api/document_data/models.py new file mode 100644 index 0000000000..acd5b5d830 --- /dev/null +++ b/api/document_data/models.py @@ -0,0 +1,12 @@ +import uuid + +from django.db import models + +from api.common.models import TimestampableModel + + +class DocumentData(TimestampableModel): + id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) + s3_key = models.CharField(max_length=1000) + data = models.BinaryField() + last_modified = models.DateTimeField() From f85898c1a7229c759be5727955e298a0ef3bea47 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Mon, 12 Feb 2024 18:59:22 +0000 Subject: [PATCH 20/28] Set distinct levels of logging when backing up document data --- api/document_data/celery_tasks.py | 57 +++++++++++++++++++++--- api/documents/libraries/s3_operations.py | 17 +++---- 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/api/document_data/celery_tasks.py b/api/document_data/celery_tasks.py index 4f6a18e1b1..74f6a09f4f 100644 --- a/api/document_data/celery_tasks.py +++ b/api/document_data/celery_tasks.py @@ -21,15 +21,45 @@ ) def backup_document_data(): """Backup document data into the database.""" - for document in Document.objects.filter(safe=True): + + # When running this command by hand it's best to set the logging as follows: + # import logging + # from api.document_data.celery_tasks import logger + # logger.setLevel(logging.DEBUG) + # from api.documents.libraries.s3_operations import logger + # logger.setLevel(logging.WARNING) + # + # This will ensure that you get the debug output of this particular file but + # miss the extra info from the get_object call + + safe_documents = Document.objects.filter(safe=True) + count = safe_documents.count() + logger.debug( + "Backing up %s documents", + count, + ) + for index, document in enumerate(safe_documents, start=1): + logger.debug( + "Processing %s of %s", + index, + count, + ) try: file = get_object(document.id, document.s3_key) except ClientError: - logger.warning(f"Failed to retrieve file '{document.s3_key}' from S3 for document '{document.id}'") + logger.warning( + "Failed to retrieve file '%s' from S3 for document '%s'", + document.s3_key, + document.id, + ) continue if not file: - logger.warning(f"Failed to retrieve file '{document.s3_key}' from S3 for document '{document.id}'") + logger.warning( + "Failed to retrieve file '%s' from S3 for document '%s'", + document.s3_key, + document.id, + ) continue try: @@ -40,11 +70,28 @@ def backup_document_data(): last_modified=file["LastModified"], s3_key=document.s3_key, ) - logger.info(f"Created '{document.s3_key}' for document '{document.id}'") + logger.info( + "Created '%s' for document '%s'", + document.s3_key, + document.id, + ) continue if file["LastModified"] > document_data.last_modified: document_data.last_modified = file["LastModified"] document_data.data = file["Body"].read() document_data.save() - logger.info(f"Updated '{document.s3_key}' for document '{document.id}'") + logger.info( + "Updated '%s' for document '%s'", + document.s3_key, + document.id, + ) + continue + + logger.debug( + "Nothing required for '%s' for document '%s'", + document.s3_key, + document.id, + ) + + logger.debug("Completed backing up documents") diff --git a/api/documents/libraries/s3_operations.py b/api/documents/libraries/s3_operations.py index 0942d45151..d60a408380 100644 --- a/api/documents/libraries/s3_operations.py +++ b/api/documents/libraries/s3_operations.py @@ -10,6 +10,9 @@ from django.http import FileResponse +logger = logging.getLogger(__name__) + + _client = None @@ -35,14 +38,14 @@ def init_s3_client(): def get_object(document_id, s3_key): - logging.info(f"Retrieving file '{s3_key}' on document '{document_id}'") + logger.info(f"Retrieving file '{s3_key}' on document '{document_id}'") try: return _client.get_object(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Key=s3_key) except ReadTimeoutError: - logging.warning(f"Timeout exceeded when retrieving file '{s3_key}' on document '{document_id}'") + logger.warning(f"Timeout exceeded when retrieving file '{s3_key}' on document '{document_id}'") except BotoCoreError as exc: - logging.warning( + logger.warning( f"An unexpected error occurred when retrieving file '{s3_key}' on document '{document_id}': {exc}" ) @@ -56,16 +59,14 @@ def upload_bytes_file(raw_file, s3_key): def delete_file(document_id, s3_key): - logging.info(f"Deleting file '{s3_key}' on document '{document_id}'") + logger.info(f"Deleting file '{s3_key}' on document '{document_id}'") try: _client.delete_object(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Key=s3_key) except ReadTimeoutError: - logging.warning(f"Timeout exceeded when retrieving file '{s3_key}' on document '{document_id}'") + logger.warning(f"Timeout exceeded when retrieving file '{s3_key}' on document '{document_id}'") except BotoCoreError as exc: - logging.warning( - f"An unexpected error occurred when deleting file '{s3_key}' on document '{document_id}': {exc}" - ) + logger.warning(f"An unexpected error occurred when deleting file '{s3_key}' on document '{document_id}': {exc}") def document_download_stream(document): From 7546e7abb26f947844f95edde6c387651b6538ff Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Mon, 12 Feb 2024 19:28:15 +0000 Subject: [PATCH 21/28] Add tests for backup document data task --- api/document_data/tests/__init__.py | 0 api/document_data/tests/test_celery_tasks.py | 179 +++++++++++++++++++ test_helpers/clients.py | 7 + 3 files changed, 186 insertions(+) create mode 100644 api/document_data/tests/__init__.py create mode 100644 api/document_data/tests/test_celery_tasks.py diff --git a/api/document_data/tests/__init__.py b/api/document_data/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/document_data/tests/test_celery_tasks.py b/api/document_data/tests/test_celery_tasks.py new file mode 100644 index 0000000000..36feaae7e7 --- /dev/null +++ b/api/document_data/tests/test_celery_tasks.py @@ -0,0 +1,179 @@ +import datetime + +from unittest import mock + +from moto import mock_aws + +from botocore.exceptions import ClientError + +from django.utils import timezone + +from test_helpers.clients import DataTestClient + +from api.documents.tests.factories import DocumentFactory +from api.document_data.celery_tasks import backup_document_data +from api.document_data.models import DocumentData + + +@mock_aws +class TestBackupDocumentData(DataTestClient): + def setUp(self): + super().setUp() + self.create_default_bucket() + + def test_backup_new_document_data(self): + self.put_object_in_default_bucket("thisisakey", b"test") + DocumentFactory.create( + s3_key="thisisakey", + safe=True, + ) + self.assertEqual( + DocumentData.objects.count(), + 0, + ) + + backup_document_data() + + self.assertEqual( + DocumentData.objects.count(), + 1, + ) + document_data = DocumentData.objects.get() + self.assertEqual( + document_data.s3_key, + "thisisakey", + ) + self.assertEqual( + document_data.data.tobytes(), + b"test", + ) + s3_object = self.get_object_from_default_bucket("thisisakey") + self.assertEqual( + document_data.last_modified, + s3_object["LastModified"], + ) + + def test_update_existing_document_data(self): + self.put_object_in_default_bucket("thisisakey", b"test") + DocumentFactory.create( + s3_key="thisisakey", + safe=True, + ) + self.assertEqual( + DocumentData.objects.count(), + 0, + ) + + backup_document_data() + self.assertEqual( + DocumentData.objects.count(), + 1, + ) + + document_data = DocumentData.objects.get() + document_data.last_modified = timezone.now() - datetime.timedelta(days=5) + document_data.save() + self.put_object_in_default_bucket("thisisakey", b"new contents") + + backup_document_data() + + self.assertEqual( + DocumentData.objects.count(), + 1, + ) + document_data = DocumentData.objects.get() + self.assertEqual( + document_data.s3_key, + "thisisakey", + ) + self.assertEqual( + document_data.data.tobytes(), + b"new contents", + ) + s3_object = self.get_object_from_default_bucket("thisisakey") + self.assertEqual( + document_data.last_modified, + s3_object["LastModified"], + ) + + def test_leave_existing_document_data(self): + self.put_object_in_default_bucket("thisisakey", b"test") + DocumentFactory.create( + s3_key="thisisakey", + safe=True, + ) + self.assertEqual( + DocumentData.objects.count(), + 0, + ) + + backup_document_data() + self.assertEqual( + DocumentData.objects.count(), + 1, + ) + + document_data = DocumentData.objects.get() + document_data.last_modified = original_last_modified = timezone.now() + datetime.timedelta(days=5) + document_data.save() + self.put_object_in_default_bucket("thisisakey", b"new contents") + + backup_document_data() + + self.assertEqual( + DocumentData.objects.count(), + 1, + ) + document_data = DocumentData.objects.get() + self.assertEqual( + document_data.s3_key, + "thisisakey", + ) + self.assertEqual( + document_data.data.tobytes(), + b"test", + ) + self.assertEqual( + document_data.last_modified, + original_last_modified, + ) + + @mock.patch("api.document_data.celery_tasks.get_object") + def test_ignore_client_error(self, mock_get_object): + mock_get_object.side_effect = ClientError({}, "fake operation") + + self.put_object_in_default_bucket("thisisakey", b"test") + DocumentFactory.create( + s3_key="thisisakey", + safe=True, + ) + self.assertEqual( + DocumentData.objects.count(), + 0, + ) + + backup_document_data() + self.assertEqual( + DocumentData.objects.count(), + 0, + ) + + @mock.patch("api.document_data.celery_tasks.get_object") + def test_ignore_get_object_returning_none(self, mock_get_object): + mock_get_object.return_value = None + + self.put_object_in_default_bucket("thisisakey", b"test") + DocumentFactory.create( + s3_key="thisisakey", + safe=True, + ) + self.assertEqual( + DocumentData.objects.count(), + 0, + ) + + backup_document_data() + self.assertEqual( + DocumentData.objects.count(), + 0, + ) diff --git a/test_helpers/clients.py b/test_helpers/clients.py index ac75b20839..9d8c917a48 100644 --- a/test_helpers/clients.py +++ b/test_helpers/clients.py @@ -1047,6 +1047,13 @@ def put_object_in_default_bucket(self, key, body): Body=body, ) + def get_object_from_default_bucket(self, key): + s3 = init_s3_client() + return s3.get_object( + Bucket=settings.AWS_STORAGE_BUCKET_NAME, + Key=key, + ) + @pytest.mark.performance # we need to set debug to true otherwise we can't see the amount of queries From 1bb23f762574f13e66db28a63a4c3477c5614cd9 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 13 Feb 2024 07:29:12 +0000 Subject: [PATCH 22/28] Allow backing up document data to db to be skipped via env variable --- api/conf/settings.py | 2 ++ api/document_data/celery_tasks.py | 6 ++++++ api/document_data/tests/test_celery_tasks.py | 22 ++++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/api/conf/settings.py b/api/conf/settings.py index fe5ea12b3d..32760c48ad 100644 --- a/api/conf/settings.py +++ b/api/conf/settings.py @@ -484,3 +484,5 @@ def _build_redis_url(base_url, db_number, **query_args): CONTENT_DATA_MIGRATION_DIR = Path(BASE_DIR).parent / "lite_content/lite_api/migrations" + +BACKUP_DOCUMENT_DATA_TO_DB = env("BACKUP_DOCUMENT_DATA_TO_DB", default=True) diff --git a/api/document_data/celery_tasks.py b/api/document_data/celery_tasks.py index 74f6a09f4f..34988fbcaa 100644 --- a/api/document_data/celery_tasks.py +++ b/api/document_data/celery_tasks.py @@ -2,6 +2,8 @@ from celery import shared_task from celery.utils.log import get_task_logger +from django.conf import settings + from api.documents.libraries.s3_operations import get_object from api.documents.models import Document from api.document_data.models import DocumentData @@ -32,6 +34,10 @@ def backup_document_data(): # This will ensure that you get the debug output of this particular file but # miss the extra info from the get_object call + if not settings.BACKUP_DOCUMENT_DATA_TO_DB: + logger.info("Skipping backup document data to db") + return + safe_documents = Document.objects.filter(safe=True) count = safe_documents.count() logger.debug( diff --git a/api/document_data/tests/test_celery_tasks.py b/api/document_data/tests/test_celery_tasks.py index 36feaae7e7..9459fc6d4b 100644 --- a/api/document_data/tests/test_celery_tasks.py +++ b/api/document_data/tests/test_celery_tasks.py @@ -6,6 +6,7 @@ from botocore.exceptions import ClientError +from django.test import override_settings from django.utils import timezone from test_helpers.clients import DataTestClient @@ -177,3 +178,24 @@ def test_ignore_get_object_returning_none(self, mock_get_object): DocumentData.objects.count(), 0, ) + + @override_settings( + BACKUP_DOCUMENT_DATA_TO_DB=False, + ) + def test_stop_backup_new_document_data(self): + self.put_object_in_default_bucket("thisisakey", b"test") + DocumentFactory.create( + s3_key="thisisakey", + safe=True, + ) + self.assertEqual( + DocumentData.objects.count(), + 0, + ) + + backup_document_data() + + self.assertEqual( + DocumentData.objects.count(), + 0, + ) From e3e19d2a11b1195308b14898ec50a17834bf5025 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 13 Feb 2024 07:44:20 +0000 Subject: [PATCH 23/28] Add tests for error paths on s3 operations --- .../libraries/tests/test_s3_operations.py | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/api/documents/libraries/tests/test_s3_operations.py b/api/documents/libraries/tests/test_s3_operations.py index 53d58e6fb3..61603da4e8 100644 --- a/api/documents/libraries/tests/test_s3_operations.py +++ b/api/documents/libraries/tests/test_s3_operations.py @@ -1,8 +1,15 @@ +import logging + from contextlib import contextmanager from unittest.mock import Mock, patch from moto import mock_aws +from botocore.exceptions import ( + BotoCoreError, + ReadTimeoutError, +) + from django.conf import settings from django.http import FileResponse from django.test import override_settings, SimpleTestCase @@ -92,6 +99,40 @@ def test_get_object(self, mock_client): self.assertEqual(returned_object, mock_object) mock_client.get_object.assert_called_with(Bucket="test-bucket", Key="s3-key") + @patch("api.documents.libraries.s3_operations._client") + def test_get_object_read_timeout_error(self, mock_client): + mock_client.get_object.side_effect = ReadTimeoutError( + endpoint_url="endpoint_url", + ) + + with self.assertLogs( + "api.documents.libraries.s3_operations", + logging.WARNING, + ) as al: + returned_object = get_object("document-id", "s3-key") + + self.assertIsNone(returned_object) + self.assertIn( + "WARNING:api.documents.libraries.s3_operations:Timeout exceeded when retrieving file 's3-key' on document 'document-id'", + al.output, + ) + + @patch("api.documents.libraries.s3_operations._client") + def test_get_object_boto_core_error(self, mock_client): + mock_client.get_object.side_effect = BotoCoreError() + + with self.assertLogs( + "api.documents.libraries.s3_operations", + logging.WARNING, + ) as al: + returned_object = get_object("document-id", "s3-key") + + self.assertIsNone(returned_object) + self.assertIn( + "WARNING:api.documents.libraries.s3_operations:An unexpected error occurred when retrieving file 's3-key' on document 'document-id': An unspecified error occurred", + al.output, + ) + @contextmanager def _create_bucket(s3): @@ -121,6 +162,38 @@ def test_delete_file(self): keys = [o["Key"] for o in objs.get("Contents", [])] self.assertNotIn("s3-key", keys) + @patch("api.documents.libraries.s3_operations._client") + def test_delete_file_read_timeout_error(self, mock_client): + mock_client.delete_object.side_effect = ReadTimeoutError( + endpoint_url="endpoint_url", + ) + + with self.assertLogs( + "api.documents.libraries.s3_operations", + logging.WARNING, + ) as al: + delete_file("document-id", "s3-key") + + self.assertIn( + "WARNING:api.documents.libraries.s3_operations:Timeout exceeded when retrieving file 's3-key' on document 'document-id'", + al.output, + ) + + @patch("api.documents.libraries.s3_operations._client") + def test_delete_file_boto_core_error(self, mock_client): + mock_client.delete_object.side_effect = BotoCoreError() + + with self.assertLogs( + "api.documents.libraries.s3_operations", + logging.WARNING, + ) as al: + delete_file("document-id", "s3-key") + + self.assertIn( + "WARNING:api.documents.libraries.s3_operations:An unexpected error occurred when deleting file 's3-key' on document 'document-id': An unspecified error occurred", + al.output, + ) + @mock_aws class S3OperationsUploadBytesFileTests(SimpleTestCase): From a9828c3165739445645cc89ac5e4d95a73ef8e20 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 13 Feb 2024 08:53:17 +0000 Subject: [PATCH 24/28] Use logging formatting correctly in s3 operations --- api/documents/libraries/s3_operations.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/api/documents/libraries/s3_operations.py b/api/documents/libraries/s3_operations.py index d60a408380..5253a17685 100644 --- a/api/documents/libraries/s3_operations.py +++ b/api/documents/libraries/s3_operations.py @@ -43,10 +43,17 @@ def get_object(document_id, s3_key): try: return _client.get_object(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Key=s3_key) except ReadTimeoutError: - logger.warning(f"Timeout exceeded when retrieving file '{s3_key}' on document '{document_id}'") + logger.warning( + "Timeout exceeded when retrieving file '%s' on document '%s'", + s3_key, + document_id, + ) except BotoCoreError as exc: logger.warning( - f"An unexpected error occurred when retrieving file '{s3_key}' on document '{document_id}': {exc}" + "An unexpected error occurred when retrieving file '%s' on document '%s': %s", + s3_key, + document_id, + exc, ) @@ -64,9 +71,18 @@ def delete_file(document_id, s3_key): try: _client.delete_object(Bucket=settings.AWS_STORAGE_BUCKET_NAME, Key=s3_key) except ReadTimeoutError: - logger.warning(f"Timeout exceeded when retrieving file '{s3_key}' on document '{document_id}'") + logger.warning( + "Timeout exceeded when retrieving file '%s' on document '%s'", + s3_key, + document_id, + ) except BotoCoreError as exc: - logger.warning(f"An unexpected error occurred when deleting file '{s3_key}' on document '{document_id}': {exc}") + logger.warning( + "An unexpected error occurred when deleting file '%s' on document '%s': %s", + s3_key, + document_id, + exc, + ) def document_download_stream(document): From 98f132a3281d136f4510ab9813b864378edee2cc Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 13 Feb 2024 11:46:12 +0000 Subject: [PATCH 25/28] Make s3_key unique for document data --- api/document_data/migrations/0001_initial.py | 4 ++-- api/document_data/models.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/api/document_data/migrations/0001_initial.py b/api/document_data/migrations/0001_initial.py index 77dca1a822..4775bf7314 100644 --- a/api/document_data/migrations/0001_initial.py +++ b/api/document_data/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.9 on 2024-02-12 17:51 +# Generated by Django 4.2.9 on 2024-02-13 11:45 from django.db import migrations, models import django.utils.timezone @@ -29,7 +29,7 @@ class Migration(migrations.Migration): ), ), ("id", models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), - ("s3_key", models.CharField(max_length=1000)), + ("s3_key", models.CharField(max_length=1000, unique=True)), ("data", models.BinaryField()), ("last_modified", models.DateTimeField()), ], diff --git a/api/document_data/models.py b/api/document_data/models.py index acd5b5d830..4d142d4b9e 100644 --- a/api/document_data/models.py +++ b/api/document_data/models.py @@ -7,6 +7,6 @@ class DocumentData(TimestampableModel): id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) - s3_key = models.CharField(max_length=1000) + s3_key = models.CharField(unique=True, max_length=1000) data = models.BinaryField() last_modified = models.DateTimeField() From 5518ce2e5ddcf860805bd7fb942e7121f2ced7fc Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 13 Feb 2024 12:33:23 +0000 Subject: [PATCH 26/28] Store content type in document data backup --- api/document_data/celery_tasks.py | 4 ++++ api/document_data/migrations/0001_initial.py | 3 ++- api/document_data/models.py | 1 + api/document_data/tests/test_celery_tasks.py | 8 ++++++++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/api/document_data/celery_tasks.py b/api/document_data/celery_tasks.py index 34988fbcaa..050de97480 100644 --- a/api/document_data/celery_tasks.py +++ b/api/document_data/celery_tasks.py @@ -75,23 +75,27 @@ def backup_document_data(): data=file["Body"].read(), last_modified=file["LastModified"], s3_key=document.s3_key, + content_type=file["ContentType"], ) logger.info( "Created '%s' for document '%s'", document.s3_key, document.id, ) + file["Body"].close() continue if file["LastModified"] > document_data.last_modified: document_data.last_modified = file["LastModified"] document_data.data = file["Body"].read() + document_data.content_type = file["ContentType"] document_data.save() logger.info( "Updated '%s' for document '%s'", document.s3_key, document.id, ) + file["Body"].close() continue logger.debug( diff --git a/api/document_data/migrations/0001_initial.py b/api/document_data/migrations/0001_initial.py index 4775bf7314..0e2ed67335 100644 --- a/api/document_data/migrations/0001_initial.py +++ b/api/document_data/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.9 on 2024-02-13 11:45 +# Generated by Django 4.2.9 on 2024-02-13 12:30 from django.db import migrations, models import django.utils.timezone @@ -32,6 +32,7 @@ class Migration(migrations.Migration): ("s3_key", models.CharField(max_length=1000, unique=True)), ("data", models.BinaryField()), ("last_modified", models.DateTimeField()), + ("content_type", models.CharField(max_length=255)), ], options={ "abstract": False, diff --git a/api/document_data/models.py b/api/document_data/models.py index 4d142d4b9e..adfe94a465 100644 --- a/api/document_data/models.py +++ b/api/document_data/models.py @@ -10,3 +10,4 @@ class DocumentData(TimestampableModel): s3_key = models.CharField(unique=True, max_length=1000) data = models.BinaryField() last_modified = models.DateTimeField() + content_type = models.CharField(max_length=255) diff --git a/api/document_data/tests/test_celery_tasks.py b/api/document_data/tests/test_celery_tasks.py index 9459fc6d4b..89bfafe72e 100644 --- a/api/document_data/tests/test_celery_tasks.py +++ b/api/document_data/tests/test_celery_tasks.py @@ -53,6 +53,10 @@ def test_backup_new_document_data(self): document_data.last_modified, s3_object["LastModified"], ) + self.assertEqual( + document_data.content_type, + s3_object["ContentType"], + ) def test_update_existing_document_data(self): self.put_object_in_default_bucket("thisisakey", b"test") @@ -96,6 +100,10 @@ def test_update_existing_document_data(self): document_data.last_modified, s3_object["LastModified"], ) + self.assertEqual( + document_data.content_type, + s3_object["ContentType"], + ) def test_leave_existing_document_data(self): self.put_object_in_default_bucket("thisisakey", b"test") From e20ecfa70417d9f004f55db35174218f103c0c0c Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Tue, 13 Feb 2024 12:50:15 +0000 Subject: [PATCH 27/28] Optimise memory usage for backup document data --- api/document_data/celery_tasks.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/api/document_data/celery_tasks.py b/api/document_data/celery_tasks.py index 050de97480..fe5c6e36c9 100644 --- a/api/document_data/celery_tasks.py +++ b/api/document_data/celery_tasks.py @@ -44,45 +44,46 @@ def backup_document_data(): "Backing up %s documents", count, ) - for index, document in enumerate(safe_documents, start=1): + for index, (document_id, document_s3_key) in enumerate(safe_documents.values_list("pk", "s3_key"), start=1): logger.debug( "Processing %s of %s", index, count, ) try: - file = get_object(document.id, document.s3_key) + file = get_object(document_id, document_s3_key) except ClientError: logger.warning( "Failed to retrieve file '%s' from S3 for document '%s'", - document.s3_key, - document.id, + document_s3_key, + document_id, ) continue if not file: logger.warning( "Failed to retrieve file '%s' from S3 for document '%s'", - document.s3_key, - document.id, + document_s3_key, + document_id, ) continue try: - document_data = DocumentData.objects.get(s3_key=document.s3_key) + document_data = DocumentData.objects.get(s3_key=document_s3_key) except DocumentData.DoesNotExist: DocumentData.objects.create( data=file["Body"].read(), last_modified=file["LastModified"], - s3_key=document.s3_key, + s3_key=document_s3_key, content_type=file["ContentType"], ) logger.info( "Created '%s' for document '%s'", - document.s3_key, - document.id, + document_s3_key, + document_id, ) file["Body"].close() + del file # Clear this out for garbage collection continue if file["LastModified"] > document_data.last_modified: @@ -92,16 +93,17 @@ def backup_document_data(): document_data.save() logger.info( "Updated '%s' for document '%s'", - document.s3_key, - document.id, + document_s3_key, + document_id, ) file["Body"].close() + del file # Clear this out for garbage collection continue logger.debug( "Nothing required for '%s' for document '%s'", - document.s3_key, - document.id, + document_s3_key, + document_id, ) logger.debug("Completed backing up documents") From 9a0b0d68fced34f9379f97fa6cc130598d3659a8 Mon Sep 17 00:00:00 2001 From: Kevin Carrogan Date: Wed, 14 Feb 2024 10:42:24 +0000 Subject: [PATCH 28/28] Ignore lint warning --- api/document_data/celery_tasks.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/api/document_data/celery_tasks.py b/api/document_data/celery_tasks.py index fe5c6e36c9..a995db066d 100644 --- a/api/document_data/celery_tasks.py +++ b/api/document_data/celery_tasks.py @@ -86,17 +86,17 @@ def backup_document_data(): del file # Clear this out for garbage collection continue - if file["LastModified"] > document_data.last_modified: - document_data.last_modified = file["LastModified"] - document_data.data = file["Body"].read() - document_data.content_type = file["ContentType"] + if file["LastModified"] > document_data.last_modified: # noqa: WS03 + document_data.last_modified = file["LastModified"] # noqa: WS03 + document_data.data = file["Body"].read() # noqa: WS03 + document_data.content_type = file["ContentType"] # noqa: WS03 document_data.save() logger.info( "Updated '%s' for document '%s'", document_s3_key, document_id, ) - file["Body"].close() + file["Body"].close() # noqa: WS03 del file # Clear this out for garbage collection continue