Skip to content

Commit c5237c1

Browse files
committed
refactor(middleware): Refactor internals of CSPMiddleware so that it's easier to extend existing logic without copy/pasting it into subclass
1 parent ed0b7a4 commit c5237c1

File tree

2 files changed

+47
-42
lines changed

2 files changed

+47
-42
lines changed

csp/contrib/rate_limiting.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
from django.conf import settings
77

8-
from csp.middleware import CSPMiddleware
9-
from csp.utils import build_policy
8+
from csp.middleware import CSPMiddleware, PolicyParts
109

1110
if TYPE_CHECKING:
1211
from django.http import HttpRequest, HttpResponseBase
@@ -16,38 +15,20 @@ class RateLimitedCSPMiddleware(CSPMiddleware):
1615
"""A CSP middleware that rate-limits the number of violation reports sent
1716
to report-uri by excluding it from some requests."""
1817

19-
def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
20-
config = getattr(response, "_csp_config", None)
21-
update = getattr(response, "_csp_update", None)
22-
replace = getattr(response, "_csp_replace", {})
23-
nonce = getattr(request, "_csp_nonce", None)
24-
25-
policy = getattr(settings, "CONTENT_SECURITY_POLICY", None)
26-
27-
if policy is None:
28-
return ""
29-
30-
report_percentage = policy.get("REPORT_PERCENTAGE", 100)
31-
include_report_uri = random.randint(0, 100) < report_percentage
32-
if not include_report_uri:
33-
replace["report-uri"] = None
34-
35-
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
36-
37-
def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
38-
config = getattr(response, "_csp_config_ro", None)
39-
update = getattr(response, "_csp_update_ro", None)
40-
replace = getattr(response, "_csp_replace_ro", {})
41-
nonce = getattr(request, "_csp_nonce", None)
42-
43-
policy = getattr(settings, "CONTENT_SECURITY_POLICY_REPORT_ONLY", None)
18+
def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
19+
policy_parts = super().get_policy_parts(request, response, report_only)
4420

21+
csp_setting_name = "CONTENT_SECURITY_POLICY_REPORT_ONLY" if report_only else "CONTENT_SECURITY_POLICY"
22+
policy = getattr(settings, csp_setting_name, None)
4523
if policy is None:
46-
return ""
24+
return policy_parts
4725

4826
report_percentage = policy.get("REPORT_PERCENTAGE", 100)
4927
include_report_uri = random.randint(0, 100) < report_percentage
5028
if not include_report_uri:
51-
replace["report-uri"] = None
29+
if policy_parts.replace is None:
30+
policy_parts.replace = {"report-uri": None}
31+
else:
32+
policy_parts.replace["report-uri"] = None
5233

53-
return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True)
34+
return policy_parts

csp/middleware.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import base64
44
import http.client as http_client
55
import os
6+
import warnings
7+
from dataclasses import asdict, dataclass
68
from functools import partial
79
from typing import TYPE_CHECKING
810

@@ -11,12 +13,21 @@
1113
from django.utils.functional import SimpleLazyObject
1214

1315
from csp.constants import HEADER, HEADER_REPORT_ONLY
14-
from csp.utils import build_policy
16+
from csp.utils import _DIRECTIVES, build_policy
1517

1618
if TYPE_CHECKING:
1719
from django.http import HttpRequest, HttpResponseBase
1820

1921

22+
@dataclass
23+
class PolicyParts:
24+
# A dataclass is used rather than a namedtuple so that the attributes are mutable
25+
config: _DIRECTIVES = None
26+
update: _DIRECTIVES = None
27+
replace: _DIRECTIVES = None
28+
nonce: str | None = None
29+
30+
2031
class CSPMiddleware(MiddlewareMixin):
2132
"""
2233
Implements the Content-Security-Policy response header, which
@@ -25,6 +36,7 @@ class CSPMiddleware(MiddlewareMixin):
2536
2637
See http://www.w3.org/TR/CSP/
2738
39+
Can be customised by subclassing and extending the get_policy_parts method.
2840
"""
2941

3042
def _make_nonce(self, request: HttpRequest) -> str:
@@ -49,7 +61,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
4961
if response.status_code in exempted_debug_codes and settings.DEBUG:
5062
return response
5163

52-
csp = self.build_policy(request, response)
64+
policy_parts = self.get_policy_parts(request=request, response=response)
65+
csp = build_policy(**asdict(policy_parts))
5366
if csp:
5467
# Only set header if not already set and not an excluded prefix and not exempted.
5568
is_not_exempt = getattr(response, "_csp_exempt", False) is False
@@ -60,7 +73,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
6073
if no_header and is_not_exempt and is_not_excluded:
6174
response[HEADER] = csp
6275

63-
csp_ro = self.build_policy_ro(request, response)
76+
policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True)
77+
csp_ro = build_policy(**asdict(policy_parts_ro), report_only=True)
6478
if csp_ro:
6579
# Only set header if not already set and not an excluded prefix and not exempted.
6680
is_not_exempt = getattr(response, "_csp_exempt_ro", False) is False
@@ -74,15 +88,25 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
7488
return response
7589

7690
def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
77-
config = getattr(response, "_csp_config", None)
78-
update = getattr(response, "_csp_update", None)
79-
replace = getattr(response, "_csp_replace", None)
80-
nonce = getattr(request, "_csp_nonce", None)
81-
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
91+
warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning)
92+
policy_parts = self.get_policy_parts(request=request, response=response, report_only=False)
93+
return build_policy(**asdict(policy_parts))
8294

8395
def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
84-
config = getattr(response, "_csp_config_ro", None)
85-
update = getattr(response, "_csp_update_ro", None)
86-
replace = getattr(response, "_csp_replace_ro", None)
96+
warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning)
97+
policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True)
98+
return build_policy(**asdict(policy_parts_ro), report_only=True)
99+
100+
def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
101+
if report_only:
102+
config = getattr(response, "_csp_config_ro", None)
103+
update = getattr(response, "_csp_update_ro", None)
104+
replace = getattr(response, "_csp_replace_ro", None)
105+
else:
106+
config = getattr(response, "_csp_config", None)
107+
update = getattr(response, "_csp_update", None)
108+
replace = getattr(response, "_csp_replace", None)
109+
87110
nonce = getattr(request, "_csp_nonce", None)
88-
return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True)
111+
112+
return PolicyParts(config, update, replace, nonce)

0 commit comments

Comments
 (0)