Bases: BaseHTTPMiddleware
Middleware for CSRF protection. This middleware will check for a CSRF token
in the headers of any POST, PUT, PATCH, or DELETE request. If the token is
not present or does not match the token stored in the database for the
client, the request will be rejected with a 403 status code.
Source code in src/prefect/server/api/middleware.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73 | class CsrfMiddleware(BaseHTTPMiddleware):
"""
Middleware for CSRF protection. This middleware will check for a CSRF token
in the headers of any POST, PUT, PATCH, or DELETE request. If the token is
not present or does not match the token stored in the database for the
client, the request will be rejected with a 403 status code.
"""
async def dispatch(
self, request: Request, call_next: NextMiddlewareFunction
) -> Response:
"""
Dispatch method for the middleware. This method will check for the
presence of a CSRF token in the headers of the request and compare it
to the token stored in the database for the client. If the token is not
present or does not match, the request will be rejected with a 403
status code.
"""
request_needs_csrf_protection = request.method in {
"POST",
"PUT",
"PATCH",
"DELETE",
}
if (
settings.PREFECT_SERVER_CSRF_PROTECTION_ENABLED.value()
and request_needs_csrf_protection
):
incoming_token = request.headers.get("Prefect-Csrf-Token")
incoming_client = request.headers.get("Prefect-Csrf-Client")
if incoming_token is None:
return JSONResponse(
{"detail": "Missing CSRF token."},
status_code=status.HTTP_403_FORBIDDEN,
)
if incoming_client is None:
return JSONResponse(
{"detail": "Missing client identifier."},
status_code=status.HTTP_403_FORBIDDEN,
)
db = provide_database_interface()
async with db.session_context() as session:
token = await models.csrf_token.read_token_for_client(
session=session, client=incoming_client
)
if token is None or token.token != incoming_token:
return JSONResponse(
{"detail": "Invalid CSRF token or client identifier."},
status_code=status.HTTP_403_FORBIDDEN,
headers={"Access-Control-Allow-Origin": "*"},
)
return await call_next(request)
|
dispatch(request, call_next)
async
Dispatch method for the middleware. This method will check for the
presence of a CSRF token in the headers of the request and compare it
to the token stored in the database for the client. If the token is not
present or does not match, the request will be rejected with a 403
status code.
Source code in src/prefect/server/api/middleware.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73 | async def dispatch(
self, request: Request, call_next: NextMiddlewareFunction
) -> Response:
"""
Dispatch method for the middleware. This method will check for the
presence of a CSRF token in the headers of the request and compare it
to the token stored in the database for the client. If the token is not
present or does not match, the request will be rejected with a 403
status code.
"""
request_needs_csrf_protection = request.method in {
"POST",
"PUT",
"PATCH",
"DELETE",
}
if (
settings.PREFECT_SERVER_CSRF_PROTECTION_ENABLED.value()
and request_needs_csrf_protection
):
incoming_token = request.headers.get("Prefect-Csrf-Token")
incoming_client = request.headers.get("Prefect-Csrf-Client")
if incoming_token is None:
return JSONResponse(
{"detail": "Missing CSRF token."},
status_code=status.HTTP_403_FORBIDDEN,
)
if incoming_client is None:
return JSONResponse(
{"detail": "Missing client identifier."},
status_code=status.HTTP_403_FORBIDDEN,
)
db = provide_database_interface()
async with db.session_context() as session:
token = await models.csrf_token.read_token_for_client(
session=session, client=incoming_client
)
if token is None or token.token != incoming_token:
return JSONResponse(
{"detail": "Invalid CSRF token or client identifier."},
status_code=status.HTTP_403_FORBIDDEN,
headers={"Access-Control-Allow-Origin": "*"},
)
return await call_next(request)
|