diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index efbdc80f3f..3664a0c231 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -581,9 +581,18 @@ def connect( raise ValueError("project in url does not match client object project") instance = client.instance(instance_id) - conn = Connection( - instance, instance.database(database_id, pool=pool) if database_id else None - ) + + if database_id: + database = instance.database( + database_id, + pool=pool, + close_inactive_transactions=False, + ) + database._pool.logging_enabled = False + else: + database = None + + conn = Connection(instance, database) if pool is not None: conn._own_pool = False diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index e0e2bfdbd0..b267996d1f 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -38,6 +38,12 @@ + "numeric has a whole component with precision {}" ) +LONG_RUNNING_TRANSACTION_ERR_MSG = "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML." + +# Constants +DELETE_LONG_RUNNING_TRANSACTION_FREQUENCY_SEC = 120 +DELETE_LONG_RUNNING_TRANSACTION_THRESHOLD_SEC = 3600 + def _try_to_coerce_bytes(bytestring): """Try to coerce a byte string into the right thing based on Python diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 41e4460c30..1d65a74ed2 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -29,6 +29,7 @@ from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1._helpers import _retry from google.cloud.spanner_v1._helpers import _check_rst_stream_error +from google.cloud.spanner_v1._helpers import LONG_RUNNING_TRANSACTION_ERR_MSG from google.api_core.exceptions import InternalServerError @@ -144,6 +145,8 @@ def _check_state(self): """ if self.committed is not None: raise ValueError("Batch already committed") + if self._session is None: + raise ValueError(LONG_RUNNING_TRANSACTION_ERR_MSG) def commit(self, return_commit_stats=False, request_options=None): """Commit mutations to the database. diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index eee34361b3..5f6dd75114 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -113,6 +113,11 @@ class Database(object): is `True` to log commit statistics. If not passed, a logger will be created when needed that will log the commit statistics to stdout. + + :type close_inactive_transactions: boolean + :param close_inactive_transactions: (Optional) If set to True, the database will automatically close inactive transactions that have been running for longer than 60 minutes which may cause session leaks. + By default, this is set to False. + :type encryption_config: :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig` @@ -142,6 +147,7 @@ def __init__( ddl_statements=(), pool=None, logger=None, + close_inactive_transactions=False, encryption_config=None, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, database_role=None, @@ -160,6 +166,7 @@ def __init__( self._default_leader = None self.log_commit_stats = False self._logger = logger + self._close_inactive_transactions = close_inactive_transactions self._encryption_config = encryption_config self._database_dialect = database_dialect self._database_role = database_role @@ -366,7 +373,7 @@ def enable_drop_protection(self, value): def logger(self): """Logger used by the database. - The default logger will log commit stats at the log level INFO using + The default logger will log at the log level INFO using `sys.stderr`. :rtype: :class:`logging.Logger` or `None` @@ -381,6 +388,14 @@ def logger(self): self._logger.addHandler(ch) return self._logger + @property + def close_inactive_transactions(self): + """Whether the database has has closing inactive transactions enabled. Default: False. + :rtype: bool + :returns: True if closing inactive transactions is enabled, else False. + """ + return self._close_inactive_transactions + @property def spanner_api(self): """Helper for session-related API calls.""" @@ -647,7 +662,7 @@ def execute_partitioned_dml( ) def execute_pdml(): - with SessionCheckout(self._pool) as session: + with SessionCheckout(self._pool, isLongRunning=True) as session: txn = api.begin_transaction( session=session.name, options=txn_options, metadata=metadata ) @@ -1019,6 +1034,7 @@ def __enter__(self): """Begin ``with`` block.""" session = self._session = self._database._pool.get() batch = self._batch = Batch(session) + self._session._transaction = batch if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag return batch @@ -1037,7 +1053,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): "CommitStats: {}".format(self._batch.commit_stats), extra={"commit_stats": self._batch.commit_stats}, ) - self._database._pool.put(self._session) + if self._batch._session is not None: + self._database._pool.put(self._session) + self._session._transaction = None class SnapshotCheckout(object): @@ -1061,22 +1079,27 @@ class SnapshotCheckout(object): def __init__(self, database, **kw): self._database = database self._session = None + self._snapshot = None self._kw = kw def __enter__(self): """Begin ``with`` block.""" session = self._session = self._database._pool.get() - return Snapshot(session, **self._kw) + self._snapshot = Snapshot(session, **self._kw) + self._session._transaction = self._snapshot + return self._snapshot def __exit__(self, exc_type, exc_val, exc_tb): """End ``with`` block.""" - if isinstance(exc_val, NotFound): - # If NotFound exception occurs inside the with block - # then we validate if the session still exists. - if not self._session.exists(): - self._session = self._database._pool._new_session() - self._session.create() - self._database._pool.put(self._session) + if self._snapshot._session is not None: + if isinstance(exc_val, NotFound): + # If NotFound exception occurs inside the with block + # then we validate if the session still exists. + if not self._session.exists(): + self._session = self._database._pool._new_session() + self._session.create() + self._database._pool.put(self._session) + self._session._transaction = None class BatchSnapshot(object): diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 1b426f8cc2..90bca393f3 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -429,6 +429,7 @@ def database( ddl_statements=(), pool=None, logger=None, + close_inactive_transactions=False, encryption_config=None, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, database_role=None, @@ -453,6 +454,10 @@ def database( will be created when needed that will log the commit statistics to stdout. + :type close_inactive_transactions: boolean + :param close_inactive_transactions: (Optional) Represents whether the database + has close inactive transactions enabled or not. Default is False + :type encryption_config: :class:`~google.cloud.spanner_admin_database_v1.types.EncryptionConfig` or :class:`~google.cloud.spanner_admin_database_v1.types.RestoreDatabaseEncryptionConfig` @@ -481,6 +486,7 @@ def database( ddl_statements=ddl_statements, pool=pool, logger=logger, + close_inactive_transactions=close_inactive_transactions, encryption_config=encryption_config, database_dialect=database_dialect, database_role=database_role, diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 56837bfc0b..c61a4feb57 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -16,13 +16,16 @@ import datetime import queue - +import threading +import traceback +import time from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest from google.cloud.spanner_v1 import Session from google.cloud.spanner_v1._helpers import ( _metadata_with_prefix, _metadata_with_leader_aware_routing, + LONG_RUNNING_TRANSACTION_ERR_MSG, ) from warnings import warn @@ -38,15 +41,27 @@ class AbstractSessionPool(object): :type database_role: str :param database_role: (Optional) user-assigned database_role for the session. + + :type logging_enabled: boolean + :param logging_enabled: (Optional) Represents whether the session pool + has logging enabled or not. Default is True """ _database = None + _sessions = None + _borrowed_sessions = [] + _traces = {} + + _cleanup_task_ongoing_event = threading.Event() + _cleanup_task_ongoing = False + _used_sessions_ratio_threshold = 0.95 - def __init__(self, labels=None, database_role=None): + def __init__(self, labels=None, database_role=None, logging_enabled=True): if labels is None: labels = {} self._labels = labels self._database_role = database_role + self._logging_enabled = logging_enabled @property def labels(self): @@ -66,6 +81,18 @@ def database_role(self): """ return self._database_role + @property + def logging_enabled(self): + """Whether the session pool has logging enabled. Default: True. + :rtype: bool + :returns: True if logging is enabled, else False. + """ + return self._logging_enabled + + @logging_enabled.setter + def logging_enabled(self, value): + self._logging_enabled = value + def bind(self, database): """Associate the pool with a database. @@ -80,9 +107,12 @@ def bind(self, database): """ raise NotImplementedError() - def get(self): + def get(self, isLongRunning=False): """Check a session out from the pool. + :type isLongRunning: bool + :param isLongRunning: Specifies if the session fetched is for long running transaction or not. + Concrete implementations of this method are allowed to raise an error to signal that the pool is exhausted, or to block until a session is available. @@ -126,6 +156,14 @@ def _new_session(self): labels=self.labels, database_role=self.database_role ) + def _set_session_properties(self, session, isLongRunning): + """Helper for setting common session properties.""" + session.checkout_time = datetime.datetime.utcnow() + session.long_running = isLongRunning + session.transaction_logged = False + self._borrowed_sessions.append(session) + self._traces[session._session_id] = "".join(traceback.format_stack()) + def session(self, **kwargs): """Check out a session from the pool. @@ -138,6 +176,117 @@ def session(self, **kwargs): """ return SessionCheckout(self, **kwargs) + def startCleaningLongRunningSessions(self): + """Starts background task for recycling session.""" + from google.cloud.spanner_v1._helpers import ( + DELETE_LONG_RUNNING_TRANSACTION_FREQUENCY_SEC, + DELETE_LONG_RUNNING_TRANSACTION_THRESHOLD_SEC, + ) + + if ( + not AbstractSessionPool._cleanup_task_ongoing_event.is_set() + and not AbstractSessionPool._cleanup_task_ongoing + ): + AbstractSessionPool._cleanup_task_ongoing_event.set() + if self.logging_enabled: + self._database.logger.info( + f"{self._used_sessions_ratio_threshold * 100}% of the session pool is exhausted" + ) + background = threading.Thread( + target=self.deleteLongRunningTransactions, + args=( + DELETE_LONG_RUNNING_TRANSACTION_FREQUENCY_SEC, + DELETE_LONG_RUNNING_TRANSACTION_THRESHOLD_SEC, + ), + daemon=True, + name="recycle-sessions", + ) + background.start() + else: + AbstractSessionPool._cleanup_task_ongoing_event.set() + + def stopCleaningLongRunningSessions(self): + """Stops background task for recycling session.""" + AbstractSessionPool._cleanup_task_ongoing_event.clear() + + def deleteLongRunningTransactions( + self, longRunningTransactionFrequency_sec, longRunningTransactionThreshold_sec + ): + """Recycles sessions with long-running transactions + :param longRunningTransactionFrequency_sec: Interval for running background task in seconds. + :param longRunningTransactionThreshold_sec: Timeout for recycling sessions in seconds. + """ + long_running_transaction_timer = time.time() + transactions_closed = 0 + while AbstractSessionPool._cleanup_task_ongoing_event.is_set(): + AbstractSessionPool._cleanup_task_ongoing = True + is_timeout_reached = ( + time.time() - long_running_transaction_timer + >= longRunningTransactionThreshold_sec + + longRunningTransactionFrequency_sec + ) + + if is_timeout_reached and transactions_closed == 0: + break + iteration_start = time.time() + + # Retrieve a list of sessions to delete. + sessions_to_delete = [ + session + for session in self._borrowed_sessions + if not session.long_running + and (datetime.datetime.utcnow() - session.checkout_time) + > datetime.timedelta(seconds=longRunningTransactionThreshold_sec) + ] + + for session in sessions_to_delete: + transactions_closed += int( + self._close_long_running_transactions(session) + ) + + # Calculate and sleep for the time remaining until the next iteration based on the interval. + iteration_elapsed = time.time() - iteration_start + remaining_time = longRunningTransactionFrequency_sec - iteration_elapsed + if remaining_time > 0: + time.sleep(remaining_time) + + AbstractSessionPool._cleanup_task_ongoing = False + AbstractSessionPool._cleanup_task_ongoing_event.clear() + + def _close_long_running_transactions(self, session): + """Helper method to close long running transactions. + :rtype: :bool + :returns: True if transaction is closed else False. + """ + session_recycled = False + if session._session_id in self._traces: + session_trace = self._traces[session._session_id] + if self._database.close_inactive_transactions: + if self.logging_enabled: + # Log a warning for a long-running transaction that has been closed + self._database.logger.warning( + LONG_RUNNING_TRANSACTION_ERR_MSG + session_trace + ) + + # Set the session as None for associated transaction object + if session._transaction is not None: + session._transaction._session = None + + # Increment the count of closed transactions and return the session to the pool + session_recycled = True + self.put(session) + elif self.logging_enabled: + # Log a warning for a potentially leaking long-running transaction. + # Only log the warning if it hasn't been logged already. + if not session.transaction_logged: + self._database.logger.warning( + "Transaction has been running for longer than 60 minutes and might be causing a leak. " + + "Enable closeInactiveTransactions in Session Pool Options to automatically clean such transactions or use batch or partitioned transactions for long running operations." + + session_trace + ) + session.transaction_logged = True + return session_recycled + class FixedSizePool(AbstractSessionPool): """Concrete session pool implementation: @@ -167,6 +316,10 @@ class FixedSizePool(AbstractSessionPool): :type database_role: str :param database_role: (Optional) user-assigned database_role for the session. + + :type logging_enabled: boolean + :param logging_enabled: (Optional) Represents whether the session pool + has logging enabled or not. Default is True """ DEFAULT_SIZE = 10 @@ -178,11 +331,16 @@ def __init__( default_timeout=DEFAULT_TIMEOUT, labels=None, database_role=None, + logging_enabled=True, ): - super(FixedSizePool, self).__init__(labels=labels, database_role=database_role) + super(FixedSizePool, self).__init__( + labels=labels, database_role=database_role, logging_enabled=logging_enabled + ) self.size = size self.default_timeout = default_timeout self._sessions = queue.LifoQueue(size) + self._borrowed_sessions = [] + self._traces = {} def bind(self, database): """Associate the pool with a database. @@ -215,7 +373,7 @@ def bind(self, database): session._session_id = session_pb.name.split("/")[-1] self._sessions.put(session) - def get(self, timeout=None): + def get(self, isLongRunning=False, timeout=None): """Check a session out from the pool. :type timeout: int @@ -235,6 +393,19 @@ def get(self, timeout=None): session = self._database.session() session.create() + # Set session properties. + self._set_session_properties(session, isLongRunning) + + # Start background task for handling long-running transactions if used session threshold has reached. + if (self._database.close_inactive_transactions or self.logging_enabled) and len( + self._borrowed_sessions + ) / self.size >= self._used_sessions_ratio_threshold: + self.startCleaningLongRunningSessions() + + # Log a warning message if Session pool is exhausted. + if self.logging_enabled and len(self._borrowed_sessions) == self.size: + self._database.logger.warning("100% of the session pool is exhausted") + return session def put(self, session): @@ -247,11 +418,21 @@ def put(self, session): :raises: :exc:`queue.Full` if the queue is full. """ + if self._borrowed_sessions.__contains__(session): + self._borrowed_sessions.remove(session) self._sessions.put_nowait(session) + self._traces.pop(session._session_id, None) + + # Stop background task for handling long running transactions if used sessions are less the threshold" + if (self._database.close_inactive_transactions or self.logging_enabled) and len( + self._borrowed_sessions + ) / self.size < self._used_sessions_ratio_threshold: + self.stopCleaningLongRunningSessions() def clear(self): - """Delete all sessions in the pool.""" + """Delete all sessions in the pool and stops the background cleanup task.""" + self.stopCleaningLongRunningSessions() while True: try: session = self._sessions.get(block=False) @@ -282,13 +463,23 @@ class BurstyPool(AbstractSessionPool): :type database_role: str :param database_role: (Optional) user-assigned database_role for the session. + + :type logging_enabled: boolean + :param logging_enabled: (Optional) Represents whether the session pool + has logging enabled or not. Default is True """ - def __init__(self, target_size=10, labels=None, database_role=None): - super(BurstyPool, self).__init__(labels=labels, database_role=database_role) + def __init__( + self, target_size=10, labels=None, database_role=None, logging_enabled=True + ): + super(BurstyPool, self).__init__( + labels=labels, database_role=database_role, logging_enabled=logging_enabled + ) self.target_size = target_size self._database = None self._sessions = queue.LifoQueue(target_size) + self._borrowed_sessions = [] + self._traces = {} def bind(self, database): """Associate the pool with a database. @@ -300,7 +491,7 @@ def bind(self, database): self._database = database self._database_role = self._database_role or self._database.database_role - def get(self): + def get(self, isLongRunning=False): """Check a session out from the pool. :rtype: :class:`~google.cloud.spanner_v1.session.Session` @@ -316,6 +507,20 @@ def get(self): if not session.exists(): session = self._new_session() session.create() + + # Set session properties. + self._set_session_properties(session, isLongRunning) + + # Start background task for handling long-running transactions if used sessions threshold has reached. + if (self._database.close_inactive_transactions or self.logging_enabled) and len( + self._borrowed_sessions + ) / self.target_size >= self._used_sessions_ratio_threshold: + self.startCleaningLongRunningSessions() + + # Log a warning message if Session pool is exhausted. + if self.logging_enabled and len(self._borrowed_sessions) == self.target_size: + self._database.logger.warning("100% of the session pool is exhausted") + return session def put(self, session): @@ -328,7 +533,18 @@ def put(self, session): :param session: the session being returned. """ try: + if self._borrowed_sessions.__contains__(session): + self._borrowed_sessions.remove(session) self._sessions.put_nowait(session) + self._traces.pop(session._session_id, None) + + # Stop background task for handling long running transactions if used sessions are less then threshold." + if ( + self._database.close_inactive_transactions or self.logging_enabled + ) and len( + self._borrowed_sessions + ) / self.target_size < self._used_sessions_ratio_threshold: + self.stopCleaningLongRunningSessions() except queue.Full: try: session.delete() @@ -336,7 +552,9 @@ def put(self, session): pass def clear(self): - """Delete all sessions in the pool.""" + """Delete all sessions in the pool and stops the background cleanup task.""" + + self.stopCleaningLongRunningSessions() while True: try: @@ -383,6 +601,10 @@ class PingingPool(AbstractSessionPool): :type database_role: str :param database_role: (Optional) user-assigned database_role for the session. + + :type logging_enabled: boolean + :param logging_enabled: (Optional) Represents whether the session pool + has logging enabled or not. Default is True """ def __init__( @@ -392,12 +614,17 @@ def __init__( ping_interval=3000, labels=None, database_role=None, + logging_enabled=True, ): - super(PingingPool, self).__init__(labels=labels, database_role=database_role) + super(PingingPool, self).__init__( + labels=labels, database_role=database_role, logging_enabled=logging_enabled + ) self.size = size self.default_timeout = default_timeout self._delta = datetime.timedelta(seconds=ping_interval) self._sessions = queue.PriorityQueue(size) + self._borrowed_sessions = [] + self._traces = {} def bind(self, database): """Associate the pool with a database. @@ -433,7 +660,7 @@ def bind(self, database): self.put(session) created_session_count += len(resp.session) - def get(self, timeout=None): + def get(self, isLongRunning=False, timeout=None): """Check a session out from the pool. :type timeout: int @@ -457,6 +684,18 @@ def get(self, timeout=None): session = self._new_session() session.create() + # Set session properties. + self._set_session_properties(session, isLongRunning) + + # Start background task for handling long-running transactions if used sessions threshold has reached. + if (self._database.close_inactive_transactions or self.logging_enabled) and len( + self._borrowed_sessions + ) / self.size >= self._used_sessions_ratio_threshold: + self.startCleaningLongRunningSessions() + + # Log a warning message if Session pool is exhausted. + if self.logging_enabled and len(self._borrowed_sessions) == self.size: + self._database.logger.warning("100% of the session pool is exhausted") return session def put(self, session): @@ -469,10 +708,22 @@ def put(self, session): :raises: :exc:`queue.Full` if the queue is full. """ + if self._borrowed_sessions.__contains__(session): + self._borrowed_sessions.remove(session) self._sessions.put_nowait((_NOW() + self._delta, session)) + self._traces.pop(session._session_id, None) + + # Stop background task for handling long running transactions if used sessions are less then the threshold" + if (self._database.close_inactive_transactions or self.logging_enabled) and len( + self._borrowed_sessions + ) / self.size < self._used_sessions_ratio_threshold: + self.stopCleaningLongRunningSessions() def clear(self): - """Delete all sessions in the pool.""" + """Delete all sessions in the pool and stops the background cleanup task.""" + + self.stopCleaningLongRunningSessions() + while True: try: _, session = self._sessions.get(block=False) @@ -537,6 +788,10 @@ class TransactionPingingPool(PingingPool): :type database_role: str :param database_role: (Optional) user-assigned database_role for the session. + + :type logging_enabled: boolean + :param logging_enabled: (Optional) Represents whether the session pool + has logging enabled or not. Default is True """ def __init__( @@ -546,6 +801,7 @@ def __init__( ping_interval=3000, labels=None, database_role=None, + logging_enabled=True, ): """This throws a deprecation warning on initialization.""" warn( @@ -554,6 +810,8 @@ def __init__( stacklevel=2, ) self._pending_sessions = queue.Queue() + self._borrowed_sessions = [] + self._traces = {} super(TransactionPingingPool, self).__init__( size, @@ -561,6 +819,7 @@ def __init__( ping_interval, labels=labels, database_role=database_role, + logging_enabled=logging_enabled, ) self.begin_pending_transactions() @@ -592,7 +851,18 @@ def put(self, session): txn = session._transaction if txn is None or txn.committed or txn.rolled_back: session.transaction() + if self._borrowed_sessions.__contains__(session): + self._borrowed_sessions.remove(session) self._pending_sessions.put(session) + self._traces.pop(session._session_id, None) + + # Stop background task for handling long running transactions if used sessions are less then threshold." + if ( + self._database.close_inactive_transactions or self.logging_enabled + ) and len( + self._borrowed_sessions + ) / self.size < self._used_sessions_ratio_threshold: + self.stopCleaningLongRunningSessions() else: super(TransactionPingingPool, self).put(session) @@ -624,4 +894,9 @@ def __enter__(self): return self._session def __exit__(self, *ignored): - self._pool.put(self._session) + if not ( + self._session._transaction is not None + and self._session._transaction._session is None + ): + self._pool.put(self._session) + self._session._transaction = None diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index b25af53805..ac58c326a1 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -95,6 +95,42 @@ def labels(self): """ return self._labels + @property + def checkout_time(self): + """Check out time for the session. + :rtype: time + :returns: the checked out time for a session. + """ + return self._checkout_time + + @checkout_time.setter + def checkout_time(self, value): + self._checkout_time = value + + @property + def long_running(self): + """Whether the session is used for long running transaction. + :rtype: bool + :returns: True if session is long running, else False. + """ + return self._long_running + + @long_running.setter + def long_running(self, value): + self._long_running = value + + @property + def transaction_logged(self): + """Whether the session is already logged for long running transaction. + :rtype: bool + :returns: True if session is already logged for long running, else False. + """ + return self._transaction_logged + + @transaction_logged.setter + def transaction_logged(self, value): + self._transaction_logged = value + @property def name(self): """Session name used in requests. diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 573042aa11..7b06ec6da0 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -37,6 +37,7 @@ _retry, _check_rst_stream_error, _SessionWrapper, + LONG_RUNNING_TRANSACTION_ERR_MSG, ) from google.cloud.spanner_v1._opentelemetry_tracing import trace_call from google.cloud.spanner_v1.streamed import StreamedResultSet @@ -152,6 +153,10 @@ class _SnapshotBase(_SessionWrapper): _execute_sql_count = 0 _lock = threading.Lock() + def _check_session_state(self): + if self._session is None: + raise ValueError(LONG_RUNNING_TRANSACTION_ERR_MSG) + def _make_txn_selector(self): """Helper for :meth:`read` / :meth:`execute_sql`. @@ -231,6 +236,9 @@ def read( for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. """ + + self._check_session_state() + if self._read_request_count > 0: if not self._multi_use: raise ValueError("Cannot re-use single-use snapshot.") @@ -383,6 +391,7 @@ def execute_sql( for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. """ + self._check_session_state() if self._read_request_count > 0: if not self._multi_use: raise ValueError("Cannot re-use single-use snapshot.") @@ -531,6 +540,7 @@ def partition_read( for single-use snapshots, or if a transaction ID is already associated with the snapshot. """ + self._check_session_state() if not self._multi_use: raise ValueError("Cannot use single-use snapshot.") @@ -625,6 +635,7 @@ def partition_query( for single-use snapshots, or if a transaction ID is already associated with the snapshot. """ + self._check_session_state() if not self._multi_use: raise ValueError("Cannot use single-use snapshot.") @@ -794,6 +805,8 @@ def begin(self): if self._read_request_count > 0: raise ValueError("Read-only transaction already pending") + self._check_session_state() + database = self._session._database api = database.spanner_api metadata = _metadata_with_prefix(database.name) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index d564d0d488..af25a62eac 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -24,6 +24,7 @@ _metadata_with_leader_aware_routing, _retry, _check_rst_stream_error, + LONG_RUNNING_TRANSACTION_ERR_MSG, ) from google.cloud.spanner_v1 import CommitRequest from google.cloud.spanner_v1 import ExecuteBatchDmlRequest @@ -62,6 +63,10 @@ def __init__(self, session): super(Transaction, self).__init__(session) + def _check_session_state(self): + if self._session is None: + raise ValueError(LONG_RUNNING_TRANSACTION_ERR_MSG) + def _check_state(self): """Helper for :meth:`commit` et al. @@ -102,6 +107,7 @@ def _execute_request( :type request: proto :param request: request proto to call the method with """ + self._check_session_state() transaction = self._make_txn_selector() request.transaction = transaction with trace_call(trace_name, session, attributes): @@ -130,6 +136,7 @@ def begin(self): if self.rolled_back: raise ValueError("Transaction is already rolled back") + self._check_session_state() database = self._session._database api = database.spanner_api metadata = _metadata_with_prefix(database.name) @@ -154,6 +161,7 @@ def begin(self): def rollback(self): """Roll back a transaction on the database.""" + self._check_session_state() self._check_state() if self._transaction_id is not None: @@ -198,6 +206,7 @@ def commit(self, return_commit_stats=False, request_options=None): :returns: timestamp of the committed changes. :raises ValueError: if there are no mutations to commit. """ + self._check_session_state() self._check_state() if self._transaction_id is None and len(self._mutations) > 0: self.begin() @@ -331,6 +340,7 @@ def execute_update( :rtype: int :returns: Count of rows affected by the DML statement. """ + self._check_session_state() params_pb = self._make_params_pb(params, param_types) database = self._session._database metadata = _metadata_with_prefix(database.name) @@ -435,6 +445,7 @@ def batch_update(self, statements, request_options=None): statement triggering the error will not have an entry in the list, nor will any statements following that one. """ + self._check_session_state() parsed = [] for statement in statements: if isinstance(statement, str): diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 4a2ce5f495..1744aa3137 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -2521,6 +2521,138 @@ def test_partition_query(sessions_database, not_emulator): batch_txn.close() +def test_should_close_inactive_transactions_with_bursty_pool_and_transaction( + not_emulator, + not_postgres, + shared_instance, + database_operation_timeout, + databases_to_delete, +): + # Overriding the frequency and threshold to smaller value to simulate the behavior. + spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_FREQUENCY_SEC = 2 + spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_THRESHOLD_SEC = 5 + + database_name = _helpers.unique_id("test_longrunning", separator="_") + pool = spanner_v1.BurstyPool(target_size=1) + + temp_db = shared_instance.database( + database_name, + ddl_statements=_helpers.DDL_STATEMENTS, + pool=pool, + close_inactive_transactions=True, + ) + + operation = temp_db.create() + operation.result(database_operation_timeout) + + databases_to_delete.append(temp_db) + + def long_operation(transaction): + transaction.execute_sql("SELECT 1") + time.sleep(10) + transaction.execute_sql("SELECT 1") + + with pytest.raises(Exception) as exc: + temp_db.run_in_transaction(long_operation) + assert ( + exc.value.args[0] + == "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML." + ) + + +def test_should_close_inactive_transactions_with_fixedsize_pool_and_snapshot( + not_emulator, + not_postgres, + shared_instance, + database_operation_timeout, + databases_to_delete, +): + # Overriding the frequency and threshold to smaller value to simulate the behavior. + spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_FREQUENCY_SEC = 2 + spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_THRESHOLD_SEC = 5 + + database_name = _helpers.unique_id("test_longrunning", separator="_") + pool = spanner_v1.FixedSizePool(size=1, default_timeout=500) + + temp_db = shared_instance.database( + database_name, + ) + + create_operation = temp_db.create() + create_operation.result(database_operation_timeout) + databases_to_delete.append(temp_db) + operation = temp_db.update_ddl(_helpers.DDL_STATEMENTS) + operation.result(database_operation_timeout) + + database = shared_instance.database( + database_name, + pool=pool, + close_inactive_transactions=True, + ) + + with pytest.raises(Exception) as exc: + with database.snapshot(multi_use=True) as snapshot: + snapshot.execute_sql("SELECT 1") + time.sleep(10) + snapshot.execute_sql("SELECT 1") + + assert ( + exc.value.args[0] + == "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML." + ) + + +def test_should_close_inactive_transactions_with_pinging_pool_and_batch( + not_emulator, + not_postgres, + shared_instance, + database_operation_timeout, + databases_to_delete, +): + # Overriding the frequency and threshold to smaller value to simulate the behavior. + spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_FREQUENCY_SEC = 2 + spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_THRESHOLD_SEC = 5 + + database_name = _helpers.unique_id("test_longrunning", separator="_") + pool = spanner_v1.PingingPool(size=1, default_timeout=500) + + temp_db = shared_instance.database( + database_name, + ) + + create_operation = temp_db.create() + create_operation.result(database_operation_timeout) + databases_to_delete.append(temp_db) + operation = temp_db.update_ddl(_helpers.DDL_STATEMENTS) + operation.result(database_operation_timeout) + + database = shared_instance.database( + database_name, + pool=pool, + close_inactive_transactions=True, + ) + + table = "contacts" + columns = ["contact_id", "first_name", "last_name", "email"] + rowdata1 = [ + (1, "Alex", "Alex", "testemail@email.com"), + ] + rowdata2 = [ + (2, "Alexander", "Alexander", "testemail@email.com"), + ] + + with pytest.raises(Exception) as exc: + with database.batch() as batch: + batch.insert(table, columns, rowdata1) + batch.insert(table, columns, rowdata2) + time.sleep(10) + + assert ( + exc.value.args[0] + == "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML." + ) + + class FauxCall: def __init__(self, code, details="FauxCall"): self._code = code diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 86dde73159..17122352d4 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -53,7 +53,11 @@ def test_w_implicit(self, mock_client): client.instance.assert_called_once_with(INSTANCE) self.assertIs(connection.database, database) - instance.database.assert_called_once_with(DATABASE, pool=None) + instance.database.assert_called_once_with( + DATABASE, + pool=None, + close_inactive_transactions=False, + ) # Datbase constructs its own pool self.assertIsNotNone(connection.database._pool) self.assertTrue(connection.instance._client.route_to_leader_enabled) @@ -96,7 +100,11 @@ def test_w_explicit(self, mock_client): client.instance.assert_called_once_with(INSTANCE) self.assertIs(connection.database, database) - instance.database.assert_called_once_with(DATABASE, pool=pool) + instance.database.assert_called_once_with( + DATABASE, + pool=pool, + close_inactive_transactions=False, + ) def test_w_credential_file_path(self, mock_client): from google.cloud.spanner_dbapi import connect diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 1628f84062..37cfee4aad 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -19,6 +19,7 @@ import unittest import warnings import pytest +from google.cloud.spanner_v1 import BurstyPool PROJECT = "test-project" INSTANCE = "test-instance" @@ -470,6 +471,7 @@ def test_commit_clears_statements(self, mock_transaction): """ connection = self._make_connection() connection._transaction = mock.Mock(rolled_back=False, committed=False) + connection._session = mock.MagicMock() connection._statements = [{}, {}] self.assertEqual(len(connection._statements), 2) @@ -486,6 +488,7 @@ def test_rollback_clears_statements(self, mock_transaction): """ connection = self._make_connection() connection._transaction = mock.Mock() + connection._session = mock.MagicMock() connection._statements = [{}, {}] self.assertEqual(len(connection._statements), 2) @@ -1000,11 +1003,30 @@ def __init__(self, name="instance_id", client=None): self.name = name self._client = client - def database(self, database_id="database_id", pool=None): - return _Database(database_id, pool) + def database( + self, + database_id="database_id", + pool=None, + logging_enabled=False, + close_inactive_transactions=False, + ): + return _Database( + database_id, pool, logging_enabled=False, close_inactive_transactions=False + ) class _Database(object): - def __init__(self, database_id="database_id", pool=None): + def __init__( + self, + database_id="database_id", + pool=None, + logging_enabled=False, + close_inactive_transactions=False, + ): self.name = database_id - self.pool = pool + if pool is None: + pool = BurstyPool() + + self._pool = pool + self.logging_enabled = logging_enabled + self.close_inactive_transactions = close_inactive_transactions diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 856816628f..674deb5dae 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -191,6 +191,17 @@ def test_commit_already_committed(self): self.assertNoSpans() + def test_commit_should_throw_error_for_recycled_session(self): + session = _Session() + batch = self._make_one(session) + batch._session = None + with self.assertRaises(Exception) as cm: + batch.commit() + self.assertEqual( + str(cm.exception), + "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML.", + ) + def test_commit_grpc_error(self): from google.api_core.exceptions import Unknown from google.cloud.spanner_v1.keyset import KeySet diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bd368eed11..00d6a3a249 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -116,6 +116,7 @@ def test_ctor_defaults(self): self.assertTrue(database._pool._sessions.empty()) self.assertIsNone(database.database_role) self.assertTrue(database._route_to_leader_enabled, True) + self.assertFalse(database._close_inactive_transactions) def test_ctor_w_explicit_pool(self): instance = _Instance(self.INSTANCE_NAME) @@ -146,6 +147,24 @@ def test_ctor_w_route_to_leader_disbled(self): self.assertIs(database._instance, instance) self.assertFalse(database._route_to_leader_enabled) + def test_ctor_w_close_inactive_transactions_enabled(self): + instance = _Instance(self.INSTANCE_NAME) + database = self._make_one( + self.DATABASE_ID, instance, close_inactive_transactions=True + ) + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._instance, instance) + self.assertIs(database.close_inactive_transactions, True) + + def test_ctor_w_close_inactive_transactions_disabled(self): + instance = _Instance(self.INSTANCE_NAME) + database = self._make_one( + self.DATABASE_ID, instance, close_inactive_transactions=False + ) + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._instance, instance) + self.assertIs(database.close_inactive_transactions, False) + def test_ctor_w_ddl_statements_non_string(self): with self.assertRaises(ValueError): self._make_one( @@ -1198,6 +1217,20 @@ def test_snapshot_defaults(self): self.assertIs(checkout._database, database) self.assertEqual(checkout._kw, {}) + def test_snapshot_longrunningvalue(self): + from google.cloud.spanner_v1.snapshot import Snapshot + + client = _Client() + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = mock.Mock() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + with database.snapshot() as checkout: + self.assertIsInstance(checkout, Snapshot) + + self.assertEqual(pool.get.call_count, 1) + # get method of pool is passed without any param, that means longrunning param is false + pool.get.assert_called_once_with() + def test_snapshot_w_read_timestamp_and_multi_use(self): import datetime from google.cloud._helpers import UTC @@ -2731,8 +2764,9 @@ class _Pool(object): def bind(self, database): self._bound = database - def get(self): + def get(self, isLongRunning=False): session, self._session = self._session, None + session.long_running = isLongRunning return session def put(self, session): diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 23ed3e7251..276179e899 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -15,8 +15,11 @@ from functools import total_ordering import unittest - +import random +import string +import threading import mock +import datetime def _make_database(name="name"): @@ -45,14 +48,18 @@ def test_ctor_defaults(self): self.assertIsNone(pool._database) self.assertEqual(pool.labels, {}) self.assertIsNone(pool.database_role) + self.assertTrue(pool.logging_enabled) def test_ctor_explicit(self): labels = {"foo": "bar"} database_role = "dummy-role" - pool = self._make_one(labels=labels, database_role=database_role) + pool = self._make_one( + labels=labels, database_role=database_role, logging_enabled=False + ) self.assertIsNone(pool._database) self.assertEqual(pool.labels, labels) self.assertEqual(pool.database_role, database_role) + self.assertFalse(pool.logging_enabled) def test_bind_abstract(self): pool = self._make_one() @@ -131,6 +138,220 @@ def test_session_w_kwargs(self): self.assertIsNone(checkout._session) self.assertEqual(checkout._kwargs, {"foo": "bar"}) + @mock.patch("threading.Thread") + def test_startCleaningLongRunningSessions_success(self, mock_thread_class): + mock_thread = mock.MagicMock() + mock_thread.start = mock.MagicMock() + mock_thread_class.return_value = mock_thread + + pool = self._make_one() + pool._database = mock.MagicMock() + pool._cleanup_task_ongoing_event.clear() + pool._cleanup_task_ongoing = False + with mock.patch( + "google.cloud.spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_FREQUENCY_SEC", + new=5, + ), mock.patch( + "google.cloud.spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_THRESHOLD_SEC", + new=10, + ): + pool.startCleaningLongRunningSessions() + + # The event should be set, indicating the task is now ongoing + self.assertTrue(pool._cleanup_task_ongoing_event.is_set()) + + # A new thread should have been created to start the task + threading.Thread.assert_called_once_with( + target=pool.deleteLongRunningTransactions, + args=(5, 10), + daemon=True, + name="recycle-sessions", + ) + mock_thread.start.assert_called_once() + pool.stopCleaningLongRunningSessions() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + @mock.patch("threading.Thread") + def test_startCleaningLongRunningSessions_should_trigger_background_task_once( + self, mock_thread_class + ): + mock_thread = mock.MagicMock() + mock_thread.start = mock.MagicMock() + mock_thread_class.return_value = mock_thread + + pool = self._make_one() + pool._database = mock.MagicMock() + pool._cleanup_task_ongoing_event.clear() + with mock.patch( + "google.cloud.spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_FREQUENCY_SEC", + new=5, + ), mock.patch( + "google.cloud.spanner_v1._helpers.DELETE_LONG_RUNNING_TRANSACTION_THRESHOLD_SEC", + new=10, + ): + pool.startCleaningLongRunningSessions() + + # Calling start and stop background task multiple times. Background should get trigerred only once. + threads = [] + threads.append( + threading.Thread( + target=pool.startCleaningLongRunningSessions, + ) + ) + threads.append( + threading.Thread( + target=pool.stopCleaningLongRunningSessions, + ) + ) + threads.append( + threading.Thread( + target=pool.startCleaningLongRunningSessions, + ) + ) + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # A new thread should have been created to start the task + deleteLongRunningTransactions_calls = sum( + [ + 1 + for call in threading.Thread.mock_calls + if "deleteLongRunningTransactions" in call.kwargs.__str__() + ] + ) + + self.assertEqual(deleteLongRunningTransactions_calls, 1) + self.assertEqual( + pool._database.logger.info.mock_calls[0].args[0].__str__(), + "95.0% of the session pool is exhausted", + ) + + pool.stopCleaningLongRunningSessions() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + def test_stopCleaningLongRunningSessions(self): + pool = self._make_one() + pool._cleanup_task_ongoing_event.set() + pool.stopCleaningLongRunningSessions() + + # The event should not be set, indicating the task is now stopped + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + def _setup_session_leak(self, close_inactive_transactions, logging_enabled): + pool = self._make_one() + pool._database = mock.MagicMock() + pool.put = mock.MagicMock() + + def put_side_effect(*args, **kwargs): + pool._borrowed_sessions = [] + pool._cleanup_task_ongoing_event.clear() + + pool.put.side_effect = put_side_effect + + pool.logging_enabled = logging_enabled + pool._cleanup_task_ongoing_event.set() + pool._database.close_inactive_transactions = close_inactive_transactions + pool._borrowed_sessions = [] + pool._database.logger.warning = mock.MagicMock() + pool._format_trace = mock.MagicMock() + return pool + + def test_deleteLongRunningTransactions_noSessionsToDelete(self): + pool = self._setup_session_leak(True, True) + pool.deleteLongRunningTransactions(1, 1) + + # Assert that no warnings were logged and no sessions were put back + self.assertEqual(pool._database.logger.warning.call_count, 0) + pool.put.assert_not_called() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + def test_deleteLongRunningTransactions_deleteAndLogSession(self): + pool = self._setup_session_leak(True, True) + # Create a session that needs to be closed + session = mock.MagicMock() + session.transaction_logged = False + session.checkout_time = datetime.datetime.utcnow() - datetime.timedelta( + minutes=61 + ) + session.long_running = False + session._session_id = "session_id" + pool._traces["session_id"] = "trace" + pool._borrowed_sessions = [session] + pool._cleanup_task_ongoing_event.set() + # Call deleteLongRunningTransactions + pool.deleteLongRunningTransactions(2, 2) + + # Assert that the session was put back and a warning was logged + pool.put.assert_called_once() + pool._database.logger.warning.assert_called_once() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + def test_deleteLongRunningTransactions_logSession(self): + pool = self._setup_session_leak(False, True) + # Create a session that needs to be closed + session = mock.MagicMock() + session.transaction_logged = False + session.checkout_time = datetime.datetime.utcnow() - datetime.timedelta( + minutes=61 + ) + session.long_running = False + session._session_id = "session_id" + pool._traces["session_id"] = "trace" + pool._borrowed_sessions = [session] + pool._cleanup_task_ongoing_event.set() + # Call deleteLongRunningTransactions + pool.deleteLongRunningTransactions(2, 2) + + # Assert that the session was not put back and a warning was logged + pool.put.assert_not_called() + pool._database.logger.warning.assert_called_once() + self.assertTrue(session.transaction_logged) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + def test_deleteLongRunningTransactions_deleteSession(self): + pool = self._setup_session_leak(True, False) + # Create a session that needs to be closed + session = mock.MagicMock() + session.transaction_logged = False + session.checkout_time = datetime.datetime.utcnow() - datetime.timedelta( + minutes=61 + ) + session.long_running = False + session._session_id = "session_id" + pool._traces["session_id"] = "trace" + pool._borrowed_sessions = [session] + pool._cleanup_task_ongoing_event.set() + # Call deleteLongRunningTransactions + pool.deleteLongRunningTransactions(2, 2) + + # Assert that the session was not put back and a warning was logged + pool.put.assert_called_once() + pool._database.logger.warning.assert_not_called() + self.assertFalse(session.transaction_logged) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + def test_deleteLongRunningTransactions_close_if_no_transaction_is_released(self): + pool = self._setup_session_leak(False, True) + # Create a session that needs to be closed + session = mock.MagicMock() + session.transaction_logged = True + session.checkout_time = datetime.datetime.utcnow() - datetime.timedelta( + minutes=61 + ) + session.long_running = False + session._session_id = "session_id" + pool._traces["session_id"] = "trace" + pool._borrowed_sessions = [session] + pool._cleanup_task_ongoing_event.set() + # Call deleteLongRunningTransactions + pool.deleteLongRunningTransactions(2, 2) + + # Assert that background task was closed as there was no transaction to close. + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + class TestFixedSizePool(unittest.TestCase): def _getTargetClass(self): @@ -149,12 +370,17 @@ def test_ctor_defaults(self): self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) self.assertIsNone(pool.database_role) + self.assertTrue(pool.logging_enabled) def test_ctor_explicit(self): labels = {"foo": "bar"} database_role = "dummy-role" pool = self._make_one( - size=4, default_timeout=30, labels=labels, database_role=database_role + size=4, + default_timeout=30, + labels=labels, + database_role=database_role, + logging_enabled=False, ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) @@ -162,6 +388,7 @@ def test_ctor_explicit(self): self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) self.assertEqual(pool.database_role, database_role) + self.assertFalse(pool.logging_enabled) def test_bind(self): database_role = "dummy-role" @@ -197,6 +424,8 @@ def test_get_non_expired(self): self.assertIs(session, SESSIONS[i]) self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) + # Stop Long running session + pool.stopCleaningLongRunningSessions() def test_get_expired(self): pool = self._make_one(size=4) @@ -212,6 +441,29 @@ def test_get_expired(self): session.create.assert_called() self.assertTrue(SESSIONS[0]._exists_checked) self.assertFalse(pool._sessions.full()) + pool.stopCleaningLongRunningSessions() + + def test_get_trigger_longrunning_and_set_defaults(self): + pool = self._make_one(size=2) + database = _Database("name") + SESSIONS = [_Session(database)] * 3 + for session in SESSIONS: + session._exists = True + database._sessions.extend(SESSIONS) + pool.bind(database) + session = pool.get() + self.assertIsInstance(session.checkout_time, datetime.datetime) + self.assertFalse(session.long_running) + self.assertFalse(session.transaction_logged) + self.assertIs(session, pool._borrowed_sessions[0]) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + # Fetch new session which will trigger the cleanup task. + pool.get() + self.assertTrue(pool._cleanup_task_ongoing_event.is_set()) + + pool.stopCleaningLongRunningSessions() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) def test_get_empty_default_timeout(self): import queue @@ -257,8 +509,14 @@ def test_put_non_full(self): pool.bind(database) pool._sessions.get() - pool.put(_Session(database)) + session = _Session(database) + pool._borrowed_sessions.append(session) + pool._cleanup_task_ongoing_event.set() + + pool.put(session) + self.assertEqual(len(pool._borrowed_sessions), 0) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) self.assertTrue(pool._sessions.full()) def test_clear(self): @@ -274,8 +532,10 @@ def test_clear(self): for session in SESSIONS: session.create.assert_not_called() + pool._cleanup_task_ongoing_event.set() pool.clear() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) for session in SESSIONS: self.assertTrue(session._deleted) @@ -296,16 +556,23 @@ def test_ctor_defaults(self): self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) self.assertIsNone(pool.database_role) + self.assertTrue(pool.logging_enabled) def test_ctor_explicit(self): labels = {"foo": "bar"} database_role = "dummy-role" - pool = self._make_one(target_size=4, labels=labels, database_role=database_role) + pool = self._make_one( + target_size=4, + labels=labels, + database_role=database_role, + logging_enabled=False, + ) self.assertIsNone(pool._database) self.assertEqual(pool.target_size, 4) self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) self.assertEqual(pool.database_role, database_role) + self.assertFalse(pool.logging_enabled) def test_ctor_explicit_w_database_role_in_db(self): database_role = "dummy-role" @@ -328,6 +595,28 @@ def test_get_empty(self): session.create.assert_called() self.assertTrue(pool._sessions.empty()) + def test_get_trigger_longrunning_and_set_defaults(self): + pool = self._make_one(target_size=2) + database = _Database("name") + SESSIONS = [_Session(database)] * 3 + database._sessions.extend(SESSIONS) + pool.bind(database) + + session = pool.get() + session.create.assert_called() + self.assertIsInstance(session.checkout_time, datetime.datetime) + self.assertFalse(session.long_running) + self.assertFalse(session.transaction_logged) + self.assertTrue(len(pool._borrowed_sessions), 1) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + # Fetch new session which will trigger the cleanup task. + pool.get() + self.assertTrue(pool._cleanup_task_ongoing_event.is_set()) + + pool.stopCleaningLongRunningSessions() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + def test_get_non_empty_session_exists(self): pool = self._make_one() database = _Database("name") @@ -364,9 +653,13 @@ def test_put_empty(self): database = _Database("name") pool.bind(database) session = _Session(database) + pool._borrowed_sessions.append(session) + pool._cleanup_task_ongoing_event.set() pool.put(session) + self.assertEqual(len(pool._borrowed_sessions), 0) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) self.assertFalse(pool._sessions.empty()) def test_put_full(self): @@ -383,29 +676,16 @@ def test_put_full(self): self.assertTrue(younger._deleted) self.assertIs(pool.get(), older) - def test_put_full_expired(self): - pool = self._make_one(target_size=1) - database = _Database("name") - pool.bind(database) - older = _Session(database) - pool.put(older) - self.assertFalse(pool._sessions.empty()) - - younger = _Session(database, exists=False) - pool.put(younger) # discarded silently - - self.assertTrue(younger._deleted) - self.assertIs(pool.get(), older) - def test_clear(self): pool = self._make_one() database = _Database("name") pool.bind(database) previous = _Session(database) pool.put(previous) + pool._cleanup_task_ongoing_event.set() pool.clear() - + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) self.assertTrue(previous._deleted) @@ -427,6 +707,7 @@ def test_ctor_defaults(self): self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, {}) self.assertIsNone(pool.database_role) + self.assertTrue(pool.logging_enabled) def test_ctor_explicit(self): labels = {"foo": "bar"} @@ -437,6 +718,7 @@ def test_ctor_explicit(self): ping_interval=1800, labels=labels, database_role=database_role, + logging_enabled=False, ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) @@ -445,6 +727,7 @@ def test_ctor_explicit(self): self.assertTrue(pool._sessions.empty()) self.assertEqual(pool.labels, labels) self.assertEqual(pool.database_role, database_role) + self.assertFalse(pool.logging_enabled) def test_ctor_explicit_w_database_role_in_db(self): database_role = "dummy-role" @@ -487,6 +770,27 @@ def test_get_hit_no_ping(self): self.assertFalse(session._exists_checked) self.assertFalse(pool._sessions.full()) + def test_get_trigger_longrunning_and_set_defaults(self): + pool = self._make_one(size=2) + database = _Database("name") + SESSIONS = [_Session(database)] * 3 + database._sessions.extend(SESSIONS) + pool.bind(database) + session = pool.get() + self.assertIsInstance(session.checkout_time, datetime.datetime) + self.assertFalse(session.long_running) + self.assertFalse(session.transaction_logged) + self.assertTrue(len(pool._borrowed_sessions), 1) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + + # Fetch new session which will trigger the cleanup task. + pool.get() + self.assertTrue(pool._cleanup_task_ongoing_event.is_set()) + + # Stop the background task. + pool.stopCleaningLongRunningSessions() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + def test_get_hit_w_ping(self): import datetime from google.cloud._testing import _Monkey @@ -578,10 +882,15 @@ def test_put_non_full(self): now = datetime.datetime.utcnow() database = _Database("name") session = _Session(database) - + pool._borrowed_sessions.append(session) + pool._database = database + pool._cleanup_task_ongoing_event.set() with _Monkey(MUT, _NOW=lambda: now): pool.put(session) + self.assertEqual(len(pool._borrowed_sessions), 0) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + self.assertEqual(len(session_queue._items), 1) ping_after, queued = session_queue._items[0] self.assertEqual(ping_after, now + datetime.timedelta(seconds=3000)) @@ -599,9 +908,11 @@ def test_clear(self): self.assertEqual(api.batch_create_sessions.call_count, 5) for session in SESSIONS: session.create.assert_not_called() + pool._cleanup_task_ongoing_event.set() pool.clear() + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) for session in SESSIONS: self.assertTrue(session._deleted) @@ -676,6 +987,7 @@ def test_ctor_defaults(self): self.assertTrue(pool._pending_sessions.empty()) self.assertEqual(pool.labels, {}) self.assertIsNone(pool.database_role) + self.assertTrue(pool.logging_enabled) def test_ctor_explicit(self): labels = {"foo": "bar"} @@ -686,6 +998,7 @@ def test_ctor_explicit(self): ping_interval=1800, labels=labels, database_role=database_role, + logging_enabled=False, ) self.assertIsNone(pool._database) self.assertEqual(pool.size, 4) @@ -695,6 +1008,7 @@ def test_ctor_explicit(self): self.assertTrue(pool._pending_sessions.empty()) self.assertEqual(pool.labels, labels) self.assertEqual(pool.database_role, database_role) + self.assertFalse(pool.logging_enabled) def test_ctor_explicit_w_database_role_in_db(self): database_role = "dummy-role" @@ -778,7 +1092,9 @@ def test_put_non_full_w_active_txn(self): database = _Database("name") session = _Session(database) txn = session.transaction() - + pool._borrowed_sessions.append(session) + pool._cleanup_task_ongoing_event.set() + pool._database = database pool.put(session) self.assertEqual(len(session_queue._items), 1) @@ -786,17 +1102,24 @@ def test_put_non_full_w_active_txn(self): self.assertIs(queued, session) self.assertEqual(len(pending._items), 0) + self.assertEqual(len(pool._borrowed_sessions), 0) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + txn.begin.assert_not_called() def test_put_non_full_w_committed_txn(self): pool = self._make_one(size=1) session_queue = pool._sessions = _Queue() pending = pool._pending_sessions = _Queue() + database = _Database("name") session = _Session(database) + database._sessions.extend([session]) committed = session.transaction() committed.committed = True - + pool._borrowed_sessions.append(session) + pool._cleanup_task_ongoing_event.set() + pool._database = database pool.put(session) self.assertEqual(len(session_queue._items), 0) @@ -804,6 +1127,9 @@ def test_put_non_full_w_committed_txn(self): self.assertEqual(len(pending._items), 1) self.assertIs(pending._items[0], session) self.assertIsNot(session._transaction, committed) + self.assertEqual(len(pool._borrowed_sessions), 0) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) + session._transaction.begin.assert_not_called() def test_put_non_full(self): @@ -812,12 +1138,16 @@ def test_put_non_full(self): pending = pool._pending_sessions = _Queue() database = _Database("name") session = _Session(database) - + pool._database = database + pool._borrowed_sessions.append(session) + pool._cleanup_task_ongoing_event.set() pool.put(session) self.assertEqual(len(session_queue._items), 0) self.assertEqual(len(pending._items), 1) self.assertIs(pending._items[0], session) + self.assertEqual(len(pool._borrowed_sessions), 0) + self.assertFalse(pool._cleanup_task_ongoing_event.is_set()) self.assertFalse(pending.empty()) @@ -835,7 +1165,7 @@ def test_begin_pending_transactions_non_empty(self): pending = pool._pending_sessions = _Queue(*PENDING_SESSIONS) self.assertFalse(pending.empty()) - + pool._database = database pool.begin_pending_transactions() # no raise for txn in TRANSACTIONS: @@ -870,7 +1200,8 @@ def test_ctor_w_kwargs(self): ) def test_context_manager_wo_kwargs(self): - session = object() + database = _Database("name") + session = _Session(database) pool = _Pool(session) checkout = self._make_one(pool) @@ -886,7 +1217,8 @@ def test_context_manager_wo_kwargs(self): self.assertEqual(pool._got, {}) def test_context_manager_w_kwargs(self): - session = object() + database = _Database("name") + session = _Session(database) pool = _Pool(session) checkout = self._make_one(pool, foo="bar") @@ -923,6 +1255,7 @@ def __init__(self, database, exists=True, transaction=None): self.create = mock.Mock() self._deleted = False self._transaction = transaction + self._session_id = "".join(random.choices(string.ascii_letters, k=10)) def __lt__(self, other): return id(self) < id(other) @@ -957,6 +1290,10 @@ def __init__(self, name): self._database_role = None self.database_id = name self._route_to_leader_enabled = True + self.close_inactive_transactions = True + self._logger = mock.MagicMock() + self._logger.info = mock.MagicMock() + self._logger.warning = mock.MagicMock() def mock_batch_create_sessions( request=None, @@ -995,6 +1332,10 @@ def database_role(self): """ return self._database_role + @property + def logger(self): + return self._logger + def session(self, **kwargs): # always return first session in the list # to avoid reversing the order of putting diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 0010877396..4ae247df24 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -812,6 +812,19 @@ def test_execute_sql_other_error(self): attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) + def test_execute_sql_should_throw_error_for_recycled_session(self): + session = _Session() + derived = self._makeDerived(session) + derived._session = None + + with self.assertRaises(Exception) as cm: + list(derived.execute_sql(SQL_QUERY)) + + self.assertEqual( + str(cm.exception), + "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML.", + ) + def test_execute_sql_w_params_wo_param_types(self): database = _Database() session = _Session(database) @@ -1152,6 +1165,22 @@ def test_partition_read_other_error(self): ), ) + def test_partition_read_should_throw_error_is_session_is_none(self): + from google.cloud.spanner_v1.keyset import KeySet + + keyset = KeySet(all_=True) + session = _Session() + derived = self._makeDerived(session) + derived._session = None + + with self.assertRaises(Exception) as cm: + list(derived.partition_read(TABLE_NAME, COLUMNS, keyset)) + + self.assertEqual( + str(cm.exception), + "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML.", + ) + def test_partition_read_w_retry(self): from google.cloud.spanner_v1.keyset import KeySet from google.api_core.exceptions import InternalServerError @@ -1311,6 +1340,19 @@ def test_partition_query_other_error(self): attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) + def test_partition_query_should_throw_error_if_session_is_none(self): + session = _Session() + derived = self._makeDerived(session) + derived._session = None + + with self.assertRaises(Exception) as cm: + list(derived.partition_query(SQL_QUERY)) + + self.assertEqual( + str(cm.exception), + "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML.", + ) + def test_partition_query_w_params_wo_param_types(self): database = _Database() session = _Session(database) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index ffcffa115e..480be0a6c5 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -305,6 +305,17 @@ def test_commit_not_begun(self): self.assertNoSpans() + def test_commit_should_throw_error_for_recycled_session(self): + session = _Session() + transaction = self._make_one(session) + transaction._session = None + with self.assertRaises(Exception) as cm: + transaction.commit() + self.assertEqual( + str(cm.exception), + "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML.", + ) + def test_commit_already_committed(self): session = _Session() transaction = self._make_one(session) @@ -666,6 +677,17 @@ def test_batch_update_other_error(self): with self.assertRaises(RuntimeError): transaction.batch_update(statements=[DML_QUERY]) + def test_batch_update_should_throw_error_for_recycled_session(self): + session = _Session() + transaction = self._make_one(session) + transaction._session = None + with self.assertRaises(Exception) as cm: + transaction.batch_update(statements=[DML_QUERY]) + self.assertEqual( + str(cm.exception), + "Transaction has been closed as it was running for more than 60 minutes. If transaction is expected to run long, run as batch or partitioned DML.", + ) + def _batch_update_helper(self, error_after=None, count=0, request_options=None): from google.rpc.status_pb2 import Status from google.protobuf.struct_pb2 import Struct