Skip to content

Commit

Permalink
feat: Add PingingPool check on session age
Browse files Browse the repository at this point in the history
  • Loading branch information
MostafaOmar98 committed Jan 6, 2024
1 parent d3fe937 commit 783ca52
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
10 changes: 9 additions & 1 deletion google/cloud/spanner_v1/pool.py
Expand Up @@ -385,6 +385,8 @@ class PingingPool(AbstractSessionPool):
:param database_role: (Optional) user-assigned database_role for the session.
"""

SESSION_MAX_AGE = 28 * 24 * 60 * 60

def __init__(
self,
size=10,
Expand Down Expand Up @@ -448,8 +450,9 @@ def get(self, timeout=None):
timeout = self.default_timeout

ping_after, session = self._sessions.get(block=True, timeout=timeout)
session_age = (_NOW() - session._created_at).total_seconds()

if _NOW() > ping_after:
if _NOW() > ping_after or session_age >= self.SESSION_MAX_AGE:
# Using session.exists() guarantees the returned session exists.
# session.ping() uses a cached result in the backend which could
# result in a recently deleted session being returned.
Expand Down Expand Up @@ -481,6 +484,11 @@ def clear(self):
else:
session.delete()

def _new_session(self):
session = super()._new_session()
session._created_at = _NOW()
return session

def ping(self):
"""Refresh maybe-expired sessions in the pool.
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/test_pool.py
Expand Up @@ -531,6 +531,48 @@ def test_get_hit_w_ping_expired(self):
self.assertTrue(SESSIONS[0]._exists_checked)
self.assertFalse(pool._sessions.full())

def test_get_hit_w_created(self):
import datetime

pool = self._make_one(size=4)
database = _Database("name")
SESSIONS = [_Session(database)] * 4
database._sessions.extend(SESSIONS)
pool.bind(database)

session_max_age = 28 * 24 * 60 * 60
SESSIONS[0]._created_at = datetime.datetime.utcnow() - datetime.timedelta(
seconds=session_max_age + 10
)

session = pool.get()

self.assertIs(session, SESSIONS[0])
self.assertTrue(session._exists_checked)
self.assertFalse(pool._sessions.full())

def test_get_hit_w_created_expired(self):
import datetime

pool = self._make_one(size=4)
database = _Database("name")
SESSIONS = [_Session(database)] * 5
database._sessions.extend(SESSIONS)
pool.bind(database)

session_max_age = 28 * 24 * 60 * 60
SESSIONS[0]._created_at = datetime.datetime.utcnow() - datetime.timedelta(
seconds=session_max_age
)
SESSIONS[0]._exists = False

session = pool.get()

self.assertIs(session, SESSIONS[4])
session.create.assert_called()
self.assertTrue(SESSIONS[0]._exists_checked)
self.assertFalse(pool._sessions.full())

def test_get_empty_default_timeout(self):
import queue

Expand Down

0 comments on commit 783ca52

Please sign in to comment.