Skip to content

Commit

Permalink
use the object-level validation with extra-context for feed type vali…
Browse files Browse the repository at this point in the history
…dation
  • Loading branch information
regulartim committed Dec 13, 2024
1 parent 679cfde commit 47875d2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
13 changes: 10 additions & 3 deletions api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@ def validate(self, data):


@cache
def feed_type_validation(feed_type: str, valid_feed_types: frozenset) -> bool:
logger.debug(f"FeedsResponseSerializer - validation feed_type: '{feed_type}'")
def feed_type_validation(feed_type: str, valid_feed_types: frozenset) -> str:
if feed_type not in valid_feed_types:
logger.info(f"Feed type {feed_type} not in feed_choices {valid_feed_types}")
raise serializers.ValidationError(f"Invalid feed_type: {feed_type}")
return True
return feed_type


class FeedsSerializer(serializers.Serializer):
Expand All @@ -63,6 +62,10 @@ class FeedsSerializer(serializers.Serializer):
age = serializers.ChoiceField(choices=["persistent", "recent"])
format = serializers.ChoiceField(choices=["csv", "json", "txt"], default="json")

def validate_feed_type(self, feed_type):
logger.debug(f"FeedsSerializer - Validation feed_type: '{feed_type}'")
return feed_type_validation(feed_type, self.context["valid_feed_types"])


class FeedsResponseSerializer(serializers.Serializer):
feed_type = serializers.CharField(max_length=120)
Expand All @@ -72,3 +75,7 @@ class FeedsResponseSerializer(serializers.Serializer):
first_seen = serializers.DateField(format="%Y-%m-%d")
last_seen = serializers.DateField(format="%Y-%m-%d")
times_seen = serializers.IntegerField()

def validate_feed_type(self, feed_type):
logger.debug(f"FeedsResponseSerializer - validation feed_type: '{feed_type}'")
return feed_type_validation(feed_type, self.context["valid_feed_types"])
12 changes: 7 additions & 5 deletions api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from datetime import datetime, timedelta

from api.serializers import EnrichmentSerializer, FeedsResponseSerializer, FeedsSerializer, feed_type_validation
from api.serializers import EnrichmentSerializer, FeedsResponseSerializer, FeedsSerializer
from certego_saas.apps.auth.backend import CookieTokenAuthentication
from certego_saas.ext.helpers import parse_humanized_range
from certego_saas.ext.pagination import CustomPageNumberPagination
Expand Down Expand Up @@ -119,10 +119,10 @@ def get_queryset(request, feed_type, valid_feed_types, attack_type, age, format_
"attack_type": attack_type,
"age": age,
"format": format_,
}
},
context={"valid_feed_types": valid_feed_types},
)
serializer.is_valid(raise_exception=True)
feed_type_validation(feed_type, valid_feed_types)

ordering = request.query_params.get("ordering")
# if ordering == "value" replace it with "name" (the corresponding field in the iocs model)
Expand Down Expand Up @@ -240,9 +240,11 @@ def feeds_response(request, iocs, feed_type, valid_feed_types, format_, dict_onl
if SKIP_FEED_VALIDATION:
json_list.append(data_)
continue
serializer_item = FeedsResponseSerializer(data=data_)
serializer_item = FeedsResponseSerializer(
data=data_,
context={"valid_feed_types": valid_feed_types},
)
serializer_item.is_valid(raise_exception=True)
feed_type_validation(feed_type, valid_feed_types)
json_list.append(serializer_item.data)

# check if sorting the results by feed_type
Expand Down

0 comments on commit 47875d2

Please sign in to comment.