diff --git a/lms/models/lti_params.py b/lms/models/lti_params.py index 5a4d27f0c1..26b2ab5960 100644 --- a/lms/models/lti_params.py +++ b/lms/models/lti_params.py @@ -1,4 +1,7 @@ import logging +from datetime import datetime + +from dateutil import parser CLAIM_PREFIX = "https://purl.imsglobal.org/spec/lti/claim" @@ -26,6 +29,13 @@ def __init__(self, v11: dict, v13: dict | None = None): def v11(self): return self + def get_datetime(self, key: str) -> datetime | None: + """Get a datetime from the LTI parameters.""" + try: + return parser.isoparse(self.get(key)) + except (TypeError, ValueError): + return None + @classmethod def from_request(cls, request): """Create an LTIParams from the request.""" diff --git a/lms/services/__init__.py b/lms/services/__init__.py index 41813936ec..0dd1c75ed1 100644 --- a/lms/services/__init__.py +++ b/lms/services/__init__.py @@ -28,6 +28,7 @@ LTILaunchVerificationError, LTIOAuthError, ) +from lms.services.lms_term import LMSTermService from lms.services.lti_grading import LTIGradingService from lms.services.lti_names_roles import LTINamesRolesService from lms.services.lti_registration import LTIRegistrationService @@ -165,6 +166,9 @@ def includeme(config): # noqa: PLR0915 config.register_service_factory( "lms.services.auto_grading.factory", iface=AutoGradingService ) + config.register_service_factory( + "lms.services.lms_term.factory", iface=LMSTermService + ) # Plugins are not the same as top level services but we want to register them as pyramid services too # Importing them here to: diff --git a/lms/services/course.py b/lms/services/course.py index 129a8135fb..e11bbc947e 100644 --- a/lms/services/course.py +++ b/lms/services/course.py @@ -1,9 +1,6 @@ import json from copy import deepcopy -from datetime import datetime -from typing import Mapping -from dateutil import parser from sqlalchemy import Select, select, union from lms.db import full_text_match @@ -28,14 +25,22 @@ ) from lms.product.family import Family from lms.services.grouping import GroupingService +from lms.services.lms_term import LMSTermService from lms.services.upsert import bulk_upsert class CourseService: - def __init__(self, db, application_instance, grouping_service: GroupingService): + def __init__( + self, + db, + application_instance, + grouping_service: GroupingService, + lms_term_service: LMSTermService, + ): self._db = db self._application_instance = application_instance self._grouping_service = grouping_service + self._lms_term_service = lms_term_service def any_with_setting(self, group, key, value=True) -> bool: """ @@ -314,7 +319,10 @@ def _upsert_lms_course(self, course: Course, lti_params: LTIParams) -> LMSCourse "https://purl.imsglobal.org/spec/lti-nrps/claim/namesroleservice", {} ).get("context_memberships_url") - course_starts_at, course_ends_at = self._get_course_dates(lti_params) + course_starts_at = lti_params.get_datetime("custom_course_starts") + course_ends_at = lti_params.get_datetime("custom_course_ends") + + lms_term = self._lms_term_service.get_term(lti_params) lms_api_course_id = self._get_api_id_from_launch(lti_params) lms_course = bulk_upsert( @@ -330,6 +338,7 @@ def _upsert_lms_course(self, course: Course, lti_params: LTIParams) -> LMSCourse "starts_at": course_starts_at, "ends_at": course_ends_at, "lms_api_course_id": lms_api_course_id, + "lms_term_id": lms_term.id if lms_term else None, } ], index_elements=["h_authority_provided_id"], @@ -340,6 +349,7 @@ def _upsert_lms_course(self, course: Course, lti_params: LTIParams) -> LMSCourse "starts_at", "ends_at", "lms_api_course_id", + "lms_term_id", ], ).one() bulk_upsert( @@ -431,22 +441,6 @@ def _get_copied_from_course(self, lti_params) -> Course | None: return None - def _get_course_dates( - self, lti_params: Mapping - ) -> tuple[datetime | None, datetime | None]: - """Get the dates for the current curse, None if not available.""" - try: - course_starts_at = parser.isoparse(lti_params.get("custom_course_starts")) - except (TypeError, ValueError): - course_starts_at = None - - try: - course_ends_at = parser.isoparse(lti_params.get("custom_course_ends")) - except (TypeError, ValueError): - course_ends_at = None - - return course_starts_at, course_ends_at - def _get_api_id_from_launch(self, lti_params: LTIParams) -> str | None: """Get the API ID from the launch params. @@ -464,4 +458,5 @@ def course_service_factory(_context, request): request.lti_user.application_instance if request.lti_user else None ), grouping_service=request.find_service(name="grouping"), + lms_term_service=request.find_service(LMSTermService), ) diff --git a/lms/services/lms_term.py b/lms/services/lms_term.py new file mode 100644 index 0000000000..f8e228ea9a --- /dev/null +++ b/lms/services/lms_term.py @@ -0,0 +1,48 @@ +from lms.models import LMSTerm, LTIParams +from lms.services.upsert import bulk_upsert + + +class LMSTermService: + def __init__(self, db): + self._db = db + + def get_term(self, lti_params: LTIParams) -> LMSTerm | None: + term_starts_at = lti_params.get_datetime("custom_term_start") + term_ends_at = lti_params.get_datetime("custom_term_end") + term_id = lti_params.get("custom_term_id") + term_name = lti_params.get("custom_term_name") + guid = lti_params["tool_consumer_instance_guid"] + + if not any([term_starts_at, term_ends_at]): + # We need to have at least one date to consider a term. + return None + + if term_id: + # If we get an ID from the LMS we'll use it as the key. + # We'll scope it the installs GUID + key = f"{guid}:{term_id}" + else: + # Otherwise we'll use the name and dates as part of the key + key = f"{guid}:{term_name if term_name else '-'}:{term_starts_at if term_starts_at else '-'}:{term_ends_at if term_ends_at else '-'}" + + values = [ + { + "name": term_name, + "tool_consumer_instance_guid": guid, + "starts_at": term_starts_at, + "ends_at": term_ends_at, + "key": key, + "lms_id": term_id, + } + ] + return bulk_upsert( + self._db, + model_class=LMSTerm, + values=values, + index_elements=["key"], + update_columns=["updated", "name", "starts_at", "ends_at"], + ).first() + + +def factory(_context, request): + return LMSTermService(db=request.db) diff --git a/tests/unit/lms/services/course_test.py b/tests/unit/lms/services/course_test.py index 5f62fc2b97..791cb61896 100644 --- a/tests/unit/lms/services/course_test.py +++ b/tests/unit/lms/services/course_test.py @@ -193,6 +193,7 @@ def test_upsert_course( custom_course_ends, course_ends_at, custom_canvas_api_id, + lms_term_service, ): lti_params["custom_course_starts"] = custom_course_starts lti_params["custom_course_ends"] = custom_course_ends @@ -233,6 +234,7 @@ def test_upsert_course( "starts_at": course_starts_at, "ends_at": course_ends_at, "lms_api_course_id": custom_canvas_api_id, + "lms_term_id": lms_term_service.get_term.return_value.id, } ], index_elements=["h_authority_provided_id"], @@ -243,6 +245,7 @@ def test_upsert_course( "starts_at", "ends_at", "lms_api_course_id", + "lms_term_id", ], ), call().one(), @@ -500,11 +503,12 @@ def grouping_service(self, grouping_service): return grouping_service @pytest.fixture - def svc(self, db_session, application_instance, grouping_service): + def svc(self, db_session, application_instance, grouping_service, lms_term_service): return CourseService( db=db_session, application_instance=application_instance, grouping_service=grouping_service, + lms_term_service=lms_term_service, ) @pytest.fixture @@ -533,13 +537,16 @@ def lti_params(self): class TestCourseServiceFactory: - def test_it(self, pyramid_request, grouping_service, CourseService): + def test_it( + self, pyramid_request, grouping_service, CourseService, lms_term_service + ): svc = course_service_factory(sentinel.context, pyramid_request) CourseService.assert_called_once_with( db=pyramid_request.db, application_instance=pyramid_request.lti_user.application_instance, grouping_service=grouping_service, + lms_term_service=lms_term_service, ) assert svc == CourseService.return_value diff --git a/tests/unit/lms/services/lms_term_test.py b/tests/unit/lms/services/lms_term_test.py new file mode 100644 index 0000000000..085f52333e --- /dev/null +++ b/tests/unit/lms/services/lms_term_test.py @@ -0,0 +1,92 @@ +from datetime import datetime +from unittest.mock import sentinel + +import pytest + +from lms.services.lms_term import LMSTermService, factory + + +class TestLMSTermService: + def test_get_term_not_enought_data(self, svc, pyramid_request): + assert not svc.get_term(pyramid_request.lti_params) + + def test_get_term(self, svc, pyramid_request): + term_starts = datetime(2020, 1, 1) + term_ends = datetime(2020, 6, 1) + term_name = "NICE TERM" + lti_params = pyramid_request.lti_params + lti_params["custom_term_start"] = term_starts.isoformat() + lti_params["custom_term_end"] = term_ends.isoformat() + lti_params["custom_term_name"] = term_name + + term = svc.get_term(pyramid_request.lti_params) + + assert term.starts_at == term_starts + assert term.ends_at == term_ends + assert term.name == term_name + assert ( + term.tool_consumer_instance_guid + == lti_params["tool_consumer_instance_guid"] + ) + + @pytest.mark.parametrize( + "name,start,end,term_id,expected", + [ + ( + "NICE TERM", + "2020-01-01 00:00:00", + "2020-06-01 00:00:00", + None, + "TEST_TOOL_CONSUMER_INSTANCE_GUID:NICE TERM:2020-01-01 00:00:00:2020-06-01 00:00:00", + ), + ( + "NICE TERM", + None, + "2020-06-01 00:00:00", + None, + "TEST_TOOL_CONSUMER_INSTANCE_GUID:NICE TERM:-:2020-06-01 00:00:00", + ), + ( + "NICE TERM", + "2020-01-01 00:00:00", + None, + None, + "TEST_TOOL_CONSUMER_INSTANCE_GUID:NICE TERM:2020-01-01 00:00:00:-", + ), + ( + "NICE TERM", + "2020-01-01 00:00:00", + "2020-06-01 00:00:00", + "TERM_ID", + "TEST_TOOL_CONSUMER_INSTANCE_GUID:TERM_ID", + ), + ], + ) + def test_get_term_key( + self, svc, pyramid_request, name, start, end, term_id, expected + ): + lti_params = pyramid_request.lti_params + lti_params["custom_term_start"] = start + lti_params["custom_term_end"] = end + lti_params["custom_term_name"] = name + lti_params["custom_term_id"] = term_id + + term = svc.get_term(pyramid_request.lti_params) + + assert term.key == expected + + @pytest.fixture() + def svc(self, pyramid_request): + return LMSTermService(db=pyramid_request.db) + + +class TestFactory: + def test_it(self, pyramid_request, LMSTermService): + service = factory(sentinel.context, pyramid_request) + + LMSTermService.assert_called_once_with(db=pyramid_request.db) + assert service == LMSTermService.return_value + + @pytest.fixture + def LMSTermService(self, patch): + return patch("lms.services.lms_term.LMSTermService") diff --git a/tests/unit/services.py b/tests/unit/services.py index 7053586117..b1d830bbd0 100644 --- a/tests/unit/services.py +++ b/tests/unit/services.py @@ -37,6 +37,7 @@ from lms.services.jwt import JWTService from lms.services.jwt_oauth2_token import JWTOAuth2TokenService from lms.services.launch_verifier import LaunchVerifier +from lms.services.lms_term import LMSTermService from lms.services.lti_grading import LTIGradingService from lms.services.lti_h import LTIHService from lms.services.lti_names_roles import LTINamesRolesService @@ -89,6 +90,7 @@ "jwt_service", "jwt_oauth2_token_service", "launch_verifier", + "lms_term_service", "lti_grading_service", "lti_h_service", "lti_names_roles_service", @@ -356,6 +358,11 @@ def lti_registration_service(mock_service): return mock_service(LTIRegistrationService) +@pytest.fixture +def lms_term_service(mock_service): + return mock_service(LMSTermService) + + @pytest.fixture def lti_role_service(mock_service): return mock_service(LTIRoleService)