Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⭐ Improve parameterized query support - fixes #793 #794

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 126 additions & 22 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,21 @@
import uuid
from datetime import date
from enum import Enum
from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Type, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
)

from pypika.enums import Arithmetic, Boolean, Comparator, Dialects, Equality, JSONOperators, Matching, Order
from pypika.utils import (
Expand Down Expand Up @@ -288,57 +302,111 @@ def get_sql(self, **kwargs: Any) -> str:
raise NotImplementedError()


def idx_placeholder_gen(idx: int) -> str:
return str(idx + 1)


def named_placeholder_gen(idx: int) -> str:
return f'param{idx + 1}'


class Parameter(Term):
is_aggregate = None

def __init__(self, placeholder: Union[str, int]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type hint here doesn't include callable? Is that correct still?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, only ListParameter and DictParameter (and classes that inherit from them) support callable.
The base parameter does not.

super().__init__()
self.placeholder = placeholder
self._placeholder = placeholder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to use a setter here. With that, implementing a check for the correct type can happen on initialization if put in the setter

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow how to do this pythonically? Could you share some pseudo code to show what you mean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mvanderlee I believe @wd60622 was referring to the following:

@property
def placeholder(self):
      return self._placeholder

@placeholder.setter
def placeholder(self, placeholder):
     # can add checks here if needed
     self._placeholder = placeholder

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, which adds no value and isn't pythonic. This isn't Java/C#.

It would also break parameters. After a parameter has been created and used, any changes to the placeholder will break queries. Which is why I've made it private with a getter only.

I haven't changed the type, so any type checks would have to be part of a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would disagree with you on it not being Pythonic, as it’s a native language feature.

Also, I was trying to be helpful, but given your receptiveness to discussion and tone, I won’t proceed.

Good luck.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my tone @danielenricocahall it'd been a rough week and I lashed out.

As for pythonic. This video does a good job at explaining language feature vs pythonic.
Raymond Hettinger - Beyond PEP 8 -- Best practices for beautiful intelligible code - https://www.youtube.com/watch?v=wf-BqAjZb8M


@property
def placeholder(self):
return self._placeholder

def get_sql(self, **kwargs: Any) -> str:
return str(self.placeholder)

def update_parameters(self, param_key: Any, param_value: Any, **kwargs):
pass

class QmarkParameter(Parameter):
"""Question mark style, e.g. ...WHERE name=?"""
def get_param_key(self, placeholder: Any, **kwargs):
return placeholder

def __init__(self) -> None:
pass

def get_sql(self, **kwargs: Any) -> str:
return "?"
class ListParameter(Parameter):
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = idx_placeholder_gen) -> None:
super().__init__(placeholder=placeholder)
self._parameters = list()

@property
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

class NumericParameter(Parameter):
"""Numeric, positional style, e.g. ...WHERE name=:1"""
return str(self._placeholder)

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)
def get_parameters(self, **kwargs):
return self._parameters

def update_parameters(self, value: Any, **kwargs):
self._parameters.append(value)

class NamedParameter(Parameter):
"""Named style, e.g. ...WHERE name=:name"""

class DictParameter(Parameter):
def __init__(self, placeholder: Union[str, int, Callable[[int], str]] = named_placeholder_gen) -> None:
super().__init__(placeholder=placeholder)
self._parameters = dict()

@property
def placeholder(self) -> str:
if callable(self._placeholder):
return self._placeholder(len(self._parameters))

return str(self._placeholder)

def get_parameters(self, **kwargs):
return self._parameters

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder[1:]

def update_parameters(self, param_key: Any, value: Any, **kwargs):
self._parameters[param_key] = value


class QmarkParameter(ListParameter):
def get_sql(self, **kwargs):
return '?'


class NumericParameter(ListParameter):
"""Numeric, positional style, e.g. ...WHERE name=:1"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class FormatParameter(Parameter):
class FormatParameter(ListParameter):
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""

def __init__(self) -> None:
pass

def get_sql(self, **kwargs: Any) -> str:
return "%s"


class PyformatParameter(Parameter):
class NamedParameter(DictParameter):
"""Named style, e.g. ...WHERE name=:name"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class PyformatParameter(DictParameter):
"""Python extended format codes, e.g. ...WHERE name=%(name)s"""

def get_sql(self, **kwargs: Any) -> str:
return "%({placeholder})s".format(placeholder=self.placeholder)

def get_param_key(self, placeholder: Any, **kwargs):
return placeholder[2:-2]


class Negative(Term):
def __init__(self, term: Term) -> None:
Expand Down Expand Up @@ -385,9 +453,44 @@ def get_formatted_value(cls, value: Any, **kwargs):
return "null"
return str(value)

def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)
def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = parameter.get_sql(**kwargs)
param_key = parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key

def get_sql(
self,
quote_char: Optional[str] = None,
secondary_quote_char: str = "'",
parameter: Parameter = None,
**kwargs: Any,
) -> str:
if parameter is None:
sql = self.get_value_sql(quote_char=quote_char, secondary_quote_char=secondary_quote_char, **kwargs)
return format_alias_sql(sql, self.alias, quote_char=quote_char, **kwargs)

# Don't stringify numbers when using a parameter
if isinstance(self.value, (int, float)):
value_sql = self.value
else:
value_sql = self.get_value_sql(quote_char=quote_char, **kwargs)
param_sql, param_key = self._get_param_data(parameter, **kwargs)
parameter.update_parameters(param_key=param_key, value=value_sql, **kwargs)

return format_alias_sql(param_sql, self.alias, quote_char=quote_char, **kwargs)


class ParameterValueWrapper(ValueWrapper):
def __init__(self, parameter: Parameter, value: Any, alias: Optional[str] = None) -> None:
super().__init__(value, alias)
self._parameter = parameter

def _get_param_data(self, parameter: Parameter, **kwargs) -> Tuple[str, str]:
param_sql = self._parameter.get_sql(**kwargs)
param_key = self._parameter.get_param_key(placeholder=param_sql)

return param_sql, param_key


class JSON(Term):
Expand Down Expand Up @@ -551,6 +654,7 @@ def __init__(
if isinstance(table, str):
# avoid circular import at load time
from pypika.queries import Table

table = Table(table)
self.table = table

Expand Down
112 changes: 112 additions & 0 deletions pypika/tests/test_parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from datetime import date

from pypika import (
FormatParameter,
Expand All @@ -10,6 +11,7 @@
Query,
Tables,
)
from pypika.terms import ListParameter, ParameterValueWrapper


class ParametrizedTests(unittest.TestCase):
Expand Down Expand Up @@ -92,3 +94,113 @@ def test_format_parameter(self):

def test_pyformat_parameter(self):
self.assertEqual('%(buz)s', PyformatParameter('buz').get_sql())


class ParametrizedTestsWithValues(unittest.TestCase):
table_abc, table_efg = Tables("abc", "efg")

def test_param_insert(self):
q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo')

parameter = QmarkParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (?,?,?)', sql)
self.assertEqual([1, 2.2, 'foo'], parameter.get_parameters())

def test_param_select_join(self):
q = (
Query.from_(self.table_abc)
.select("*")
.where(self.table_abc.category == 'foobar')
.join(self.table_efg)
.on(self.table_abc.id == self.table_efg.abc_id)
.where(self.table_efg.date >= date(2024, 2, 22))
.limit(10)
)

parameter = FormatParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT * FROM "abc" JOIN "efg" ON "abc"."id"="efg"."abc_id" WHERE "abc"."category"=%s AND "efg"."date">=%s LIMIT 10',
sql,
)
self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters())

def test_param_select_subquery(self):
q = (
Query.from_(self.table_abc)
.select("*")
.where(self.table_abc.category == 'foobar')
.where(
self.table_abc.id.isin(
Query.from_(self.table_efg)
.select(self.table_efg.abc_id)
.where(self.table_efg.date >= date(2024, 2, 22))
)
)
.limit(10)
)

parameter = ListParameter(placeholder=lambda idx: f'&{idx+1}')
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT * FROM "abc" WHERE "category"=&1 AND "id" IN (SELECT "abc_id" FROM "efg" WHERE "date">=&2) LIMIT 10',
sql,
)
self.assertEqual(['foobar', '2024-02-22'], parameter.get_parameters())

def test_join(self):
subquery = (
Query.from_(self.table_efg)
.select(self.table_efg.fiz, self.table_efg.buz)
.where(self.table_efg.buz == 'buz')
)

q = (
Query.from_(self.table_abc)
.join(subquery)
.on(self.table_abc.bar == subquery.buz)
.select(self.table_abc.foo, subquery.fiz)
.where(self.table_abc.bar == 'bar')
)

parameter = NamedParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:param1)'
' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:param2',
sql,
)
self.assertEqual({'param1': 'buz', 'param2': 'bar'}, parameter.get_parameters())

def test_join_with_parameter_value_wrapper(self):
subquery = (
Query.from_(self.table_efg)
.select(self.table_efg.fiz, self.table_efg.buz)
.where(self.table_efg.buz == ParameterValueWrapper(Parameter(':buz'), 'buz'))
)

q = (
Query.from_(self.table_abc)
.join(subquery)
.on(self.table_abc.bar == subquery.buz)
.select(self.table_abc.foo, subquery.fiz)
.where(self.table_abc.bar == ParameterValueWrapper(NamedParameter('bar'), 'bar'))
)

parameter = NamedParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual(
'SELECT "abc"."foo","sq0"."fiz" FROM "abc" JOIN (SELECT "fiz","buz" FROM "efg" WHERE "buz"=:buz)'
' "sq0" ON "abc"."bar"="sq0"."buz" WHERE "abc"."bar"=:bar',
sql,
)
self.assertEqual({':buz': 'buz', 'bar': 'bar'}, parameter.get_parameters())

def test_pyformat_parameter(self):
q = Query.into(self.table_abc).columns("a", "b", "c").insert(1, 2.2, 'foo')

parameter = PyformatParameter()
sql = q.get_sql(parameter=parameter)
self.assertEqual('INSERT INTO "abc" ("a","b","c") VALUES (%(param1)s,%(param2)s,%(param3)s)', sql)
self.assertEqual({"param1": 1, "param2": 2.2, "param3": "foo"}, parameter.get_parameters())
2 changes: 1 addition & 1 deletion pypika/tests/test_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_init_with_str_table(self):
test_table_name = "test_table"
field = Field(name="name", table=test_table_name)
self.assertEqual(field.table, Table(name=test_table_name))


class FieldHashingTests(TestCase):
def test_tabled_eq_fields_equally_hashed(self):
Expand Down