Skip to content

Commit 20b8683

Browse files
committed
refactor(middleware): Refactor internals of CSPMiddleware so that it's easier to extend existing logic without copy/pasting it into subclass
1 parent f860f6c commit 20b8683

File tree

4 files changed

+124
-57
lines changed

4 files changed

+124
-57
lines changed

csp/contrib/rate_limiting.py

Lines changed: 18 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
from django.conf import settings
77

8-
from csp.middleware import CSPMiddleware
9-
from csp.utils import build_policy
8+
from csp.middleware import CSPMiddleware, PolicyParts
109

1110
if TYPE_CHECKING:
1211
from django.http import HttpRequest, HttpResponseBase
@@ -16,48 +15,27 @@ class RateLimitedCSPMiddleware(CSPMiddleware):
1615
"""A CSP middleware that rate-limits the number of violation reports sent
1716
to report-uri by excluding it from some requests."""
1817

19-
def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
20-
config = getattr(response, "_csp_config", None)
21-
update = getattr(response, "_csp_update", None)
22-
replace = getattr(response, "_csp_replace", {})
23-
nonce = getattr(request, "_csp_nonce", None)
24-
25-
policy = getattr(settings, "CONTENT_SECURITY_POLICY", None)
26-
27-
if policy is None:
28-
return ""
29-
30-
report_percentage = policy.get("REPORT_PERCENTAGE", 100)
31-
remove_report = random.randint(0, 99) >= report_percentage
32-
if remove_report:
33-
replace.update(
34-
{
35-
"report-uri": None,
36-
"report-to": None,
37-
}
38-
)
39-
40-
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
41-
42-
def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
43-
config = getattr(response, "_csp_config_ro", None)
44-
update = getattr(response, "_csp_update_ro", None)
45-
replace = getattr(response, "_csp_replace_ro", {})
46-
nonce = getattr(request, "_csp_nonce", None)
47-
48-
policy = getattr(settings, "CONTENT_SECURITY_POLICY_REPORT_ONLY", None)
18+
def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
19+
policy_parts = super().get_policy_parts(request, response, report_only)
4920

21+
csp_setting_name = "CONTENT_SECURITY_POLICY_REPORT_ONLY" if report_only else "CONTENT_SECURITY_POLICY"
22+
policy = getattr(settings, csp_setting_name, None)
5023
if policy is None:
51-
return ""
24+
return policy_parts
5225

53-
report_percentage = policy.get("REPORT_PERCENTAGE", 100)
54-
remove_report = random.randint(0, 99) >= report_percentage
26+
remove_report = random.randint(0, 99) >= policy.get("REPORT_PERCENTAGE", 100)
5527
if remove_report:
56-
replace.update(
57-
{
28+
if policy_parts.replace is None:
29+
policy_parts.replace = {
5830
"report-uri": None,
5931
"report-to": None,
6032
}
61-
)
62-
63-
return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True)
33+
else:
34+
policy_parts.replace.update(
35+
{
36+
"report-uri": None,
37+
"report-to": None,
38+
}
39+
)
40+
41+
return policy_parts

csp/middleware.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import base64
44
import http.client as http_client
55
import os
6+
import warnings
7+
from dataclasses import asdict, dataclass
68
from functools import partial
79
from typing import TYPE_CHECKING
810

@@ -11,12 +13,21 @@
1113
from django.utils.functional import SimpleLazyObject
1214

1315
from csp.constants import HEADER, HEADER_REPORT_ONLY
14-
from csp.utils import build_policy
16+
from csp.utils import DIRECTIVES_T, build_policy
1517

1618
if TYPE_CHECKING:
1719
from django.http import HttpRequest, HttpResponseBase
1820

1921

22+
@dataclass
23+
class PolicyParts:
24+
# A dataclass is used rather than a namedtuple so that the attributes are mutable
25+
config: DIRECTIVES_T | None = None
26+
update: DIRECTIVES_T | None = None
27+
replace: DIRECTIVES_T | None = None
28+
nonce: str | None = None
29+
30+
2031
class CSPMiddleware(MiddlewareMixin):
2132
"""
2233
Implements the Content-Security-Policy response header, which
@@ -25,6 +36,7 @@ class CSPMiddleware(MiddlewareMixin):
2536
2637
See http://www.w3.org/TR/CSP/
2738
39+
Can be customised by subclassing and extending the get_policy_parts method.
2840
"""
2941

3042
def _make_nonce(self, request: HttpRequest) -> str:
@@ -49,7 +61,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
4961
if response.status_code in exempted_debug_codes and settings.DEBUG:
5062
return response
5163

52-
csp = self.build_policy(request, response)
64+
policy_parts = self.get_policy_parts(request=request, response=response)
65+
csp = build_policy(**asdict(policy_parts))
5366
if csp:
5467
# Only set header if not already set and not an excluded prefix and not exempted.
5568
is_not_exempt = getattr(response, "_csp_exempt", False) is False
@@ -60,7 +73,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
6073
if no_header and is_not_exempt and is_not_excluded:
6174
response[HEADER] = csp
6275

63-
csp_ro = self.build_policy_ro(request, response)
76+
policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True)
77+
csp_ro = build_policy(**asdict(policy_parts_ro), report_only=True)
6478
if csp_ro:
6579
# Only set header if not already set and not an excluded prefix and not exempted.
6680
is_not_exempt = getattr(response, "_csp_exempt_ro", False) is False
@@ -74,15 +88,25 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
7488
return response
7589

7690
def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
77-
config = getattr(response, "_csp_config", None)
78-
update = getattr(response, "_csp_update", None)
79-
replace = getattr(response, "_csp_replace", None)
80-
nonce = getattr(request, "_csp_nonce", None)
81-
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
91+
warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning)
92+
policy_parts = self.get_policy_parts(request=request, response=response, report_only=False)
93+
return build_policy(**asdict(policy_parts))
8294

8395
def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
84-
config = getattr(response, "_csp_config_ro", None)
85-
update = getattr(response, "_csp_update_ro", None)
86-
replace = getattr(response, "_csp_replace_ro", None)
96+
warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning)
97+
policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True)
98+
return build_policy(**asdict(policy_parts_ro), report_only=True)
99+
100+
def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
101+
if report_only:
102+
config = getattr(response, "_csp_config_ro", None)
103+
update = getattr(response, "_csp_update_ro", None)
104+
replace = getattr(response, "_csp_replace_ro", None)
105+
else:
106+
config = getattr(response, "_csp_config", None)
107+
update = getattr(response, "_csp_update", None)
108+
replace = getattr(response, "_csp_replace", None)
109+
87110
nonce = getattr(request, "_csp_nonce", None)
88-
return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True)
111+
112+
return PolicyParts(config, update, replace, nonce)

csp/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@
5252
"block-all-mixed-content": None, # Deprecated.
5353
}
5454

55-
_DIRECTIVES = Dict[str, Any]
55+
DIRECTIVES_T = Dict[str, Any]
5656

5757

58-
def default_config(csp: _DIRECTIVES | None) -> _DIRECTIVES | None:
58+
def default_config(csp: DIRECTIVES_T | None) -> DIRECTIVES_T | None:
5959
if csp is None:
6060
return None
6161
# Make a copy of the passed in config to avoid mutating it, and also to drop any unknown keys.
@@ -66,9 +66,9 @@ def default_config(csp: _DIRECTIVES | None) -> _DIRECTIVES | None:
6666

6767

6868
def build_policy(
69-
config: _DIRECTIVES | None = None,
70-
update: _DIRECTIVES | None = None,
71-
replace: _DIRECTIVES | None = None,
69+
config: DIRECTIVES_T | None = None,
70+
update: DIRECTIVES_T | None = None,
71+
replace: DIRECTIVES_T | None = None,
7272
nonce: str | None = None,
7373
report_only: bool = False,
7474
) -> str:

docs/migration-guide.rst

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,71 @@ decorator now requires parentheses when used with and without arguments. For exa
191191
Look for uses of the following decorators in your code: ``@csp``, ``@csp_update``, ``@csp_replace``,
192192
and ``@csp_exempt``.
193193

194+
Migrating Custom Middleware
195+
===========================
196+
The `CSPMiddleware` has changed in order to support easier extension via subclassing.
197+
198+
The `CSPMiddleware.build_policy` and `CSPMiddleware.build_policy_ro` methods have been deprecated
199+
in 4.0 and replaced with a new method `CSPMiddleware.build_policy_parts`.
200+
201+
.. note::
202+
The deprecated methods will be removed in 4.1.
203+
204+
Unlike the old methods, which returned the built CSP policy header string, `build_policy_parts`
205+
returns a dataclass that can be modified and updated before the policy is built. This allows
206+
custom middleware to modify the policy whilst inheriting behaviour from the base classes.
207+
208+
An existing custom middleware, such as this:
209+
210+
.. code-block:: python
211+
212+
from django.http import HttpRequest, HttpResponseBase
213+
214+
from csp.middleware import CSPMiddleware, PolicyParts
215+
216+
class ACustomMiddleware(CSPMiddleware):
217+
218+
def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
219+
config = getattr(response, "_csp_config", None)
220+
update = getattr(response, "_csp_update", None)
221+
replace = getattr(response, "_csp_replace", {})
222+
nonce = getattr(request, "_csp_nonce", None)
223+
224+
# ... do custom CSP policy logic ...
225+
226+
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
227+
228+
def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
229+
config = getattr(response, "_csp_config_ro", None)
230+
update = getattr(response, "_csp_update_ro", None)
231+
replace = getattr(response, "_csp_replace_ro", {})
232+
nonce = getattr(request, "_csp_nonce", None)
233+
234+
# ... do custom CSP report only policy logic ...
235+
236+
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
237+
238+
can be replaced with this:
239+
240+
.. code-block:: python
241+
242+
from django.http import HttpRequest, HttpResponseBase
243+
244+
from csp.middleware import CSPMiddleware, PolicyParts
245+
246+
247+
class ACustomMiddleware(CSPMiddleware):
248+
249+
def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
250+
policy_parts = super().get_policy_parts(request, response, report_only)
251+
252+
if report_only:
253+
# ... do custom CSP report only policy logic ...
254+
else:
255+
# ... do custom CSP policy logic ...
256+
257+
return policy_parts
258+
194259
Conclusion
195260
==========
196261

0 commit comments

Comments
 (0)