376 lines
11 KiB
Python
376 lines
11 KiB
Python
|
"""
|
||
|
Utility module to manipulate queries
|
||
|
"""
|
||
|
|
||
|
# Copyright (C) 2020 The Psycopg Team
|
||
|
|
||
|
import re
|
||
|
from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
|
||
|
from typing import Sequence, Tuple, Union, TYPE_CHECKING
|
||
|
from functools import lru_cache
|
||
|
|
||
|
from . import pq
|
||
|
from . import errors as e
|
||
|
from .sql import Composable
|
||
|
from .abc import Buffer, Query, Params
|
||
|
from ._enums import PyFormat
|
||
|
from ._encodings import conn_encoding
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from .abc import Transformer
|
||
|
|
||
|
|
||
|
class QueryPart(NamedTuple):
|
||
|
pre: bytes
|
||
|
item: Union[int, str]
|
||
|
format: PyFormat
|
||
|
|
||
|
|
||
|
class PostgresQuery:
|
||
|
"""
|
||
|
Helper to convert a Python query and parameters into Postgres format.
|
||
|
"""
|
||
|
|
||
|
__slots__ = """
|
||
|
query params types formats
|
||
|
_tx _want_formats _parts _encoding _order
|
||
|
""".split()
|
||
|
|
||
|
def __init__(self, transformer: "Transformer"):
|
||
|
self._tx = transformer
|
||
|
|
||
|
self.params: Optional[Sequence[Optional[Buffer]]] = None
|
||
|
# these are tuples so they can be used as keys e.g. in prepared stmts
|
||
|
self.types: Tuple[int, ...] = ()
|
||
|
|
||
|
# The format requested by the user and the ones to really pass Postgres
|
||
|
self._want_formats: Optional[List[PyFormat]] = None
|
||
|
self.formats: Optional[Sequence[pq.Format]] = None
|
||
|
|
||
|
self._encoding = conn_encoding(transformer.connection)
|
||
|
self._parts: List[QueryPart]
|
||
|
self.query = b""
|
||
|
self._order: Optional[List[str]] = None
|
||
|
|
||
|
def convert(self, query: Query, vars: Optional[Params]) -> None:
|
||
|
"""
|
||
|
Set up the query and parameters to convert.
|
||
|
|
||
|
The results of this function can be obtained accessing the object
|
||
|
attributes (`query`, `params`, `types`, `formats`).
|
||
|
"""
|
||
|
if isinstance(query, str):
|
||
|
bquery = query.encode(self._encoding)
|
||
|
elif isinstance(query, Composable):
|
||
|
bquery = query.as_bytes(self._tx)
|
||
|
else:
|
||
|
bquery = query
|
||
|
|
||
|
if vars is not None:
|
||
|
(
|
||
|
self.query,
|
||
|
self._want_formats,
|
||
|
self._order,
|
||
|
self._parts,
|
||
|
) = _query2pg(bquery, self._encoding)
|
||
|
else:
|
||
|
self.query = bquery
|
||
|
self._want_formats = self._order = None
|
||
|
|
||
|
self.dump(vars)
|
||
|
|
||
|
def dump(self, vars: Optional[Params]) -> None:
|
||
|
"""
|
||
|
Process a new set of variables on the query processed by `convert()`.
|
||
|
|
||
|
This method updates `params` and `types`.
|
||
|
"""
|
||
|
if vars is not None:
|
||
|
params = _validate_and_reorder_params(self._parts, vars, self._order)
|
||
|
assert self._want_formats is not None
|
||
|
self.params = self._tx.dump_sequence(params, self._want_formats)
|
||
|
self.types = self._tx.types or ()
|
||
|
self.formats = self._tx.formats
|
||
|
else:
|
||
|
self.params = None
|
||
|
self.types = ()
|
||
|
self.formats = None
|
||
|
|
||
|
|
||
|
class PostgresClientQuery(PostgresQuery):
|
||
|
"""
|
||
|
PostgresQuery subclass merging query and arguments client-side.
|
||
|
"""
|
||
|
|
||
|
__slots__ = ("template",)
|
||
|
|
||
|
def convert(self, query: Query, vars: Optional[Params]) -> None:
|
||
|
"""
|
||
|
Set up the query and parameters to convert.
|
||
|
|
||
|
The results of this function can be obtained accessing the object
|
||
|
attributes (`query`, `params`, `types`, `formats`).
|
||
|
"""
|
||
|
if isinstance(query, str):
|
||
|
bquery = query.encode(self._encoding)
|
||
|
elif isinstance(query, Composable):
|
||
|
bquery = query.as_bytes(self._tx)
|
||
|
else:
|
||
|
bquery = query
|
||
|
|
||
|
if vars is not None:
|
||
|
(self.template, self._order, self._parts) = _query2pg_client(
|
||
|
bquery, self._encoding
|
||
|
)
|
||
|
else:
|
||
|
self.query = bquery
|
||
|
self._order = None
|
||
|
|
||
|
self.dump(vars)
|
||
|
|
||
|
def dump(self, vars: Optional[Params]) -> None:
|
||
|
"""
|
||
|
Process a new set of variables on the query processed by `convert()`.
|
||
|
|
||
|
This method updates `params` and `types`.
|
||
|
"""
|
||
|
if vars is not None:
|
||
|
params = _validate_and_reorder_params(self._parts, vars, self._order)
|
||
|
self.params = tuple(
|
||
|
self._tx.as_literal(p) if p is not None else b"NULL" for p in params
|
||
|
)
|
||
|
self.query = self.template % self.params
|
||
|
else:
|
||
|
self.params = None
|
||
|
|
||
|
|
||
|
@lru_cache()
|
||
|
def _query2pg(
|
||
|
query: bytes, encoding: str
|
||
|
) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
|
||
|
"""
|
||
|
Convert Python query and params into something Postgres understands.
|
||
|
|
||
|
- Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
|
||
|
format (``$1``, ``$2``)
|
||
|
- placeholders can be %s, %t, or %b (auto, text or binary)
|
||
|
- return ``query`` (bytes), ``formats`` (list of formats) ``order``
|
||
|
(sequence of names used in the query, in the position they appear)
|
||
|
``parts`` (splits of queries and placeholders).
|
||
|
"""
|
||
|
parts = _split_query(query, encoding)
|
||
|
order: Optional[List[str]] = None
|
||
|
chunks: List[bytes] = []
|
||
|
formats = []
|
||
|
|
||
|
if isinstance(parts[0].item, int):
|
||
|
for part in parts[:-1]:
|
||
|
assert isinstance(part.item, int)
|
||
|
chunks.append(part.pre)
|
||
|
chunks.append(b"$%d" % (part.item + 1))
|
||
|
formats.append(part.format)
|
||
|
|
||
|
elif isinstance(parts[0].item, str):
|
||
|
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
|
||
|
order = []
|
||
|
for part in parts[:-1]:
|
||
|
assert isinstance(part.item, str)
|
||
|
chunks.append(part.pre)
|
||
|
if part.item not in seen:
|
||
|
ph = b"$%d" % (len(seen) + 1)
|
||
|
seen[part.item] = (ph, part.format)
|
||
|
order.append(part.item)
|
||
|
chunks.append(ph)
|
||
|
formats.append(part.format)
|
||
|
else:
|
||
|
if seen[part.item][1] != part.format:
|
||
|
raise e.ProgrammingError(
|
||
|
f"placeholder '{part.item}' cannot have different formats"
|
||
|
)
|
||
|
chunks.append(seen[part.item][0])
|
||
|
|
||
|
# last part
|
||
|
chunks.append(parts[-1].pre)
|
||
|
|
||
|
return b"".join(chunks), formats, order, parts
|
||
|
|
||
|
|
||
|
@lru_cache()
|
||
|
def _query2pg_client(
|
||
|
query: bytes, encoding: str
|
||
|
) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
|
||
|
"""
|
||
|
Convert Python query and params into a template to perform client-side binding
|
||
|
"""
|
||
|
parts = _split_query(query, encoding, collapse_double_percent=False)
|
||
|
order: Optional[List[str]] = None
|
||
|
chunks: List[bytes] = []
|
||
|
|
||
|
if isinstance(parts[0].item, int):
|
||
|
for part in parts[:-1]:
|
||
|
assert isinstance(part.item, int)
|
||
|
chunks.append(part.pre)
|
||
|
chunks.append(b"%s")
|
||
|
|
||
|
elif isinstance(parts[0].item, str):
|
||
|
seen: Dict[str, Tuple[bytes, PyFormat]] = {}
|
||
|
order = []
|
||
|
for part in parts[:-1]:
|
||
|
assert isinstance(part.item, str)
|
||
|
chunks.append(part.pre)
|
||
|
if part.item not in seen:
|
||
|
ph = b"%s"
|
||
|
seen[part.item] = (ph, part.format)
|
||
|
order.append(part.item)
|
||
|
chunks.append(ph)
|
||
|
else:
|
||
|
chunks.append(seen[part.item][0])
|
||
|
order.append(part.item)
|
||
|
|
||
|
# last part
|
||
|
chunks.append(parts[-1].pre)
|
||
|
|
||
|
return b"".join(chunks), order, parts
|
||
|
|
||
|
|
||
|
def _validate_and_reorder_params(
|
||
|
parts: List[QueryPart], vars: Params, order: Optional[List[str]]
|
||
|
) -> Sequence[Any]:
|
||
|
"""
|
||
|
Verify the compatibility between a query and a set of params.
|
||
|
"""
|
||
|
# Try concrete types, then abstract types
|
||
|
t = type(vars)
|
||
|
if t is list or t is tuple:
|
||
|
sequence = True
|
||
|
elif t is dict:
|
||
|
sequence = False
|
||
|
elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
|
||
|
sequence = True
|
||
|
elif isinstance(vars, Mapping):
|
||
|
sequence = False
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
"query parameters should be a sequence or a mapping,"
|
||
|
f" got {type(vars).__name__}"
|
||
|
)
|
||
|
|
||
|
if sequence:
|
||
|
if len(vars) != len(parts) - 1:
|
||
|
raise e.ProgrammingError(
|
||
|
f"the query has {len(parts) - 1} placeholders but"
|
||
|
f" {len(vars)} parameters were passed"
|
||
|
)
|
||
|
if vars and not isinstance(parts[0].item, int):
|
||
|
raise TypeError("named placeholders require a mapping of parameters")
|
||
|
return vars # type: ignore[return-value]
|
||
|
|
||
|
else:
|
||
|
if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
|
||
|
raise TypeError(
|
||
|
"positional placeholders (%s) require a sequence of parameters"
|
||
|
)
|
||
|
try:
|
||
|
return [vars[item] for item in order or ()] # type: ignore[call-overload]
|
||
|
except KeyError:
|
||
|
raise e.ProgrammingError(
|
||
|
"query parameter missing:"
|
||
|
f" {', '.join(sorted(i for i in order or () if i not in vars))}"
|
||
|
)
|
||
|
|
||
|
|
||
|
_re_placeholder = re.compile(
|
||
|
rb"""(?x)
|
||
|
% # a literal %
|
||
|
(?:
|
||
|
(?:
|
||
|
\( ([^)]+) \) # or a name in (braces)
|
||
|
. # followed by a format
|
||
|
)
|
||
|
|
|
||
|
(?:.) # or any char, really
|
||
|
)
|
||
|
"""
|
||
|
)
|
||
|
|
||
|
|
||
|
def _split_query(
|
||
|
query: bytes, encoding: str = "ascii", collapse_double_percent: bool = True
|
||
|
) -> List[QueryPart]:
|
||
|
parts: List[Tuple[bytes, Optional[Match[bytes]]]] = []
|
||
|
cur = 0
|
||
|
|
||
|
# pairs [(fragment, match], with the last match None
|
||
|
m = None
|
||
|
for m in _re_placeholder.finditer(query):
|
||
|
pre = query[cur : m.span(0)[0]]
|
||
|
parts.append((pre, m))
|
||
|
cur = m.span(0)[1]
|
||
|
if m:
|
||
|
parts.append((query[cur:], None))
|
||
|
else:
|
||
|
parts.append((query, None))
|
||
|
|
||
|
rv = []
|
||
|
|
||
|
# drop the "%%", validate
|
||
|
i = 0
|
||
|
phtype = None
|
||
|
while i < len(parts):
|
||
|
pre, m = parts[i]
|
||
|
if m is None:
|
||
|
# last part
|
||
|
rv.append(QueryPart(pre, 0, PyFormat.AUTO))
|
||
|
break
|
||
|
|
||
|
ph = m.group(0)
|
||
|
if ph == b"%%":
|
||
|
# unescape '%%' to '%' if necessary, then merge the parts
|
||
|
if collapse_double_percent:
|
||
|
ph = b"%"
|
||
|
pre1, m1 = parts[i + 1]
|
||
|
parts[i + 1] = (pre + ph + pre1, m1)
|
||
|
del parts[i]
|
||
|
continue
|
||
|
|
||
|
if ph == b"%(":
|
||
|
raise e.ProgrammingError(
|
||
|
"incomplete placeholder:"
|
||
|
f" '{query[m.span(0)[0]:].split()[0].decode(encoding)}'"
|
||
|
)
|
||
|
elif ph == b"% ":
|
||
|
# explicit messasge for a typical error
|
||
|
raise e.ProgrammingError(
|
||
|
"incomplete placeholder: '%'; if you want to use '%' as an"
|
||
|
" operator you can double it up, i.e. use '%%'"
|
||
|
)
|
||
|
elif ph[-1:] not in b"sbt":
|
||
|
raise e.ProgrammingError(
|
||
|
"only '%s', '%b', '%t' are allowed as placeholders, got"
|
||
|
f" '{m.group(0).decode(encoding)}'"
|
||
|
)
|
||
|
|
||
|
# Index or name
|
||
|
item: Union[int, str]
|
||
|
item = m.group(1).decode(encoding) if m.group(1) else i
|
||
|
|
||
|
if not phtype:
|
||
|
phtype = type(item)
|
||
|
elif phtype is not type(item):
|
||
|
raise e.ProgrammingError(
|
||
|
"positional and named placeholders cannot be mixed"
|
||
|
)
|
||
|
|
||
|
format = _ph_to_fmt[ph[-1:]]
|
||
|
rv.append(QueryPart(pre, item, format))
|
||
|
i += 1
|
||
|
|
||
|
return rv
|
||
|
|
||
|
|
||
|
_ph_to_fmt = {
|
||
|
b"s": PyFormat.AUTO,
|
||
|
b"t": PyFormat.TEXT,
|
||
|
b"b": PyFormat.BINARY,
|
||
|
}
|