diff --git a/cachalot/monkey_patch.py b/cachalot/monkey_patch.py index 60cb9f6b0..b9a531bfa 100644 --- a/cachalot/monkey_patch.py +++ b/cachalot/monkey_patch.py @@ -1,4 +1,5 @@ import re +import warnings from collections.abc import Iterable from functools import wraps from time import time @@ -140,6 +141,15 @@ def inner(cursor, sql, *args, **kwargs): if getattr(connection, 'raw', True): if isinstance(sql, bytes): sql = sql.decode('utf-8') + # in case `sql` is not of type `str`, + # we raise a warning and return. + if not isinstance(sql, str): + warnings.warn( + f"Unsupported sql of type {type(sql)}, " + f"skipping raw query cache invalidation's. " + f"Try setting `CACHALOT_INVALIDATE_RAW` to False." + ) + return sql = sql.lower() if SQL_DATA_CHANGE_RE.search(sql): tables = filter_cachable( diff --git a/cachalot/tests/read.py b/cachalot/tests/read.py index aee07fbab..ae23266cb 100644 --- a/cachalot/tests/read.py +++ b/cachalot/tests/read.py @@ -1,4 +1,5 @@ import datetime +import warnings from unittest import skipIf from uuid import UUID from decimal import Decimal @@ -15,6 +16,7 @@ from django.db.models.functions import Coalesce, Now from django.db.transaction import TransactionManagementError from django.test import TransactionTestCase, skipUnlessDBFeature, override_settings +from psycopg2 import sql from pytz import UTC from cachalot.cache import cachalot_caches @@ -1010,6 +1012,20 @@ def test_cursor_execute_bytes(self): data2, [('é',) + l for l in Test.objects.values_list(*attnames)]) + def test_unsupported_sql_type_execute_with_warning(self): + table_name = Test._meta.db_table + sql_ = sql.SQL("SELECT {field} FROM {table}").format( + field=sql.Identifier('name'), + table=sql.Identifier(table_name), + ) + with warnings.catch_warnings(record=True) as warning: + with connection.cursor() as cursor: + cursor.execute(sql_) + data = list(cursor.fetchall()) + self.assertEqual(len(warning), 1) + self.assertEqual(warning[0]._category_name, 'UserWarning') + self.assertEqual(len(data), 2) + def test_cursor_execute_no_table(self): sql = 'SELECT * FROM (SELECT 1 AS id UNION ALL SELECT 2) AS t;' with self.assertNumQueries(1):