diff --git a/pypika/terms.py b/pypika/terms.py index ce7aed65..a277e1a5 100644 --- a/pypika/terms.py +++ b/pypika/terms.py @@ -3,7 +3,21 @@ import uuid from datetime import date from enum import Enum -from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order from pypika.utils import ( @@ -288,57 +302,111 @@ def get_sql(self, **kwargs: Any) -> str: raise NotImplementedError() +def idx_placeholder_gen(idx: int) -> str: + return str(idx + 1) + + +def named_placeholder_gen(idx: int) -> str: + return f'param{idx + 1}' + + class Parameter(Term): is_aggregate = None def __init__(self, placeholder: Union[str, int]) -> None: super().__init__() - self.placeholder = placeholder + self._placeholder = placeholder + + @property + def placeholder(self): + return self._placeholder def get_sql(self, **kwargs: Any) -> str: return str(self.placeholder) + def update_parameters(self, param_key: Any, param_value: Any, **kwargs): + pass -class QmarkParameter(Parameter): - """Question mark style, e.g. ...WHERE name=?""" + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder - def __init__(self) -> None: - pass - def get_sql(self, **kwargs: Any) -> str: - return "?" +class ListParameter(Parameter): + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None: + super().__init__(placeholder=placeholder) + self._parameters = list() + @property + def placeholder(self) -> str: + if callable(self._placeholder): + return self._placeholder(len(self._parameters)) -class NumericParameter(Parameter): - """Numeric, positional style, e.g. ...WHERE name=:1""" + return str(self._placeholder) - def get_sql(self, **kwargs: Any) -> str: - return ":{placeholder}".format(placeholder=self.placeholder) + def get_parameters(self, **kwargs): + return self._parameters + def update_parameters(self, value: Any, **kwargs): + self._parameters.append(value) -class NamedParameter(Parameter): - """Named style, e.g. ...WHERE name=:name""" + +class DictParameter(Parameter): + def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None: + super().__init__(placeholder=placeholder) + self._parameters = dict() + + @property + def placeholder(self) -> str: + if callable(self._placeholder): + return self._placeholder(len(self._parameters)) + + return str(self._placeholder) + + def get_parameters(self, **kwargs): + return self._parameters + + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder[1:] + + def update_parameters(self, param_key: Any, value: Any, **kwargs): + self._parameters[param_key] = value + + +class QmarkParameter(ListParameter): + def get_sql(self, **kwargs): + return '?' + + +class NumericParameter(ListParameter): + """Numeric, positional style, e.g. ...WHERE name=:1""" def get_sql(self, **kwargs: Any) -> str: return ":{placeholder}".format(placeholder=self.placeholder) -class FormatParameter(Parameter): +class FormatParameter(ListParameter): """ANSI C printf format codes, e.g. ...WHERE name=%s""" - def __init__(self) -> None: - pass - def get_sql(self, **kwargs: Any) -> str: return "%s" -class PyformatParameter(Parameter): +class NamedParameter(DictParameter): + """Named style, e.g. ...WHERE name=:name""" + + def get_sql(self, **kwargs: Any) -> str: + return ":{placeholder}".format(placeholder=self.placeholder) + + +class PyformatParameter(DictParameter): """Python extended format codes, e.g. ...WHERE name=%(name)s""" def get_sql(self, **kwargs: Any) -> str: return "%({placeholder})s".format(placeholder=self.placeholder) + def get_param_key(self, placeholder: Any, **kwargs): + return placeholder[2:-2] + class Negative(Term): def __init__(self, term: Term) -> None: @@ -385,9 +453,44 @@ def get_formatted_value(cls, value: Any, **kwargs): return "null" return str(value) - def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str: - sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) - return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + param_sql = parameter.get_sql(**kwargs) + param_key = parameter.get_param_key(placeholder=param_sql) + + return param_sql, param_key + + def get_sql( + self, + quote_char: Optional[str] = None, + secondary_quote_char: str = "'", + parameter: Parameter = None, + **kwargs: Any, + ) -> str: + if parameter is None: + sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs) + return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs) + + # Don't stringify numbers when using a parameter + if isinstance(self.value, (int, float)): + value_sql = self.value + else: + value_sql = self.get_value_sql(quote_char=quote_char, **kwargs) + param_sql, param_key = self._get_param_data(parameter, **kwargs) + parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs) + + return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs) + + +class ParameterValueWrapper(ValueWrapper): + def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None: + super().__init__(value, alias) + self._parameter = parameter + + def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]: + param_sql = self._parameter.get_sql(**kwargs) + param_key = self._parameter.get_param_key(placeholder=param_sql) + + return param_sql, param_key class JSON(Term): @@ -551,6 +654,7 @@ def __init__( if isinstance(table, str): # avoid circular import at load time from pypika.queries import Table + table = Table(table) self.table = table diff --git a/pypika/tests/test_parameter.py b/pypika/tests/test_parameter.py index e19666a0..c11e9afc 100644 --- a/pypika/tests/test_parameter.py +++ b/pypika/tests/test_parameter.py @@ -1,4 +1,5 @@ import unittest +from datetime import date from pypika import ( FormatParameter, @@ -10,6 +11,7 @@ Query, Tables, ) +from pypika.terms import ListParameter, ParameterValueWrapper class ParametrizedTests(unittest.TestCase): @@ -92,3 +94,113 @@ def test_format_parameter(self): def test_pyformat_parameter(self): self.assertEqual('%(buz)s', PyformatParameter('buz').get_sql()) + + +class ParametrizedTestsWithValues(unittest.TestCase): + table_abc, table_efg = Tables("abc", "efg") + + def test_param_insert(self): + q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo') + + parameter = QmarkParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql) + self.assertEqual([1, 2.2, 'foo'], parameter.get_parameters()) + + def test_param_select_join(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == 'foobar') + .join(self.table_efg) + .on(self.table_abc.id == self.table_efg.abc_id) + .where(self.table_efg.date >= date(2024, 2, 22)) + .limit(10) + ) + + parameter = FormatParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10', + sql, + ) + self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters()) + + def test_param_select_subquery(self): + q = ( + Query.from_(self.table_abc) + .select("*") + .where(self.table_abc.category == 'foobar') + .where( + self.table_abc.id.isin( + Query.from_(self.table_efg) + .select(self.table_efg.abc_id) + .where(self.table_efg.date >= date(2024, 2, 22)) + ) + ) + .limit(10) + ) + + parameter = ListParameter(placeholder=lambda idx: f'&{idx+1}') + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10', + sql, + ) + self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters()) + + def test_join(self): + subquery = ( + Query.from_(self.table_efg) + .select(self.table_efg.fiz, self.table_efg.buz) + .where(self.table_efg.buz == 'buz') + ) + + q = ( + Query.from_(self.table_abc) + .join(subquery) + .on(self.table_abc.bar == subquery.buz) + .select(self.table_abc.foo, subquery.fiz) + .where(self.table_abc.bar == 'bar') + ) + + parameter = NamedParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2', + sql, + ) + self.assertEqual({'param1': 'buz', 'param2': 'bar'}, parameter.get_parameters()) + + def test_join_with_parameter_value_wrapper(self): + subquery = ( + Query.from_(self.table_efg) + .select(self.table_efg.fiz, self.table_efg.buz) + .where(self.table_efg.buz == ParameterValueWrapper(Parameter(':buz'), 'buz')) + ) + + q = ( + Query.from_(self.table_abc) + .join(subquery) + .on(self.table_abc.bar == subquery.buz) + .select(self.table_abc.foo, subquery.fiz) + .where(self.table_abc.bar == ParameterValueWrapper(NamedParameter('bar'), 'bar')) + ) + + parameter = NamedParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual( + 'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)' + ' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar', + sql, + ) + self.assertEqual({':buz': 'buz', 'bar': 'bar'}, parameter.get_parameters()) + + def test_pyformat_parameter(self): + q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo') + + parameter = PyformatParameter() + sql = q.get_sql(parameter=parameter) + self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql) + self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters()) diff --git a/pypika/tests/test_terms.py b/pypika/tests/test_terms.py index 607c4c01..4c7590df 100644 --- a/pypika/tests/test_terms.py +++ b/pypika/tests/test_terms.py @@ -20,7 +20,7 @@ def test_init_with_str_table(self): test_table_name = "test_table" field = Field(name="name", table=test_table_name) self.assertEqual(field.table, Table(name=test_table_name)) - + class FieldHashingTests(TestCase): def test_tabled_eq_fields_equally_hashed(self):