291 lines
9.2 KiB
Python
291 lines
9.2 KiB
Python
|
"""
|
||
|
Transaction context managers returned by Connection.transaction()
|
||
|
"""
|
||
|
|
||
|
# Copyright (C) 2020 The Psycopg Team
|
||
|
|
||
|
import logging
|
||
|
|
||
|
from types import TracebackType
|
||
|
from typing import Generic, Iterator, Optional, Type, Union, TypeVar, TYPE_CHECKING
|
||
|
|
||
|
from . import pq
|
||
|
from . import sql
|
||
|
from . import errors as e
|
||
|
from .abc import ConnectionType, PQGen
|
||
|
|
||
|
if TYPE_CHECKING:
|
||
|
from typing import Any
|
||
|
from .connection import Connection
|
||
|
from .connection_async import AsyncConnection
|
||
|
|
||
|
IDLE = pq.TransactionStatus.IDLE
|
||
|
|
||
|
OK = pq.ConnStatus.OK
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class Rollback(Exception):
|
||
|
"""
|
||
|
Exit the current `Transaction` context immediately and rollback any changes
|
||
|
made within this context.
|
||
|
|
||
|
If a transaction context is specified in the constructor, rollback
|
||
|
enclosing transactions contexts up to and including the one specified.
|
||
|
"""
|
||
|
|
||
|
__module__ = "psycopg"
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
transaction: Union["Transaction", "AsyncTransaction", None] = None,
|
||
|
):
|
||
|
self.transaction = transaction
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
return f"{self.__class__.__qualname__}({self.transaction!r})"
|
||
|
|
||
|
|
||
|
class OutOfOrderTransactionNesting(e.ProgrammingError):
|
||
|
"""Out-of-order transaction nesting detected"""
|
||
|
|
||
|
|
||
|
class BaseTransaction(Generic[ConnectionType]):
|
||
|
def __init__(
|
||
|
self,
|
||
|
connection: ConnectionType,
|
||
|
savepoint_name: Optional[str] = None,
|
||
|
force_rollback: bool = False,
|
||
|
):
|
||
|
self._conn = connection
|
||
|
self.pgconn = self._conn.pgconn
|
||
|
self._savepoint_name = savepoint_name or ""
|
||
|
self.force_rollback = force_rollback
|
||
|
self._entered = self._exited = False
|
||
|
self._outer_transaction = False
|
||
|
self._stack_index = -1
|
||
|
|
||
|
@property
|
||
|
def savepoint_name(self) -> Optional[str]:
|
||
|
"""
|
||
|
The name of the savepoint; `!None` if handling the main transaction.
|
||
|
"""
|
||
|
# Yes, it may change on __enter__. No, I don't care, because the
|
||
|
# un-entered state is outside the public interface.
|
||
|
return self._savepoint_name
|
||
|
|
||
|
def __repr__(self) -> str:
|
||
|
cls = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||
|
info = pq.misc.connection_summary(self.pgconn)
|
||
|
if not self._entered:
|
||
|
status = "inactive"
|
||
|
elif not self._exited:
|
||
|
status = "active"
|
||
|
else:
|
||
|
status = "terminated"
|
||
|
|
||
|
sp = f"{self.savepoint_name!r} " if self.savepoint_name else ""
|
||
|
return f"<{cls} {sp}({status}) {info} at 0x{id(self):x}>"
|
||
|
|
||
|
def _enter_gen(self) -> PQGen[None]:
|
||
|
if self._entered:
|
||
|
raise TypeError("transaction blocks can be used only once")
|
||
|
self._entered = True
|
||
|
|
||
|
self._push_savepoint()
|
||
|
for command in self._get_enter_commands():
|
||
|
yield from self._conn._exec_command(command)
|
||
|
|
||
|
def _exit_gen(
|
||
|
self,
|
||
|
exc_type: Optional[Type[BaseException]],
|
||
|
exc_val: Optional[BaseException],
|
||
|
exc_tb: Optional[TracebackType],
|
||
|
) -> PQGen[bool]:
|
||
|
if not exc_val and not self.force_rollback:
|
||
|
yield from self._commit_gen()
|
||
|
return False
|
||
|
else:
|
||
|
# try to rollback, but if there are problems (connection in a bad
|
||
|
# state) just warn without clobbering the exception bubbling up.
|
||
|
try:
|
||
|
return (yield from self._rollback_gen(exc_val))
|
||
|
except OutOfOrderTransactionNesting:
|
||
|
# Clobber an exception happened in the block with the exception
|
||
|
# caused by out-of-order transaction detected, so make the
|
||
|
# behaviour consistent with _commit_gen and to make sure the
|
||
|
# user fixes this condition, which is unrelated from
|
||
|
# operational error that might arise in the block.
|
||
|
raise
|
||
|
except Exception as exc2:
|
||
|
logger.warning("error ignored in rollback of %s: %s", self, exc2)
|
||
|
return False
|
||
|
|
||
|
def _commit_gen(self) -> PQGen[None]:
|
||
|
ex = self._pop_savepoint("commit")
|
||
|
self._exited = True
|
||
|
if ex:
|
||
|
raise ex
|
||
|
|
||
|
for command in self._get_commit_commands():
|
||
|
yield from self._conn._exec_command(command)
|
||
|
|
||
|
def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
|
||
|
if isinstance(exc_val, Rollback):
|
||
|
logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
|
||
|
|
||
|
ex = self._pop_savepoint("rollback")
|
||
|
self._exited = True
|
||
|
if ex:
|
||
|
raise ex
|
||
|
|
||
|
for command in self._get_rollback_commands():
|
||
|
yield from self._conn._exec_command(command)
|
||
|
|
||
|
if isinstance(exc_val, Rollback):
|
||
|
if not exc_val.transaction or exc_val.transaction is self:
|
||
|
return True # Swallow the exception
|
||
|
|
||
|
return False
|
||
|
|
||
|
def _get_enter_commands(self) -> Iterator[bytes]:
|
||
|
if self._outer_transaction:
|
||
|
yield self._conn._get_tx_start_command()
|
||
|
|
||
|
if self._savepoint_name:
|
||
|
yield (
|
||
|
sql.SQL("SAVEPOINT {}")
|
||
|
.format(sql.Identifier(self._savepoint_name))
|
||
|
.as_bytes(self._conn)
|
||
|
)
|
||
|
|
||
|
def _get_commit_commands(self) -> Iterator[bytes]:
|
||
|
if self._savepoint_name and not self._outer_transaction:
|
||
|
yield (
|
||
|
sql.SQL("RELEASE {}")
|
||
|
.format(sql.Identifier(self._savepoint_name))
|
||
|
.as_bytes(self._conn)
|
||
|
)
|
||
|
|
||
|
if self._outer_transaction:
|
||
|
assert not self._conn._num_transactions
|
||
|
yield b"COMMIT"
|
||
|
|
||
|
def _get_rollback_commands(self) -> Iterator[bytes]:
|
||
|
if self._savepoint_name and not self._outer_transaction:
|
||
|
yield (
|
||
|
sql.SQL("ROLLBACK TO {n}")
|
||
|
.format(n=sql.Identifier(self._savepoint_name))
|
||
|
.as_bytes(self._conn)
|
||
|
)
|
||
|
yield (
|
||
|
sql.SQL("RELEASE {n}")
|
||
|
.format(n=sql.Identifier(self._savepoint_name))
|
||
|
.as_bytes(self._conn)
|
||
|
)
|
||
|
|
||
|
if self._outer_transaction:
|
||
|
assert not self._conn._num_transactions
|
||
|
yield b"ROLLBACK"
|
||
|
|
||
|
# Also clear the prepared statements cache.
|
||
|
if self._conn._prepared.clear():
|
||
|
yield from self._conn._prepared.get_maintenance_commands()
|
||
|
|
||
|
def _push_savepoint(self) -> None:
|
||
|
"""
|
||
|
Push the transaction on the connection transactions stack.
|
||
|
|
||
|
Also set the internal state of the object and verify consistency.
|
||
|
"""
|
||
|
self._outer_transaction = self.pgconn.transaction_status == IDLE
|
||
|
if self._outer_transaction:
|
||
|
# outer transaction: if no name it's only a begin, else
|
||
|
# there will be an additional savepoint
|
||
|
assert not self._conn._num_transactions
|
||
|
else:
|
||
|
# inner transaction: it always has a name
|
||
|
if not self._savepoint_name:
|
||
|
self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
|
||
|
|
||
|
self._stack_index = self._conn._num_transactions
|
||
|
self._conn._num_transactions += 1
|
||
|
|
||
|
def _pop_savepoint(self, action: str) -> Optional[Exception]:
|
||
|
"""
|
||
|
Pop the transaction from the connection transactions stack.
|
||
|
|
||
|
Also verify the state consistency.
|
||
|
"""
|
||
|
self._conn._num_transactions -= 1
|
||
|
if self._conn._num_transactions == self._stack_index:
|
||
|
return None
|
||
|
|
||
|
return OutOfOrderTransactionNesting(
|
||
|
f"transaction {action} at the wrong nesting level: {self}"
|
||
|
)
|
||
|
|
||
|
|
||
|
class Transaction(BaseTransaction["Connection[Any]"]):
|
||
|
"""
|
||
|
Returned by `Connection.transaction()` to handle a transaction block.
|
||
|
"""
|
||
|
|
||
|
__module__ = "psycopg"
|
||
|
|
||
|
_Self = TypeVar("_Self", bound="Transaction")
|
||
|
|
||
|
@property
|
||
|
def connection(self) -> "Connection[Any]":
|
||
|
"""The connection the object is managing."""
|
||
|
return self._conn
|
||
|
|
||
|
def __enter__(self: _Self) -> _Self:
|
||
|
with self._conn.lock:
|
||
|
self._conn.wait(self._enter_gen())
|
||
|
return self
|
||
|
|
||
|
def __exit__(
|
||
|
self,
|
||
|
exc_type: Optional[Type[BaseException]],
|
||
|
exc_val: Optional[BaseException],
|
||
|
exc_tb: Optional[TracebackType],
|
||
|
) -> bool:
|
||
|
if self.pgconn.status == OK:
|
||
|
with self._conn.lock:
|
||
|
return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
|
||
|
class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
|
||
|
"""
|
||
|
Returned by `AsyncConnection.transaction()` to handle a transaction block.
|
||
|
"""
|
||
|
|
||
|
__module__ = "psycopg"
|
||
|
|
||
|
_Self = TypeVar("_Self", bound="AsyncTransaction")
|
||
|
|
||
|
@property
|
||
|
def connection(self) -> "AsyncConnection[Any]":
|
||
|
return self._conn
|
||
|
|
||
|
async def __aenter__(self: _Self) -> _Self:
|
||
|
async with self._conn.lock:
|
||
|
await self._conn.wait(self._enter_gen())
|
||
|
return self
|
||
|
|
||
|
async def __aexit__(
|
||
|
self,
|
||
|
exc_type: Optional[Type[BaseException]],
|
||
|
exc_val: Optional[BaseException],
|
||
|
exc_tb: Optional[TracebackType],
|
||
|
) -> bool:
|
||
|
if self.pgconn.status == OK:
|
||
|
async with self._conn.lock:
|
||
|
return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
|
||
|
else:
|
||
|
return False
|