Skip to content

Commit

Permalink
Skipping Cache Invalidations for Unsupported Raw Queries
Browse files Browse the repository at this point in the history
  • Loading branch information
Yogesh Kumar committed Oct 30, 2023
1 parent b411e00 commit 578bf81
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
10 changes: 10 additions & 0 deletions cachalot/monkey_patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from collections.abc import Iterable
from functools import wraps
from time import time
Expand Down Expand Up @@ -140,6 +141,15 @@ def inner(cursor, sql, *args, **kwargs):
if getattr(connection, 'raw', True):
if isinstance(sql, bytes):
sql = sql.decode('utf-8')
# in case `sql` is not of type `str`,
# we raise a warning and return.
if not isinstance(sql, str):
warnings.warn(
f"Unsupported sql of type {type(sql)}, "
f"skipping raw query cache invalidation's. "
f"Try setting `CACHALOT_INVALIDATE_RAW` to False."
)
return
sql = sql.lower()
if SQL_DATA_CHANGE_RE.search(sql):
tables = filter_cachable(
Expand Down
16 changes: 16 additions & 0 deletions cachalot/tests/read.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import warnings
from unittest import skipIf
from uuid import UUID
from decimal import Decimal
Expand All @@ -15,6 +16,7 @@
from django.db.models.functions import Coalesce, Now
from django.db.transaction import TransactionManagementError
from django.test import TransactionTestCase, skipUnlessDBFeature, override_settings
from psycopg2 import sql
from pytz import UTC

from cachalot.cache import cachalot_caches
Expand Down Expand Up @@ -1010,6 +1012,20 @@ def test_cursor_execute_bytes(self):
data2,
[('é',) + l for l in Test.objects.values_list(*attnames)])

def test_unsupported_sql_type_execute_with_warning(self):
table_name = Test._meta.db_table
sql_ = sql.SQL("SELECT {field} FROM {table}").format(
field=sql.Identifier('name'),
table=sql.Identifier(table_name),
)
with warnings.catch_warnings(record=True) as warning:
with connection.cursor() as cursor:
cursor.execute(sql_)
data = list(cursor.fetchall())
self.assertEqual(len(warning), 1)
self.assertEqual(warning[0]._category_name, 'UserWarning')
self.assertEqual(len(data), 2)

def test_cursor_execute_no_table(self):
sql = 'SELECT * FROM (SELECT 1 AS id UNION ALL SELECT 2) AS t;'
with self.assertNumQueries(1):
Expand Down

0 comments on commit 578bf81

Please sign in to comment.