From 82d7bd62256e2d14dfa8711b74786d8f96c2dd10 Mon Sep 17 00:00:00 2001 From: sunfkny <30853461+sunfkny@users.noreply.github.com> Date: Wed, 13 Nov 2024 02:22:56 +0800 Subject: [PATCH] Add pattern support to param --- ninja/params/models.py | 15 ++++++++++++++- tests/main.py | 5 +++++ tests/test_path.py | 16 ++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/ninja/params/models.py b/ninja/params/models.py index 0f36e6e11..96ae8af66 100644 --- a/ninja/params/models.py +++ b/ninja/params/models.py @@ -1,6 +1,17 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Pattern, + Tuple, + Type, + TypeVar, + Union, +) from django.conf import settings from django.http import HttpRequest @@ -204,6 +215,7 @@ def __init__( examples: Optional[Dict[str, Any]] = None, deprecated: Optional[bool] = None, include_in_schema: Optional[bool] = True, + pattern: Union[str, Pattern[str], None] = None, # param_name: str = None, # param_type: Any = None, **extra: Any, @@ -237,6 +249,7 @@ def __init__( le=le, min_length=min_length, max_length=max_length, + pattern=pattern, json_schema_extra=json_schema_extra, **extra, ) diff --git a/tests/main.py b/tests/main.py index d313d08ea..6535036a5 100644 --- a/tests/main.py +++ b/tests/main.py @@ -135,6 +135,11 @@ def get_path_param_le_ge_int(request, item_id: int = Path(..., le=3, ge=1)): return item_id +@router.get("/path/param-pattern/{item_id}") +def get_path_param_pattern(request, item_id: str = Path(..., pattern="^foo")): + return item_id + + @router.get("/path/param-django-str/{str:item_id}") def get_path_param_django_str(request, item_id): return item_id diff --git a/tests/test_path.py b/tests/test_path.py index 2b325f585..687d6486d 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -172,6 +172,20 @@ def test_text_get(): } +response_not_valid_pattern = { + "detail": [ + { + "ctx": { + "pattern": "^foo", + }, + "loc": ["path", "item_id"], + "msg": "String should match pattern '^foo'", + "type": "string_pattern_mismatch", + } + ] +} + + @pytest.mark.parametrize( "path,expected_status,expected_response", [ @@ -249,6 +263,8 @@ def test_text_get(): ("/path/param-le-ge-int/3", 200, 3), ("/path/param-le-ge-int/4", 422, response_less_than_equal_3), ("/path/param-le-ge-int/2.7", 422, response_not_valid_int_float), + ("/path/param-pattern/foo", 200, "foo"), + ("/path/param-pattern/fo", 422, response_not_valid_pattern), ], ) def test_get_path(path, expected_status, expected_response):