Skip to content

Commit 76e3751

Browse files
charettesfelixxm
authored andcommitted
Refs #33374 -- Adjusted full match condition handling.
Adjusting WhereNode.as_sql() to raise an exception when encoutering a full match just like with empty matches ensures that all case are explicitly handled.
1 parent 4b702c8 commit 76e3751

File tree

11 files changed

+114
-61
lines changed

11 files changed

+114
-61
lines changed

django/core/exceptions.py

+6
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ class EmptyResultSet(Exception):
233233
pass
234234

235235

236+
class FullResultSet(Exception):
237+
"""A database query predicate is matches everything."""
238+
239+
pass
240+
241+
236242
class SynchronousOnlyOperation(Exception):
237243
"""The user tried to call a sync-only function from an async context."""
238244

django/db/backends/mysql/compiler.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from django.core.exceptions import FieldError
1+
from django.core.exceptions import FieldError, FullResultSet
22
from django.db.models.expressions import Col
33
from django.db.models.sql import compiler
44

@@ -40,12 +40,16 @@ def as_sql(self):
4040
"DELETE %s FROM"
4141
% self.quote_name_unless_alias(self.query.get_initial_alias())
4242
]
43-
from_sql, from_params = self.get_from_clause()
43+
from_sql, params = self.get_from_clause()
4444
result.extend(from_sql)
45-
where_sql, where_params = self.compile(where)
46-
if where_sql:
45+
try:
46+
where_sql, where_params = self.compile(where)
47+
except FullResultSet:
48+
pass
49+
else:
4750
result.append("WHERE %s" % where_sql)
48-
return " ".join(result), tuple(from_params) + tuple(where_params)
51+
params.extend(where_params)
52+
return " ".join(result), tuple(params)
4953

5054

5155
class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):

django/db/models/aggregates.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Classes to represent the definitions of aggregate functions.
33
"""
4-
from django.core.exceptions import FieldError
4+
from django.core.exceptions import FieldError, FullResultSet
55
from django.db.models.expressions import Case, Func, Star, When
66
from django.db.models.fields import IntegerField
77
from django.db.models.functions.comparison import Coalesce
@@ -104,8 +104,11 @@ def as_sql(self, compiler, connection, **extra_context):
104104
extra_context["distinct"] = "DISTINCT " if self.distinct else ""
105105
if self.filter:
106106
if connection.features.supports_aggregate_filter_clause:
107-
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
108-
if filter_sql:
107+
try:
108+
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
109+
except FullResultSet:
110+
pass
111+
else:
109112
template = self.filter_template % extra_context.get(
110113
"template", self.template
111114
)

django/db/models/expressions.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from decimal import Decimal
88
from uuid import UUID
99

10-
from django.core.exceptions import EmptyResultSet, FieldError
10+
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
1111
from django.db import DatabaseError, NotSupportedError, connection
1212
from django.db.models import fields
1313
from django.db.models.constants import LOOKUP_SEP
@@ -955,6 +955,8 @@ def as_sql(
955955
if empty_result_set_value is NotImplemented:
956956
raise
957957
arg_sql, arg_params = compiler.compile(Value(empty_result_set_value))
958+
except FullResultSet:
959+
arg_sql, arg_params = compiler.compile(Value(True))
958960
sql_parts.append(arg_sql)
959961
params.extend(arg_params)
960962
data = {**self.extra, **extra_context}
@@ -1367,14 +1369,6 @@ def as_sql(self, compiler, connection, template=None, **extra_context):
13671369
template_params = extra_context
13681370
sql_params = []
13691371
condition_sql, condition_params = compiler.compile(self.condition)
1370-
# Filters that match everything are handled as empty strings in the
1371-
# WHERE clause, but in a CASE WHEN expression they must use a predicate
1372-
# that's always True.
1373-
if condition_sql == "":
1374-
if connection.features.supports_boolean_expr_in_select_clause:
1375-
condition_sql, condition_params = compiler.compile(Value(True))
1376-
else:
1377-
condition_sql, condition_params = "1=1", ()
13781372
template_params["condition"] = condition_sql
13791373
result_sql, result_params = compiler.compile(self.result)
13801374
template_params["result"] = result_sql
@@ -1461,14 +1455,17 @@ def as_sql(
14611455
template_params = {**self.extra, **extra_context}
14621456
case_parts = []
14631457
sql_params = []
1458+
default_sql, default_params = compiler.compile(self.default)
14641459
for case in self.cases:
14651460
try:
14661461
case_sql, case_params = compiler.compile(case)
14671462
except EmptyResultSet:
14681463
continue
1464+
except FullResultSet:
1465+
default_sql, default_params = compiler.compile(case.result)
1466+
break
14691467
case_parts.append(case_sql)
14701468
sql_params.extend(case_params)
1471-
default_sql, default_params = compiler.compile(self.default)
14721469
if not case_parts:
14731470
return default_sql, default_params
14741471
case_joiner = case_joiner or self.case_joiner

django/db/models/fields/__init__.py

-9
Original file line numberDiff line numberDiff line change
@@ -1103,15 +1103,6 @@ def formfield(self, **kwargs):
11031103
defaults = {"form_class": form_class, "required": False}
11041104
return super().formfield(**{**defaults, **kwargs})
11051105

1106-
def select_format(self, compiler, sql, params):
1107-
sql, params = super().select_format(compiler, sql, params)
1108-
# Filters that match everything are handled as empty strings in the
1109-
# WHERE clause, but in SELECT or GROUP BY list they must use a
1110-
# predicate that's always True.
1111-
if sql == "":
1112-
sql = "1"
1113-
return sql, params
1114-
11151106

11161107
class CharField(Field):
11171108
description = _("String (up to %(max_length)s)")

django/db/models/sql/compiler.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from functools import partial
55
from itertools import chain
66

7-
from django.core.exceptions import EmptyResultSet, FieldError
7+
from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
88
from django.db import DatabaseError, NotSupportedError
99
from django.db.models.constants import LOOKUP_SEP
1010
from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
@@ -169,7 +169,7 @@ def get_group_by(self, select, order_by):
169169
expr = Ref(alias, expr)
170170
try:
171171
sql, params = self.compile(expr)
172-
except EmptyResultSet:
172+
except (EmptyResultSet, FullResultSet):
173173
continue
174174
sql, params = expr.select_format(self, sql, params)
175175
params_hash = make_hashable(params)
@@ -287,6 +287,8 @@ def get_select_from_parent(klass_info):
287287
sql, params = "0", ()
288288
else:
289289
sql, params = self.compile(Value(empty_result_set_value))
290+
except FullResultSet:
291+
sql, params = self.compile(Value(True))
290292
else:
291293
sql, params = col.select_format(self, sql, params)
292294
if alias is None and with_col_aliases:
@@ -721,9 +723,16 @@ def as_sql(self, with_limits=True, with_col_aliases=False):
721723
raise
722724
# Use a predicate that's always False.
723725
where, w_params = "0 = 1", []
724-
having, h_params = (
725-
self.compile(self.having) if self.having is not None else ("", [])
726-
)
726+
except FullResultSet:
727+
where, w_params = "", []
728+
try:
729+
having, h_params = (
730+
self.compile(self.having)
731+
if self.having is not None
732+
else ("", [])
733+
)
734+
except FullResultSet:
735+
having, h_params = "", []
727736
result = ["SELECT"]
728737
params = []
729738

@@ -1817,11 +1826,12 @@ def contains_self_reference_subquery(self):
18171826
)
18181827

18191828
def _as_sql(self, query):
1820-
result = ["DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)]
1821-
where, params = self.compile(query.where)
1822-
if where:
1823-
result.append("WHERE %s" % where)
1824-
return " ".join(result), tuple(params)
1829+
delete = "DELETE FROM %s" % self.quote_name_unless_alias(query.base_table)
1830+
try:
1831+
where, params = self.compile(query.where)
1832+
except FullResultSet:
1833+
return delete, ()
1834+
return f"{delete} WHERE {where}", tuple(params)
18251835

18261836
def as_sql(self):
18271837
"""
@@ -1906,8 +1916,11 @@ def as_sql(self):
19061916
"UPDATE %s SET" % qn(table),
19071917
", ".join(values),
19081918
]
1909-
where, params = self.compile(self.query.where)
1910-
if where:
1919+
try:
1920+
where, params = self.compile(self.query.where)
1921+
except FullResultSet:
1922+
params = []
1923+
else:
19111924
result.append("WHERE %s" % where)
19121925
return " ".join(result), tuple(update_params + params)
19131926

django/db/models/sql/datastructures.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Useful auxiliary data structures for query construction. Not useful outside
33
the SQL domain.
44
"""
5+
from django.core.exceptions import FullResultSet
56
from django.db.models.sql.constants import INNER, LOUTER
67

78

@@ -100,8 +101,11 @@ def as_sql(self, compiler, connection):
100101
join_conditions.append("(%s)" % extra_sql)
101102
params.extend(extra_params)
102103
if self.filtered_relation:
103-
extra_sql, extra_params = compiler.compile(self.filtered_relation)
104-
if extra_sql:
104+
try:
105+
extra_sql, extra_params = compiler.compile(self.filtered_relation)
106+
except FullResultSet:
107+
pass
108+
else:
105109
join_conditions.append("(%s)" % extra_sql)
106110
params.extend(extra_params)
107111
if not join_conditions:

django/db/models/sql/where.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import operator
55
from functools import reduce
66

7-
from django.core.exceptions import EmptyResultSet
7+
from django.core.exceptions import EmptyResultSet, FullResultSet
88
from django.db.models.expressions import Case, When
99
from django.db.models.lookups import Exact
1010
from django.utils import tree
@@ -145,6 +145,8 @@ def as_sql(self, compiler, connection):
145145
sql, params = compiler.compile(child)
146146
except EmptyResultSet:
147147
empty_needed -= 1
148+
except FullResultSet:
149+
full_needed -= 1
148150
else:
149151
if sql:
150152
result.append(sql)
@@ -158,24 +160,25 @@ def as_sql(self, compiler, connection):
158160
# counts.
159161
if empty_needed == 0:
160162
if self.negated:
161-
return "", []
163+
raise FullResultSet
162164
else:
163165
raise EmptyResultSet
164166
if full_needed == 0:
165167
if self.negated:
166168
raise EmptyResultSet
167169
else:
168-
return "", []
170+
raise FullResultSet
169171
conn = " %s " % self.connector
170172
sql_string = conn.join(result)
171-
if sql_string:
172-
if self.negated:
173-
# Some backends (Oracle at least) need parentheses
174-
# around the inner SQL in the negated case, even if the
175-
# inner SQL contains just a single expression.
176-
sql_string = "NOT (%s)" % sql_string
177-
elif len(result) > 1 or self.resolved:
178-
sql_string = "(%s)" % sql_string
173+
if not sql_string:
174+
raise FullResultSet
175+
if self.negated:
176+
# Some backends (Oracle at least) need parentheses around the inner
177+
# SQL in the negated case, even if the inner SQL contains just a
178+
# single expression.
179+
sql_string = "NOT (%s)" % sql_string
180+
elif len(result) > 1 or self.resolved:
181+
sql_string = "(%s)" % sql_string
179182
return sql_string, result_params
180183

181184
def get_group_by_cols(self):

docs/ref/exceptions.txt

+11
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ Django core exception classes are defined in ``django.core.exceptions``.
4242
return any results. Most Django projects won't encounter this exception,
4343
but it might be useful for implementing custom lookups and expressions.
4444

45+
``FullResultSet``
46+
-----------------
47+
48+
.. exception:: FullResultSet
49+
50+
.. versionadded:: 4.2
51+
52+
``FullResultSet`` may be raised during query generation if a query will
53+
match everything. Most Django projects won't encounter this exception, but
54+
it might be useful for implementing custom lookups and expressions.
55+
4556
``FieldDoesNotExist``
4657
---------------------
4758

tests/annotations/tests.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,15 @@
2424
When,
2525
)
2626
from django.db.models.expressions import RawSQL
27-
from django.db.models.functions import Coalesce, ExtractYear, Floor, Length, Lower, Trim
27+
from django.db.models.functions import (
28+
Cast,
29+
Coalesce,
30+
ExtractYear,
31+
Floor,
32+
Length,
33+
Lower,
34+
Trim,
35+
)
2836
from django.test import TestCase, skipUnlessDBFeature
2937
from django.test.utils import register_lookup
3038

@@ -282,6 +290,13 @@ def test_full_expression_annotation(self):
282290
self.assertEqual(len(books), Book.objects.count())
283291
self.assertTrue(all(book.selected for book in books))
284292

293+
def test_full_expression_wrapped_annotation(self):
294+
books = Book.objects.annotate(
295+
selected=Coalesce(~Q(pk__in=[]), True),
296+
)
297+
self.assertEqual(len(books), Book.objects.count())
298+
self.assertTrue(all(book.selected for book in books))
299+
285300
def test_full_expression_annotation_with_aggregation(self):
286301
qs = Book.objects.filter(isbn="159059725").annotate(
287302
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
@@ -292,7 +307,7 @@ def test_full_expression_annotation_with_aggregation(self):
292307
def test_aggregate_over_full_expression_annotation(self):
293308
qs = Book.objects.annotate(
294309
selected=ExpressionWrapper(~Q(pk__in=[]), output_field=BooleanField()),
295-
).aggregate(Sum("selected"))
310+
).aggregate(selected__sum=Sum(Cast("selected", IntegerField())))
296311
self.assertEqual(qs["selected__sum"], Book.objects.count())
297312

298313
def test_empty_queryset_annotation(self):

0 commit comments

Comments
 (0)