diff --git a/CHANGELOG.md b/CHANGELOG.md index 3dd2e31c7a..653de43e95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Fix flaky "duplicated email" importing fixtures tests [#3176](https://github.com/opendatateam/udata/pull/3176) - Fix deprecated CircleCI config [#3181](https://github.com/opendatateam/udata/pull/3181) - Use proper RESTful Hydra API endpoints [#3178](https://github.com/opendatateam/udata/pull/3178) +- Add a "filter by organization badge" for datasets, dataservices, reuses and organizations [#3155](https://github.com/opendatateam/udata/pull/3155] ## 9.2.4 (2024-10-22) diff --git a/udata/api_fields.py b/udata/api_fields.py index 699d0ee078..37c1ed986c 100644 --- a/udata/api_fields.py +++ b/udata/api_fields.py @@ -1,3 +1,6 @@ +import functools +from typing import Any, Dict + import flask_restx.fields as restx_fields import mongoengine import mongoengine.fields as mongo_fields @@ -114,13 +117,15 @@ def constructor_read(**kwargs): # But we want to keep the `constructor_write` to allow changing the list. def constructor_write(**kwargs): return restx_fields.List(field_write, **kwargs) + elif isinstance( field, (mongo_fields.GenericReferenceField, mongoengine.fields.GenericLazyReferenceField) ): def constructor(**kwargs): return restx_fields.Nested(lazy_reference, **kwargs) - elif isinstance(field, (mongo_fields.ReferenceField, mongo_fields.LazyReferenceField)): + + elif isinstance(field, mongo_fields.ReferenceField | mongo_fields.LazyReferenceField): # For reference we accept while writing a String representing the ID of the referenced model. # For reading, if the user supplied a `nested_fields` (RestX model), we use it to convert # the referenced model, if not we return a String (and RestX will call the `str()` of the model @@ -142,6 +147,7 @@ def constructor_read(**kwargs): def constructor(**kwargs): return restx_fields.Nested(nested_fields, **kwargs) + elif hasattr(field.document_type_obj, "__read_fields__"): def constructor_read(**kwargs): @@ -149,6 +155,7 @@ def constructor_read(**kwargs): def constructor_write(**kwargs): return restx_fields.Nested(field.document_type_obj.__write_fields__, **kwargs) + else: raise ValueError( f"EmbeddedDocumentField `{key}` requires a `nested_fields` param to serialize/deserialize or a `@generate_fields()` definition." @@ -200,8 +207,12 @@ def wrapper(cls): read_fields = {} write_fields = {} ref_fields = {} - sortables = kwargs.get("additionalSorts", []) + sortables = kwargs.get("additional_sorts", []) + filterables = [] + additional_filters = get_fields_with_additional_filters( + kwargs.get("additional_filters", {}) + ) read_fields["id"] = restx_fields.String(required=True, readonly=True) @@ -209,38 +220,48 @@ def wrapper(cls): sortable_key = info.get("sortable", False) if sortable_key: sortables.append( - {"key": sortable_key if isinstance(sortable_key, str) else key, "value": key} + { + "key": sortable_key if isinstance(sortable_key, str) else key, + "value": key, + } ) filterable = info.get("filterable", None) + if filterable is not None: - if "key" not in filterable: - filterable["key"] = key - if "column" not in filterable: - filterable["column"] = key - - if "constraints" not in filterable: - filterable["constraints"] = [] - if isinstance( - field, (mongo_fields.ReferenceField, mongo_fields.LazyReferenceField) - ) or ( - isinstance(field, mongo_fields.ListField) - and isinstance( - field.field, - (mongo_fields.ReferenceField, mongo_fields.LazyReferenceField), - ) - ): - filterable["constraints"].append("objectid") + filterables.append(compute_filter(key, field, info, filterable)) + + additional_filter = additional_filters.get(key, None) + if additional_filter: + if not isinstance( + field, mongo_fields.ReferenceField | mongo_fields.LazyReferenceField + ): + raise Exception("Cannot use additional_filters on not a ref.") + + ref_model = field.document_type + + for child in additional_filter.get("children", []): + inner_field = getattr(ref_model, child["key"]) - if "type" not in filterable: - filterable["type"] = str - if isinstance(field, mongo_fields.BooleanField): - filterable["type"] = boolean + column = f"{key}__{child['key']}" + child["key"] = f"{key}_{child['key']}" + filterable = compute_filter(column, inner_field, info, child) - # We may add more information later here: - # - type of mongo query to execute (right now only simple =) + # Since MongoDB is not capable of doing joins with a column like `organization__slug` we need to + # do a custom filter by splitting the query in two. - filterables.append(filterable) + def query(filterable, query, value): + # We use the computed `filterable["column"]` here because the `compute_filter` function + # could have added a default filter at the end (for example `organization__badges` converted + # in `organization__badges__kind`) + parts = filterable["column"].split("__", 1) + models = ref_model.objects.filter(**{parts[1]: value}).only("id") + return query.filter(**{f"{parts[0]}__in": models}) + + # do a query-based filter instead of a column based one + filterable["query"] = functools.partial(query, filterable) + + filterables.append(filterable) read, write = convert_db_to_field(key, field, info) @@ -330,9 +351,10 @@ def make_lambda(method): for filterable in filterables: parser.add_argument( - filterable["key"], + filterable.get("label", filterable["key"]), type=filterable["type"], location="args", + choices=filterable.get("choices", None), ) cls.__index_parser__ = parser @@ -360,18 +382,23 @@ def apply_sort_filters_and_pagination(base_query): base_query = base_query.search_text(phrase_query) for filterable in filterables: - if args.get(filterable["key"]) is not None: - for constraint in filterable["constraints"]: + filter = args.get(filterable.get("label", filterable["key"])) + if filter is not None: + for constraint in filterable.get("constraints", []): if constraint == "objectid" and not ObjectId.is_valid( args[filterable["key"]] ): api.abort(400, f'`{filterable["key"]}` must be an identifier') - base_query = base_query.filter( - **{ - filterable["column"]: args[filterable["key"]], - } - ) + query = filterable.get("query", None) + if query: + base_query = filterable["query"](base_query, filter) + else: + base_query = base_query.filter( + **{ + filterable["column"]: filter, + } + ) if paginable: base_query = base_query.paginate(args["page"], args["page_size"]) @@ -379,6 +406,7 @@ def apply_sort_filters_and_pagination(base_query): return base_query cls.apply_sort_filters_and_pagination = apply_sort_filters_and_pagination + cls.__additional_class_info__ = kwargs return cls return wrapper @@ -417,12 +445,12 @@ def patch(obj, request): value = model_attribute.from_input(value) elif isinstance(model_attribute, mongoengine.fields.ListField) and isinstance( model_attribute.field, - (mongo_fields.ReferenceField, mongo_fields.LazyReferenceField), + mongo_fields.ReferenceField | mongo_fields.LazyReferenceField, ): # TODO `wrap_primary_key` do Mongo request, do a first pass to fetch all documents before calling it (to avoid multiple queries). value = [wrap_primary_key(key, model_attribute.field, id) for id in value] elif isinstance( - model_attribute, (mongo_fields.ReferenceField, mongo_fields.LazyReferenceField) + model_attribute, mongo_fields.ReferenceField | mongo_fields.LazyReferenceField ): value = wrap_primary_key(key, model_attribute, value) elif isinstance( @@ -517,3 +545,81 @@ def wrap_primary_key( raise ValueError( f"Unknown ID field type {id_field.__class__} for {document_type} (ID field name is {id_field_name}, value was {value})" ) + + +def get_fields_with_additional_filters(additional_filters: Dict[str, str]) -> Dict[str, Any]: + """ + Right now we only support additional filters like "organization.badges". + + The goal of this function is to key the additional filters by the first part (`organization`) to + be able to compute them when we loop over all the fields (`title`, `organization`…) + + The `additional_filters` property is a dict: {"label": "key"}, for example {"organization_badge": "organization.badges"}. + The `label` will be the name of the parser arg, like `?organization_badge=public-service`, which makes more + sense than `?organization_badges=public-service`. + """ + results: dict = {} + for label, key in additional_filters.items(): + parts = key.split(".") + if len(parts) == 2: + parent = parts[0] + child = parts[1] + + if parent not in results: + results[parent] = {"children": []} + + results[parent]["children"].append( + { + "label": label, + "key": child, + "type": str, + } + ) + else: + raise Exception(f"Do not support `additional_filters` without two parts: {key}.") + + return results + + +def compute_filter(column: str, field, info, filterable): + # "key" is the param key in the URL + if "key" not in filterable: + filterable["key"] = column + + # If we do a filter on a embed document, get the class info + # of this document to see if there is a default filter value + embed_info = None + if isinstance(field, mongo_fields.EmbeddedDocumentField): + embed_info = field.get("__additional_class_info__", None) + elif isinstance(field, mongo_fields.EmbeddedDocumentListField): + embed_info = getattr(field.field.document_type, "__additional_class_info__", None) + + if embed_info and embed_info.get("default_filterable_field", None): + # There is a default filterable field so append it to the column and replace the + # field to use the inner one (for example using the `kind` `StringField` instead of + # the embed `Badge` field.) + filterable["column"] = f"{column}__{embed_info['default_filterable_field']}" + field = getattr(field.field.document_type, embed_info["default_filterable_field"]) + else: + filterable["column"] = column + + if "constraints" not in filterable: + filterable["constraints"] = [] + + if isinstance(field, mongo_fields.ReferenceField | mongo_fields.LazyReferenceField) or ( + isinstance(field, mongo_fields.ListField) + and isinstance(field.field, mongo_fields.ReferenceField | mongo_fields.LazyReferenceField) + ): + filterable["constraints"].append("objectid") + + if "type" not in filterable: + if isinstance(field, mongo_fields.BooleanField): + filterable["type"] = boolean + else: + filterable["type"] = str + + filterable["choices"] = info.get("choices", None) + if hasattr(field, "choices") and field.choices: + filterable["choices"] = field.choices + + return filterable diff --git a/udata/core/badges/factories.py b/udata/core/badges/factories.py index 05f1c09601..e5b257fb97 100644 --- a/udata/core/badges/factories.py +++ b/udata/core/badges/factories.py @@ -2,14 +2,12 @@ from udata.factories import ModelFactory -from .models import Badge - -def badge_factory(model): +def badge_factory(model_): class BadgeFactory(ModelFactory): class Meta: - model = Badge + model = model_._fields["badges"].field.document_type - kind = FuzzyChoice(model.__badges__.keys()) + kind = FuzzyChoice(model_.__badges__) return BadgeFactory diff --git a/udata/core/badges/forms.py b/udata/core/badges/forms.py index 22209cd9cd..5e6d063345 100644 --- a/udata/core/badges/forms.py +++ b/udata/core/badges/forms.py @@ -1,6 +1,5 @@ from udata.forms import ModelForm, fields, validators from udata.i18n import lazy_gettext as _ -from udata.models import Badge __all__ = ("badge_form",) @@ -9,8 +8,6 @@ def badge_form(model): """A form factory for a given model badges""" class BadgeForm(ModelForm): - model_class = Badge - kind = fields.RadioField( _("Kind"), [validators.DataRequired()], diff --git a/udata/core/badges/models.py b/udata/core/badges/models.py index 7ca2e14650..f8579e01bd 100644 --- a/udata/core/badges/models.py +++ b/udata/core/badges/models.py @@ -3,7 +3,7 @@ from mongoengine.signals import post_save -from udata.api_fields import field +from udata.api_fields import field, generate_fields from udata.auth import current_user from udata.core.badges.fields import badge_fields from udata.mongo import db @@ -12,10 +12,19 @@ log = logging.getLogger(__name__) -__all__ = ("Badge", "BadgeMixin") +__all__ = ["Badge", "BadgeMixin", "BadgesList"] +DEFAULT_BADGES_LIST_PARAMS = { + "readonly": True, + "inner_field_info": {"nested_fields": badge_fields}, +} + + +@generate_fields(default_filterable_field="kind") class Badge(db.EmbeddedDocument): + meta = {"allow_inheritance": True} + # The following field should be overloaded in descendants. kind = db.StringField(required=True) created = db.DateTimeField(default=datetime.utcnow, required=True) created_by = db.ReferenceField("User") @@ -23,30 +32,16 @@ class Badge(db.EmbeddedDocument): def __str__(self): return self.kind - def validate(self, clean=True): - badges = getattr(self._instance, "__badges__", {}) - if self.kind not in badges.keys(): - raise db.ValidationError("Unknown badge type %s" % self.kind) - return super(Badge, self).validate(clean=clean) - class BadgesList(db.EmbeddedDocumentListField): - def __init__(self, *args, **kwargs): - return super(BadgesList, self).__init__(Badge, *args, **kwargs) - - def validate(self, value): - kinds = [b.kind for b in value] - if len(kinds) > len(set(kinds)): - raise db.ValidationError("Duplicate badges for a given kind is not allowed") - return super(BadgesList, self).validate(value) + def __init__(self, badge_model, *args, **kwargs): + return super(BadgesList, self).__init__(badge_model, *args, **kwargs) -class BadgeMixin(object): - badges = field( - BadgesList(), - readonly=True, - inner_field_info={"nested_fields": badge_fields}, - ) +class BadgeMixin: + default_badges_list_params = DEFAULT_BADGES_LIST_PARAMS + # The following field should be overloaded in descendants. + badges = field(BadgesList(Badge), **DEFAULT_BADGES_LIST_PARAMS) def get_badge(self, kind): """Get a badge given its kind if present""" @@ -61,7 +56,7 @@ def add_badge(self, kind): if kind not in getattr(self, "__badges__", {}): msg = "Unknown badge type for {model}: {kind}" raise db.ValidationError(msg.format(model=self.__class__.__name__, kind=kind)) - badge = Badge(kind=kind) + badge = self._fields["badges"].field.document_type(kind=kind) if current_user.is_authenticated: badge.created_by = current_user.id @@ -88,5 +83,5 @@ def toggle_badge(self, kind): def badge_label(self, badge): """Display the badge label for a given kind""" - kind = badge.kind if isinstance(badge, Badge) else badge + kind = badge.kind if isinstance(badge, self.badge) else badge return self.__badges__[kind] diff --git a/udata/core/badges/tests/test_commands.py b/udata/core/badges/tests/test_commands.py index 8f02ab514f..a9e168a156 100644 --- a/udata/core/badges/tests/test_commands.py +++ b/udata/core/badges/tests/test_commands.py @@ -4,7 +4,6 @@ from udata.core.organization.constants import CERTIFIED, PUBLIC_SERVICE from udata.core.organization.factories import OrganizationFactory -from udata.models import Badge @pytest.mark.usefixtures("clean_db") @@ -21,9 +20,9 @@ def test_toggle_badge_on(self, cli): assert org.badges[0].kind == PUBLIC_SERVICE def test_toggle_badge_off(self, cli): - ps_badge = Badge(kind=PUBLIC_SERVICE) - certified_badge = Badge(kind=CERTIFIED) - org = OrganizationFactory(badges=[ps_badge, certified_badge]) + org = OrganizationFactory() + org.add_badge(PUBLIC_SERVICE) + org.add_badge(CERTIFIED) cli("badges", "toggle", str(org.id), PUBLIC_SERVICE) diff --git a/udata/core/badges/tests/test_model.py b/udata/core/badges/tests/test_model.py index d5d7cdc839..6913473831 100644 --- a/udata/core/badges/tests/test_model.py +++ b/udata/core/badges/tests/test_model.py @@ -1,19 +1,31 @@ +from udata.api_fields import field from udata.auth import login_user from udata.core.user.factories import UserFactory from udata.mongo import db from udata.tests import DBTestMixin, TestCase -from ..models import Badge, BadgeMixin +from ..models import Badge, BadgeMixin, BadgesList TEST = "test" OTHER = "other" +BADGES = { + TEST: "Test", + OTHER: "Other", +} -class Fake(db.Document, BadgeMixin): - __badges__ = { - TEST: "Test", - OTHER: "Other", - } + +class FakeBadge(Badge): + kind = db.StringField(required=True, choices=list(BADGES.keys())) + + +class FakeBadgeMixin(BadgeMixin): + badges = field(BadgesList(FakeBadge), **BadgeMixin.default_badges_list_params) + __badges__ = BADGES + + +class Fake(db.Document, FakeBadgeMixin): + pass class BadgeMixinTest(DBTestMixin, TestCase): @@ -22,15 +34,24 @@ def test_attributes(self): fake = Fake.objects.create() self.assertIsInstance(fake.badges, (list, tuple)) + def test_choices(self): + """It should have a choice list on the badge field.""" + self.assertEqual( + Fake._fields["badges"].field.document_type.kind.choices, list(Fake.__badges__.keys()) + ) + def test_get_badge_found(self): """It allow to get a badge by kind if present""" - fake = Fake.objects.create(badges=[Badge(kind=TEST), Badge(kind=OTHER)]) + fake = Fake.objects.create() + fake.add_badge(TEST) + fake.add_badge(OTHER) badge = fake.get_badge(TEST) self.assertEqual(badge.kind, TEST) def test_get_badge_not_found(self): """It should return None if badge is absent""" - fake = Fake.objects.create(badges=[Badge(kind=OTHER)]) + fake = Fake.objects.create() + fake.add_badge(OTHER) badge = fake.get_badge(TEST) self.assertIsNone(badge) @@ -49,7 +70,8 @@ def test_add_badge(self): def test_add_2nd_badge(self): """It should add badges to the top of the list""" - fake = Fake.objects.create(badges=[Badge(kind=OTHER)]) + fake = Fake.objects.create() + fake.add_badge(OTHER) result = fake.add_badge(TEST) @@ -86,8 +108,8 @@ def test_add_unknown_badge(self): def test_remove_badge(self): """It should remove a badge given its kind""" - badge = Badge(kind=TEST) - fake = Fake.objects.create(badges=[badge]) + fake = Fake.objects.create() + fake.add_badge(TEST) fake.remove_badge(TEST) @@ -121,28 +143,15 @@ def test_toggle_add_badge(self): def test_toggle_remove_badge(self): """Toggle should remove a badge given its kind if present""" - badge = Badge(kind=TEST) - fake = Fake.objects.create(badges=[badge]) + fake = Fake.objects.create() + fake.add_badge(TEST) fake.toggle_badge(TEST) self.assertEqual(len(fake.badges), 0) - def test_create_with_badges(self): - """It should allow object creation with badges""" - fake = Fake.objects.create(badges=[Badge(kind=TEST), Badge(kind=OTHER)]) - - self.assertEqual(len(fake.badges), 2) - for badge, kind in zip(fake.badges, (TEST, OTHER)): - self.assertEqual(badge.kind, kind) - self.assertIsNotNone(badge.created) - - def test_create_disallow_duplicate_badges(self): - """It should not allow object creation with duplicate badges""" - with self.assertRaises(db.ValidationError): - Fake.objects.create(badges=[Badge(kind=TEST), Badge(kind=TEST)]) - def test_create_disallow_unknown_badges(self): """It should not allow object creation with unknown badges""" with self.assertRaises(db.ValidationError): - Fake.objects.create(badges=[Badge(kind="unknown")]) + fake = Fake.objects.create() + fake.add_badge("unknown") diff --git a/udata/core/dataservices/models.py b/udata/core/dataservices/models.py index 7174ae8054..6ef4c1a801 100644 --- a/udata/core/dataservices/models.py +++ b/udata/core/dataservices/models.py @@ -95,7 +95,10 @@ class HarvestMetadata(db.EmbeddedDocument): archived_at = field(db.DateTimeField()) -@generate_fields(searchable=True) +@generate_fields( + searchable=True, + additional_filters={"organization_badge": "organization.badges"}, +) class Dataservice(WithMetrics, Owned, db.Document): meta = { "indexes": [ diff --git a/udata/core/dataset/api.py b/udata/core/dataset/api.py index 6d76f1020c..280fd81c70 100644 --- a/udata/core/dataset/api.py +++ b/udata/core/dataset/api.py @@ -36,6 +36,7 @@ from udata.core.dataservices.models import Dataservice from udata.core.dataset.models import CHECKSUM_TYPES from udata.core.followers.api import FollowAPI +from udata.core.organization.models import Organization from udata.core.storages.api import handle_upload, upload_parser from udata.core.topic.models import Topic from udata.linkchecker.checker import check_resource @@ -96,6 +97,12 @@ def __init__(self): self.parser.add_argument("granularity", type=str, location="args") self.parser.add_argument("temporal_coverage", type=str, location="args") self.parser.add_argument("organization", type=str, location="args") + self.parser.add_argument( + "organization_badge", + type=str, + choices=list(Organization.__badges__), + location="args", + ) self.parser.add_argument("owner", type=str, location="args") self.parser.add_argument("format", type=str, location="args") self.parser.add_argument("schema", type=str, location="args") @@ -131,6 +138,9 @@ def parse_filters(datasets, args): if not ObjectId.is_valid(args["organization"]): api.abort(400, "Organization arg must be an identifier") datasets = datasets.filter(organization=args["organization"]) + if args.get("organization_badge"): + orgs = Organization.objects.with_badge(args["organization_badge"]).only("id") + datasets = datasets.filter(organization__in=orgs) if args.get("owner"): if not ObjectId.is_valid(args["owner"]): api.abort(400, "Owner arg must be an identifier") diff --git a/udata/core/dataset/models.py b/udata/core/dataset/models.py index 058c448cd5..f84bd70eca 100644 --- a/udata/core/dataset/models.py +++ b/udata/core/dataset/models.py @@ -14,12 +14,13 @@ from stringdist import rdlevenshtein from werkzeug.utils import cached_property +from udata.api_fields import field from udata.app import cache from udata.core import storages from udata.core.owned import Owned, OwnedQuerySet from udata.frontend.markdown import mdstrip from udata.i18n import lazy_gettext as _ -from udata.models import BadgeMixin, SpatialCoverage, WithMetrics, db +from udata.models import Badge, BadgeMixin, BadgesList, SpatialCoverage, WithMetrics, db from udata.mongo.errors import FieldValidationError from udata.uris import ValidationError, endpoint_for from udata.uris import validate as validate_url @@ -53,6 +54,10 @@ "ResourceSchema", ) +BADGES: dict[str, str] = { + PIVOTAL_DATA: _("Pivotal data"), +} + NON_ASSIGNABLE_SCHEMA_TYPES = ["datapackage"] log = logging.getLogger(__name__) @@ -498,7 +503,16 @@ def save(self, *args, **kwargs): self.dataset.save(*args, **kwargs) -class Dataset(WithMetrics, BadgeMixin, Owned, db.Document): +class DatasetBadge(Badge): + kind = db.StringField(required=True, choices=list(BADGES.keys())) + + +class DatasetBadgeMixin(BadgeMixin): + badges = field(BadgesList(DatasetBadge), **BadgeMixin.default_badges_list_params) + __badges__ = BADGES + + +class Dataset(WithMetrics, DatasetBadgeMixin, Owned, db.Document): title = db.StringField(required=True) acronym = db.StringField(max_length=128) # /!\ do not set directly the slug when creating or updating a dataset @@ -539,10 +553,6 @@ class Dataset(WithMetrics, BadgeMixin, Owned, db.Document): def __str__(self): return self.title or "" - __badges__ = { - PIVOTAL_DATA: _("Pivotal data"), - } - __metrics_keys__ = [ "discussions", "reuses", diff --git a/udata/core/dataset/search.py b/udata/core/dataset/search.py index 653aa4f3d5..c36886724c 100644 --- a/udata/core/dataset/search.py +++ b/udata/core/dataset/search.py @@ -32,8 +32,9 @@ class DatasetSearch(ModelSearchAdapter): filters = { "tag": Filter(), - "badge": Filter(), + "badge": Filter(choices=list(Dataset.__badges__)), "organization": ModelTermsFilter(model=Organization), + "organization_badge": Filter(choices=list(Organization.__badges__)), "owner": ModelTermsFilter(model=User), "license": ModelTermsFilter(model=License), "geozone": ModelTermsFilter(model=GeoZone), @@ -76,6 +77,7 @@ def serialize(cls, dataset): "name": org.name, "public_service": 1 if org.public_service else 0, "followers": org.metrics.get("followers", 0), + "badges": [badge.kind for badge in org.badges], } elif dataset.owner: owner = User.objects(id=dataset.owner.id).first() diff --git a/udata/core/organization/api.py b/udata/core/organization/api.py index 0a19a8496e..61c191e524 100644 --- a/udata/core/organization/api.py +++ b/udata/core/organization/api.py @@ -63,6 +63,15 @@ class OrgApiParser(ModelApiParser): "last_modified": "last_modified", } + def __init__(self): + super().__init__() + self.parser.add_argument( + "badge", + type=str, + choices=list(Organization.__badges__), + location="args", + ) + @staticmethod def parse_filters(organizations, args): if args.get("q"): @@ -72,6 +81,8 @@ def parse_filters(organizations, args): # between tokens whereas an OR is used without it. phrase_query = " ".join([f'"{elem}"' for elem in args["q"].split(" ")]) organizations = organizations.search_text(phrase_query) + if args.get("badge"): + organizations = organizations.with_badge(args["badge"]) return organizations diff --git a/udata/core/organization/models.py b/udata/core/organization/models.py index 047a2fa086..60351ee3d4 100644 --- a/udata/core/organization/models.py +++ b/udata/core/organization/models.py @@ -5,7 +5,8 @@ from mongoengine.signals import post_save, pre_save from werkzeug.utils import cached_property -from udata.core.badges.models import BadgeMixin +from udata.api_fields import field +from udata.core.badges.models import Badge, BadgeMixin, BadgesList from udata.core.metrics.models import WithMetrics from udata.core.storages import avatars, default_image_basename from udata.frontend.markdown import mdstrip @@ -29,6 +30,14 @@ __all__ = ("Organization", "Team", "Member", "MembershipRequest") +BADGES: dict[str, str] = { + PUBLIC_SERVICE: _("Public Service"), + CERTIFIED: _("Certified"), + ASSOCIATION: _("Association"), + COMPANY: _("Company"), + LOCAL_AUTHORITY: _("Local authority"), +} + class Team(db.EmbeddedDocument): name = db.StringField(required=True) @@ -82,8 +91,20 @@ def hidden(self): def get_by_id_or_slug(self, id_or_slug): return self(slug=id_or_slug).first() or self(id=id_or_slug).first() + def with_badge(self, kind): + return self(badges__kind=kind) + + +class OrganizationBadge(Badge): + kind = db.StringField(required=True, choices=list(BADGES.keys())) + -class Organization(WithMetrics, BadgeMixin, db.Datetimed, db.Document): +class OrganizationBadgeMixin(BadgeMixin): + badges = field(BadgesList(OrganizationBadge), **BadgeMixin.default_badges_list_params) + __badges__ = BADGES + + +class Organization(WithMetrics, OrganizationBadgeMixin, db.Datetimed, db.Document): name = db.StringField(required=True) acronym = db.StringField(max_length=128) slug = db.SlugField( @@ -126,14 +147,6 @@ class Organization(WithMetrics, BadgeMixin, db.Datetimed, db.Document): def __str__(self): return self.name or "" - __badges__ = { - PUBLIC_SERVICE: _("Public Service"), - CERTIFIED: _("Certified"), - ASSOCIATION: _("Association"), - COMPANY: _("Company"), - LOCAL_AUTHORITY: _("Local authority"), - } - __metrics_keys__ = [ "datasets", "members", diff --git a/udata/core/organization/search.py b/udata/core/organization/search.py index dcec9ffc63..4ce4bc2bea 100644 --- a/udata/core/organization/search.py +++ b/udata/core/organization/search.py @@ -3,7 +3,7 @@ from udata import search from udata.core.organization.api import DEFAULT_SORTING, OrgApiParser from udata.models import Organization -from udata.search.fields import Filter +from udata.search.fields import ModelTermsFilter from udata.utils import to_iso_datetime __all__ = ("OrganizationSearch",) @@ -22,7 +22,11 @@ class OrganizationSearch(search.ModelSearchAdapter): "created": "created_at", } - filters = {"badge": Filter()} + filters = { + "badge": ModelTermsFilter( + model=Organization, field_name="badges", choices=list(Organization.__badges__) + ), + } @classmethod def is_indexable(cls, org): diff --git a/udata/core/reuse/api.py b/udata/core/reuse/api.py index de854e2175..310283a1b1 100644 --- a/udata/core/reuse/api.py +++ b/udata/core/reuse/api.py @@ -13,6 +13,7 @@ from udata.core.badges.fields import badge_fields from udata.core.dataset.api_fields import dataset_ref_fields from udata.core.followers.api import FollowAPI +from udata.core.organization.models import Organization from udata.core.reuse.constants import REUSE_TOPICS, REUSE_TYPES from udata.core.storages.api import ( image_parser, @@ -49,6 +50,12 @@ def __init__(self): self.parser.add_argument("dataset", type=str, location="args") self.parser.add_argument("tag", type=str, location="args") self.parser.add_argument("organization", type=str, location="args") + self.parser.add_argument( + "organization_badge", + type=str, + choices=list(Organization.__badges__), + location="args", + ) self.parser.add_argument("owner", type=str, location="args") self.parser.add_argument("type", type=str, location="args") self.parser.add_argument("topic", type=str, location="args") @@ -79,6 +86,9 @@ def parse_filters(reuses, args): if not ObjectId.is_valid(args["organization"]): api.abort(400, "Organization arg must be an identifier") reuses = reuses.filter(organization=args["organization"]) + if args.get("organization_badge"): + orgs = Organization.objects.with_badge(args["organization_badge"]) + reuses = reuses.filter(organization__in=orgs) if args.get("owner"): if not ObjectId.is_valid(args["owner"]): api.abort(400, "Owner arg must be an identifier") diff --git a/udata/core/reuse/models.py b/udata/core/reuse/models.py index b5c1f4a46c..8a33cff151 100644 --- a/udata/core/reuse/models.py +++ b/udata/core/reuse/models.py @@ -9,7 +9,7 @@ from udata.core.storages import default_image_basename, images from udata.frontend.markdown import mdstrip from udata.i18n import lazy_gettext as _ -from udata.models import BadgeMixin, WithMetrics, db +from udata.models import Badge, BadgeMixin, BadgesList, WithMetrics, db from udata.mongo.errors import FieldValidationError from udata.uris import endpoint_for from udata.utils import hash_url @@ -18,6 +18,8 @@ __all__ = ("Reuse",) +BADGES: dict[str, str] = {} + class ReuseQuerySet(OwnedQuerySet): def visible(self): @@ -33,15 +35,25 @@ def check_url_does_not_exists(url): raise FieldValidationError(_("This URL is already registered"), field="url") +class ReuseBadge(Badge): + kind = db.StringField(required=True, choices=list(BADGES.keys())) + + +class ReuseBadgeMixin(BadgeMixin): + badges = field(BadgesList(ReuseBadge), **BadgeMixin.default_badges_list_params) + __badges__ = BADGES + + @generate_fields( searchable=True, - additionalSorts=[ + additional_sorts=[ {"key": "datasets", "value": "metrics.datasets"}, {"key": "followers", "value": "metrics.followers"}, {"key": "views", "value": "metrics.views"}, ], + additional_filters={"organization_badge": "organization.badges"}, ) -class Reuse(db.Datetimed, WithMetrics, BadgeMixin, Owned, db.Document): +class Reuse(db.Datetimed, WithMetrics, ReuseBadgeMixin, Owned, db.Document): title = field( db.StringField(required=True), sortable=True, @@ -124,8 +136,6 @@ class Reuse(db.Datetimed, WithMetrics, BadgeMixin, Owned, db.Document): def __str__(self): return self.title or "" - __badges__ = {} - __metrics_keys__ = [ "discussions", "datasets", diff --git a/udata/core/reuse/search.py b/udata/core/reuse/search.py index 3774e11647..c9df9d4fb4 100644 --- a/udata/core/reuse/search.py +++ b/udata/core/reuse/search.py @@ -29,9 +29,10 @@ class ReuseSearch(ModelSearchAdapter): filters = { "tag": Filter(), "organization": ModelTermsFilter(model=Organization), + "organization_badge": Filter(choices=list(Organization.__badges__)), "owner": ModelTermsFilter(model=User), "type": Filter(), - "badge": Filter(), + "badge": Filter(choices=list(Reuse.__badges__)), "featured": BoolFilter(), "topic": Filter(), "archived": BoolFilter(), @@ -65,6 +66,7 @@ def serialize(cls, reuse: Reuse) -> dict: "name": org.name, "public_service": 1 if org.public_service else 0, "followers": org.metrics.get("followers", 0), + "badges": [badge.kind for badge in org.badges], } elif reuse.owner: owner = User.objects(id=reuse.owner.id).first() diff --git a/udata/search/fields.py b/udata/search/fields.py index a725705c8d..b2b434246b 100644 --- a/udata/search/fields.py +++ b/udata/search/fields.py @@ -19,8 +19,12 @@ class Filter: - @staticmethod - def as_request_parser_kwargs(): + def __init__(self, choices=None): + self.choices = choices + + def as_request_parser_kwargs(self): + if self.choices: + return {"type": clean_string, "choices": self.choices} return {"type": clean_string} @@ -31,9 +35,10 @@ def as_request_parser_kwargs(): class ModelTermsFilter(Filter): - def __init__(self, model, field_name="id"): + def __init__(self, model, field_name="id", choices=None): self.model = model self.field_name = field_name + super().__init__(choices=choices) @property def model_field(self): diff --git a/udata/tests/api/test_dataservices_api.py b/udata/tests/api/test_dataservices_api.py index b2556a92a5..0ce8af2da3 100644 --- a/udata/tests/api/test_dataservices_api.py +++ b/udata/tests/api/test_dataservices_api.py @@ -3,6 +3,7 @@ import pytest from flask import url_for +import udata.core.organization.constants as org_constants from udata.core.dataservices.factories import DataserviceFactory from udata.core.dataservices.models import Dataservice from udata.core.dataset.factories import DatasetFactory, LicenseFactory @@ -10,7 +11,7 @@ from udata.core.organization.models import Member from udata.core.user.factories import UserFactory from udata.i18n import gettext as _ -from udata.tests.helpers import assert200, assert_redirects +from udata.tests.helpers import assert200, assert400, assert_redirects from . import APITestCase @@ -18,6 +19,25 @@ class DataserviceAPITest(APITestCase): modules = [] + def test_dataservices_api_list_with_filters(self): + """Should filters dataservices results based on query filters""" + org = OrganizationFactory() + org_public_service = OrganizationFactory() + org_public_service.add_badge(org_constants.PUBLIC_SERVICE) + + _dataservice = DataserviceFactory(organization=org) + dataservice_public_service = DataserviceFactory(organization=org_public_service) + + response = self.get( + url_for("api.dataservices", organization_badge=org_constants.PUBLIC_SERVICE) + ) + assert200(response) + assert len(response.json["data"]) == 1 + assert response.json["data"][0]["id"] == str(dataservice_public_service.id) + + response = self.get(url_for("api.dataservices", organization_badge="bad-badge")) + assert400(response) + def test_dataservice_api_create(self): user = self.login() datasets = DatasetFactory.create_batch(3) diff --git a/udata/tests/api/test_datasets_api.py b/udata/tests/api/test_datasets_api.py index 08ebfdcf91..5f8163c153 100644 --- a/udata/tests/api/test_datasets_api.py +++ b/udata/tests/api/test_datasets_api.py @@ -8,6 +8,7 @@ import requests_mock from flask import url_for +import udata.core.organization.constants as org_constants from udata.api import fields from udata.app import cache from udata.core import storages @@ -148,6 +149,8 @@ def test_dataset_api_list_with_filters(self): """Should filters datasets results based on query filters""" owner = UserFactory() org = OrganizationFactory() + org_public_service = OrganizationFactory() + org_public_service.add_badge(org_constants.PUBLIC_SERVICE) [DatasetFactory() for i in range(2)] @@ -167,6 +170,7 @@ def test_dataset_api_list_with_filters(self): owner_dataset = DatasetFactory(owner=owner) org_dataset = DatasetFactory(organization=org) + org_dataset_public_service = DatasetFactory(organization=org_public_service) schema_dataset = DatasetFactory( resources=[ @@ -247,6 +251,17 @@ def test_dataset_api_list_with_filters(self): response = self.get(url_for("api.datasets", organization="org-id")) self.assert400(response) + # filter on organization badge + response = self.get( + url_for("api.datasets", organization_badge=org_constants.PUBLIC_SERVICE) + ) + self.assert200(response) + self.assertEqual(len(response.json["data"]), 1) + self.assertEqual(response.json["data"][0]["id"], str(org_dataset_public_service.id)) + + response = self.get(url_for("api.datasets", organization_badge="bad-badge")) + self.assert400(response) + # filter on schema response = self.get(url_for("api.datasets", schema="my-schema")) self.assert200(response) diff --git a/udata/tests/api/test_organizations_api.py b/udata/tests/api/test_organizations_api.py index bd55885700..962d16e91c 100644 --- a/udata/tests/api/test_organizations_api.py +++ b/udata/tests/api/test_organizations_api.py @@ -3,6 +3,7 @@ import pytest from flask import url_for +import udata.core.organization.constants as org_constants from udata.core.badges.factories import badge_factory from udata.core.badges.signals import on_badge_added, on_badge_removed from udata.core.dataset.factories import DatasetFactory @@ -41,6 +42,20 @@ def test_organization_api_list(self, api): assert200(response) len(response.json["data"]) == len(organizations) + def test_organization_api_list_with_filters(self, api): + """It should filter the organization list""" + _org = OrganizationFactory() + org_public_service = OrganizationFactory() + org_public_service.add_badge(org_constants.PUBLIC_SERVICE) + + response = api.get(url_for("api.organizations", badge=org_constants.PUBLIC_SERVICE)) + assert200(response) + assert len(response.json["data"]) == 1 + assert response.json["data"][0]["id"] == str(org_public_service.id) + + response = api.get(url_for("api.organizations", badge="bad-badge")) + assert400(response) + def test_organization_role_api_get(self, api): """It should fetch an organization's roles list from the API""" response = api.get(url_for("api.org_roles")) @@ -818,10 +833,6 @@ class OrganizationBadgeAPITest: @pytest.fixture(autouse=True) def setUp(self, api, clean_db): - # Register at least two badges - Organization.__badges__["test-1"] = "Test 1" - Organization.__badges__["test-2"] = "Test 2" - self.factory = badge_factory(Organization) self.user = api.login(AdminFactory()) self.organization = OrganizationFactory() diff --git a/udata/tests/api/test_reuses_api.py b/udata/tests/api/test_reuses_api.py index b1927913ee..2782a9f1e2 100644 --- a/udata/tests/api/test_reuses_api.py +++ b/udata/tests/api/test_reuses_api.py @@ -4,6 +4,7 @@ from flask import url_for from werkzeug.test import TestResponse +import udata.core.organization.constants as org_constants from udata.core.badges.factories import badge_factory from udata.core.dataset.factories import DatasetFactory from udata.core.organization.factories import OrganizationFactory @@ -67,12 +68,17 @@ def test_reuse_api_list_with_filters(self, api): """Should filters reuses results based on query filters""" owner = UserFactory() org = OrganizationFactory() + org_public_service = OrganizationFactory() + org_public_service.add_badge(org_constants.PUBLIC_SERVICE) [ReuseFactory(topic="health", type="api") for i in range(2)] tag_reuse = ReuseFactory(tags=["my-tag", "other"], topic="health", type="api") owner_reuse = ReuseFactory(owner=owner, topic="health", type="api") org_reuse = ReuseFactory(organization=org, topic="health", type="api") + org_reuse_public_service = ReuseFactory( + organization=org_public_service, topic="health", type="api" + ) featured_reuse = ReuseFactory(featured=True, topic="health", type="api") topic_reuse = ReuseFactory(topic="transport_and_mobility", type="api") type_reuse = ReuseFactory(topic="health", type="application") @@ -125,6 +131,15 @@ def test_reuse_api_list_with_filters(self, api): response = api.get(url_for("api.reuses", organization="org-id")) assert400(response) + # filter on organization badge + response = api.get(url_for("api.reuses", organization_badge=org_constants.PUBLIC_SERVICE)) + assert200(response) + assert len(response.json["data"]) == 1 + assert response.json["data"][0]["id"] == str(org_reuse_public_service.id) + + response = api.get(url_for("api.reuses", organization_badge="bad-badge")) + assert400(response) + def test_reuse_api_list_filter_private(self, api) -> None: """Should filters reuses results based on the `private` filter""" user = UserFactory() diff --git a/udata/tests/apiv2/test_datasets.py b/udata/tests/apiv2/test_datasets.py index 824d854487..1eaa1f380f 100644 --- a/udata/tests/apiv2/test_datasets.py +++ b/udata/tests/apiv2/test_datasets.py @@ -2,6 +2,7 @@ from flask import url_for +import udata.core.organization.constants as org_constants from udata.core.dataset.apiv2 import DEFAULT_PAGE_SIZE from udata.core.dataset.factories import ( CommunityResourceFactory, @@ -44,6 +45,22 @@ def test_get_dataset(self): assert data["community_resources"]["type"] == "GET" assert data["community_resources"]["total"] == 0 + def test_search_dataset(self): + org = OrganizationFactory() + org.add_badge(org_constants.CERTIFIED) + org_public_service = OrganizationFactory() + org_public_service.add_badge(org_constants.PUBLIC_SERVICE) + _dataset_org = DatasetFactory(organization=org) + dataset_org_public_service = DatasetFactory(organization=org_public_service) + + response = self.get( + url_for("apiv2.dataset_search", organization_badge=org_constants.PUBLIC_SERVICE) + ) + self.assert200(response) + data = response.json["data"] + assert len(data) == 1 + assert data[0]["id"] == str(dataset_org_public_service.id) + class DatasetResourceAPIV2Test(APITestCase): def test_get_specific(self): diff --git a/udata/tests/organization/test_organization_model.py b/udata/tests/organization/test_organization_model.py index b451e0e2b4..eaef542081 100644 --- a/udata/tests/organization/test_organization_model.py +++ b/udata/tests/organization/test_organization_model.py @@ -1,8 +1,10 @@ from datetime import datetime +import udata.core.organization.constants as org_constants from udata.core.dataset.factories import DatasetFactory, HiddenDatasetFactory from udata.core.followers.signals import on_follow, on_unfollow from udata.core.organization.factories import OrganizationFactory +from udata.core.organization.models import Organization from udata.core.reuse.factories import ReuseFactory, VisibleReuseFactory from udata.core.user.factories import UserFactory from udata.models import Dataset, Follow, Member, Reuse @@ -50,3 +52,22 @@ def test_organization_metrics(self): assert org.get_metrics()["datasets"] == 0 assert org.get_metrics()["reuses"] == 0 assert org.get_metrics()["followers"] == 0 + + def test_organization_queryset_with_badge(self): + org_public_service = OrganizationFactory() + org_public_service.add_badge(org_constants.PUBLIC_SERVICE) + org_certified_association = OrganizationFactory() + org_certified_association.add_badge(org_constants.CERTIFIED) + org_certified_association.add_badge(org_constants.ASSOCIATION) + + public_services = list(Organization.objects.with_badge(org_constants.PUBLIC_SERVICE)) + assert len(public_services) == 1 + assert org_public_service in public_services + + certified = list(Organization.objects.with_badge(org_constants.CERTIFIED)) + assert len(certified) == 1 + assert org_certified_association in certified + + associations = list(Organization.objects.with_badge(org_constants.ASSOCIATION)) + assert len(associations) == 1 + assert org_certified_association in associations diff --git a/udata/tests/site/test_site_metrics.py b/udata/tests/site/test_site_metrics.py index 27d9aee11e..4fe36a7c36 100644 --- a/udata/tests/site/test_site_metrics.py +++ b/udata/tests/site/test_site_metrics.py @@ -9,7 +9,6 @@ from udata.core.reuse.factories import VisibleReuseFactory from udata.core.site.factories import SiteFactory from udata.harvest.tests.factories import HarvestSourceFactory -from udata.models import Badge @pytest.mark.usefixtures("clean_db") @@ -52,8 +51,7 @@ def test_resources_metric(self, app): def test_badges_metric(self, app): site = SiteFactory.create(id=app.config["SITE_ID"]) - ps_badge = Badge(kind=PUBLIC_SERVICE) - public_services = [OrganizationFactory(badges=[ps_badge]) for _ in range(2)] + public_services = [OrganizationFactory().add_badge(PUBLIC_SERVICE) for _ in range(2)] for _ in range(3): OrganizationFactory()