Skip to content

Commit

Permalink
feat:enable instance-level connection
Browse files Browse the repository at this point in the history
  • Loading branch information
asthamohta committed Apr 19, 2023
1 parent 6c8672b commit 6d7275c
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 9 deletions.
22 changes: 17 additions & 5 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Connection:
should end a that a new one should be started when the next statement is executed.
"""

def __init__(self, instance, database, read_only=False):
def __init__(self, instance, database=None, read_only=False):
self._instance = instance
self._database = database
self._ddl_statements = []
Expand Down Expand Up @@ -242,6 +242,8 @@ def _session_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.session.Session`
:returns: Cloud Spanner session object ready to use.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
if not self._session:
self._session = self.database._pool.get()

Expand All @@ -252,6 +254,8 @@ def _release_session(self):
The session will be returned into the sessions pool.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self.database._pool.put(self._session)
self._session = None

Expand Down Expand Up @@ -368,7 +372,7 @@ def close(self):
if self.inside_transaction:
self._transaction.rollback()

if self._own_pool:
if self._own_pool and self.database:
self.database._pool.clear()

self.is_closed = True
Expand All @@ -378,6 +382,8 @@ def commit(self):
This method is non-operational in autocommit mode.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self._snapshot = None

if self._autocommit:
Expand Down Expand Up @@ -420,6 +426,8 @@ def cursor(self):

@check_not_closed
def run_prior_DDL_statements(self):
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
if self._ddl_statements:
ddl_statements = self._ddl_statements
self._ddl_statements = []
Expand Down Expand Up @@ -474,6 +482,8 @@ def validate(self):
:raises: :class:`google.cloud.exceptions.NotFound`: if the linked instance
or database doesn't exist.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
with self.database.snapshot() as snapshot:
result = list(snapshot.execute_sql("SELECT 1"))
if result != [[1]]:
Expand All @@ -492,7 +502,7 @@ def __exit__(self, etype, value, traceback):

def connect(
instance_id,
database_id,
database_id=None,
project=None,
credentials=None,
pool=None,
Expand All @@ -505,7 +515,7 @@ def connect(
:param instance_id: The ID of the instance to connect to.
:type database_id: str
:param database_id: The ID of the database to connect to.
:param database_id: (Optional) The ID of the database to connect to.
:type project: str
:param project: (Optional) The ID of the project which owns the
Expand Down Expand Up @@ -557,7 +567,9 @@ def connect(
raise ValueError("project in url does not match client object project")

instance = client.instance(instance_id)
conn = Connection(instance, instance.database(database_id, pool=pool))
conn = Connection(
instance, instance.database(database_id, pool=pool) if database_id else None
)
if pool is not None:
conn._own_pool = False

Expand Down
8 changes: 8 additions & 0 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def execute(self, sql, args=None):
:type args: list
:param args: Additional parameters to supplement the SQL query.
"""
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT
Expand Down Expand Up @@ -301,6 +303,8 @@ def executemany(self, operation, seq_of_params):
:param seq_of_params: Sequence of additional parameters to run
the query with.
"""
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT
Expand Down Expand Up @@ -444,6 +448,8 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
self._row_count = _UNSET_COUNT

def _handle_DQL(self, sql, params):
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
if self.connection.read_only and not self.connection.autocommit:
# initiate or use the existing multi-use snapshot
Expand Down Expand Up @@ -484,6 +490,8 @@ def list_tables(self):
def run_sql_in_snapshot(self, sql, params=None, param_types=None):
# Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions
# hence this method exists to circumvent that limit.
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
self.connection.run_prior_DDL_statements()

with self.connection.database.snapshot() as snapshot:
Expand Down
50 changes: 46 additions & 4 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ def test__session_checkout(self, mock_database):
connection._session_checkout()
self.assertEqual(connection._session, "db_session")

def test__session_checkout_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)

with pytest.raises(ValueError):
connection._session_checkout()

@mock.patch("google.cloud.spanner_v1.database.Database")
def test__release_session(self, mock_database):
from google.cloud.spanner_dbapi import Connection
Expand All @@ -182,6 +190,13 @@ def test__release_session(self, mock_database):
pool.put.assert_called_once_with("session")
self.assertIsNone(connection._session)

def test__release_session_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
with pytest.raises(ValueError):
connection._release_session()

def test_transaction_checkout(self):
from google.cloud.spanner_dbapi import Connection

Expand Down Expand Up @@ -294,6 +309,14 @@ def test_commit(self, mock_warn):
AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2
)

def test_commit_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)

with pytest.raises(ValueError):
connection.commit()

@mock.patch.object(warnings, "warn")
def test_rollback(self, mock_warn):
from google.cloud.spanner_dbapi import Connection
Expand Down Expand Up @@ -347,6 +370,13 @@ def test_run_prior_DDL_statements(self, mock_database):
with self.assertRaises(InterfaceError):
connection.run_prior_DDL_statements()

def test_run_prior_DDL_statements_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
with pytest.raises(ValueError):
connection.run_prior_DDL_statements()

def test_as_context_manager(self):
connection = self._make_connection()
with connection as conn:
Expand Down Expand Up @@ -766,6 +796,14 @@ def test_validate_error(self):

snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")

def test_validate_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)

with pytest.raises(ValueError):
connection.validate()

def test_validate_closed(self):
from google.cloud.spanner_dbapi.exceptions import InterfaceError

Expand Down Expand Up @@ -916,16 +954,14 @@ def test_request_priority(self):
sql, params, param_types=param_types, request_options=None
)

@mock.patch("google.cloud.spanner_v1.Client")
def test_custom_client_connection(self, mock_client):
def test_custom_client_connection(self):
from google.cloud.spanner_dbapi import connect

client = _Client()
connection = connect("test-instance", "test-database", client=client)
self.assertTrue(connection.instance._client == client)

@mock.patch("google.cloud.spanner_v1.Client")
def test_invalid_custom_client_connection(self, mock_client):
def test_invalid_custom_client_connection(self):
from google.cloud.spanner_dbapi import connect

client = _Client()
Expand All @@ -937,6 +973,12 @@ def test_invalid_custom_client_connection(self, mock_client):
client=client,
)

def test_connection_wo_database(self):
from google.cloud.spanner_dbapi import connect

connection = connect("test-instance")
self.assertTrue(connection.database is None)


def exit_ctx_func(self, exc_type, exc_value, traceback):
"""Context __exit__ method mock."""
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ def test_execute_attribute_error(self):
with self.assertRaises(AttributeError):
cursor.execute(sql="SELECT 1")

def test_execute_database_error(self):
connection = self._make_connection(self.INSTANCE)
cursor = self._make_one(connection)

with self.assertRaises(ValueError):
cursor.execute(sql="SELECT 1")

def test_execute_autocommit_off(self):
from google.cloud.spanner_dbapi.utils import PeekIterator

Expand Down Expand Up @@ -607,6 +614,16 @@ def test_executemany_insert_batch_aborted(self):
)
self.assertIsInstance(connection._statements[0][1], ResultsChecksum)

@mock.patch("google.cloud.spanner_v1.Client")
def test_executemany_database_error(self, mock_client):
from google.cloud.spanner_dbapi import connect

connection = connect("test-instance")
cursor = connection.cursor()

with self.assertRaises(ValueError):
cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ())

@unittest.skipIf(
sys.version_info[0] < 3, "Python 2 has an outdated iterator definition"
)
Expand Down Expand Up @@ -754,6 +771,13 @@ def test_handle_dql_priority(self):
sql, None, None, request_options=RequestOptions(priority=1)
)

def test_handle_dql_database_error(self):
connection = self._make_connection(self.INSTANCE)
cursor = self._make_one(connection)

with self.assertRaises(ValueError):
cursor._handle_DQL("sql", params=None)

def test_context(self):
connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)
Expand Down Expand Up @@ -814,6 +838,13 @@ def test_run_sql_in_snapshot(self):
mock_snapshot.execute_sql.return_value = results
self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results))

def test_run_sql_in_snapshot_database_error(self):
connection = self._make_connection(self.INSTANCE)
cursor = self._make_one(connection)

with self.assertRaises(ValueError):
cursor.run_sql_in_snapshot("sql")

def test_get_table_column_schema(self):
from google.cloud.spanner_dbapi.cursor import ColumnDetails
from google.cloud.spanner_dbapi import _helpers
Expand Down

0 comments on commit 6d7275c

Please sign in to comment.