From e880a96725c35e46804080a31e56934aefc711d8 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 19 Oct 2021 15:05:50 +0300 Subject: [PATCH 1/6] feat: support JSON data type --- .../cloud/sqlalchemy_spanner/requirements.py | 4 + .../sqlalchemy_spanner/sqlalchemy_spanner.py | 39 ++++ test/test_suite.py | 180 +++++++++++++++++- 3 files changed, 218 insertions(+), 5 deletions(-) diff --git a/google/cloud/sqlalchemy_spanner/requirements.py b/google/cloud/sqlalchemy_spanner/requirements.py index 1b929847..4032e93e 100644 --- a/google/cloud/sqlalchemy_spanner/requirements.py +++ b/google/cloud/sqlalchemy_spanner/requirements.py @@ -17,6 +17,10 @@ class Requirements(SuiteRequirements): + @property + def json_type(self): + return exclusions.open() + @property def foreign_key_constraint_name_reflection(self): return exclusions.open() diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index a9abec59..8eaf9d66 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -32,12 +32,20 @@ GenericTypeCompiler, IdentifierPreparer, SQLCompiler, + OPERATORS, RESERVED_WORDS, ) +from sqlalchemy.sql.default_comparator import operator_lookup +from sqlalchemy.sql.operators import json_getitem_op +from google.cloud.spanner_v1.data_types import JsonObject from google.cloud import spanner_dbapi from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call + +# register a method to get a single value of a JSON object +OPERATORS[json_getitem_op] = operator_lookup["json_getitem_op"] + # Spanner-to-SQLAlchemy types map _type_map = { "BOOL": types.Boolean, @@ -51,8 +59,10 @@ "TIME": types.TIME, "TIMESTAMP": types.TIMESTAMP, "ARRAY": types.ARRAY, + "JSON": types.JSON, } + _type_map_inv = { types.Boolean: "BOOL", types.BINARY: "BYTES(MAX)", @@ -197,6 +207,30 @@ def visit_like_op_binary(self, binary, operator, **kw): binary.right._compiler_dispatch(self, **kw), ) + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + """Build a JSON_VALUE() function call.""" + expr = """JSON_VALUE(%s, "$.%s")""" + + return expr % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + + def _generate_generic_binary(self, binary, opstring, eager_grouping=False, **kw): + """Build a JSON_VALUE() function arguments.""" + kw["_in_binary"] = True + right_value = getattr( + binary.right, "value", None + ) or binary.right._compiler_dispatch(self, eager_grouping=eager_grouping, **kw) + + text = ( + binary.left._compiler_dispatch(self, eager_grouping=eager_grouping, **kw) + + """, "$.""" + + right_value + + '"' + ) + return "JSON_VALUE(%s)" % text + def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. @@ -377,6 +411,9 @@ def visit_NUMERIC(self, type_, **kw): def visit_BIGINT(self, type_, **kw): return "INT64" + def visit_JSON(self, type_, **kw): + return "JSON" + class SpannerDialect(DefaultDialect): """Cloud Spanner dialect. @@ -407,6 +444,8 @@ class SpannerDialect(DefaultDialect): statement_compiler = SpannerSQLCompiler type_compiler = SpannerTypeCompiler execution_ctx_cls = SpannerExecutionContext + _json_serializer = JsonObject + _json_deserializer = JsonObject @classmethod def dbapi(cls): diff --git a/test/test_suite.py b/test/test_suite.py index 500114c3..c17c3841 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -20,6 +20,7 @@ import os import pkg_resources import pytest +import random from unittest import mock import sqlalchemy @@ -54,9 +55,7 @@ from sqlalchemy.types import Numeric from sqlalchemy.types import Text from sqlalchemy.testing import requires - from google.api_core.datetime_helpers import DatetimeWithNanoseconds - from google.cloud import spanner_dbapi from sqlalchemy.testing.suite.test_cte import * # noqa: F401, F403 @@ -92,15 +91,17 @@ ) from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403 + _DateFixture as _DateFixtureTest, + _LiteralRoundTripFixture, + _UnicodeFixture as _UnicodeFixtureTest, BooleanTest as _BooleanTest, DateTest as _DateTest, - _DateFixture as _DateFixtureTest, DateTimeHistoricTest, DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest, DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest, DateTimeTest as _DateTimeTest, IntegerTest as _IntegerTest, - _LiteralRoundTripFixture, + JSONTest as _JSONTest, NumericTest as _NumericTest, StringTest as _StringTest, TextTest as _TextTest, @@ -109,7 +110,6 @@ TimestampMicrosecondsTest, UnicodeVarcharTest as _UnicodeVarcharTest, UnicodeTextTest as _UnicodeTextTest, - _UnicodeFixture as _UnicodeFixtureTest, ) config.test_schema = "" @@ -1576,6 +1576,176 @@ def test_user_agent(self): ) +class JSONTest(_JSONTest): + def _index_fixtures(fn): + fn = testing.combinations( + ("boolean", True), + ("boolean", False), + ("boolean", None), + ("string", "some string"), + ("string", None), + ("string", util.u("réve illé")), + ( + "string", + util.u("réve🐍 illé"), + testing.requires.json_index_supplementary_unicode_element, + ), + ("integer", 15), + ("integer", 1), + ("integer", 0), + ("integer", None), + ("float", 28.5), + ("float", None), + # TODO: how to test for comaprison + # ("json", {"foo": "bar"}), + id_="sa", + )(fn) + return fn + + @pytest.mark.skip("Values without keys are not supported.") + def test_single_element_round_trip(self, element): + pass + + def _test_round_trip(self, data_element): + data_table = self.tables.data_table + + config.db.execute( + data_table.insert(), + {"id": random.randint(1, 100000000), "name": "row1", "data": data_element}, + ) + + row = config.db.execute(select([data_table.c.data])).first() + + eq_(row, (data_element,)) + + @_index_fixtures + def test_path_typed_comparison(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": {"subkey1": value}} + with config.db.connect() as conn: + conn.execute( + data_table.insert(), + { + "id": random.randint(1, 100000000), + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + expr = data_table.c.data[("key1", "subkey1")] + expr = getattr(expr, "as_%s" % datatype)() + + row = conn.execute(select([expr]).where(expr == value)).first() + + # make sure we get a row even if value is None + eq_(row, (value,)) + + def test_unicode_round_trip(self): + # note we include Unicode supplementary characters as well + with config.db.connect() as conn: + conn.execute( + self.tables.data_table.insert(), + { + "id": random.randint(1, 100000000), + "name": "r1", + "data": { + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, + }, + }, + ) + + eq_( + conn.scalar(select([self.tables.data_table.c.data])), + { + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, + }, + ) + + @pytest.mark.skip("Parameterized types are not supported.") + def test_eval_none_flag_orm(self): + pass + + @pytest.mark.skip( + "Spanner JSON_VALUE() always returns STRING," + "thus, this test case can't be executed." + ) + def test_index_typed_comparison(self): + pass + + @pytest.mark.skip( + "Spanner JSON_VALUE() always returns STRING," + "thus, this test case can't be executed." + ) + def test_path_typed_comparison(self): + pass + + @pytest.mark.skip("Custom JSON de-/serializers are not supported.") + def test_round_trip_custom_json(self): + pass + + def _index_fixtures(fn): + fn = testing.combinations( + ("boolean", True), + ("boolean", False), + ("boolean", None), + ("string", "some string"), + ("string", None), + ("integer", 15), + ("integer", 1), + ("integer", 0), + ("integer", None), + ("float", 28.5), + ("float", None), + id_="sa", + )(fn) + return fn + + @_index_fixtures + def test_index_typed_access(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + with config.db.connect() as conn: + conn.execute( + data_table.insert(), + { + "id": random.randint(1, 100000000), + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + expr = data_table.c.data["key1"] + expr = getattr(expr, "as_%s" % datatype)() + + roundtrip = conn.scalar(select([expr])) + if roundtrip in ("true", "false", None): + roundtrip = str(roundtrip).capitalize() + + eq_(str(roundtrip), str(value)) + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_json_null_as_json_null(self): + pass + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_none_as_json_null(self): + pass + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_none_as_sql_null(self): + pass + + class ExecutionOptionsTest(fixtures.TestBase): """ Check that `execution_options()` method correctly From db062646f31bedff89d21a97da227ce347eea683 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 19 Oct 2021 15:09:09 +0300 Subject: [PATCH 2/6] fix type --- google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 8eaf9d66..4fb3933d 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -226,7 +226,7 @@ def _generate_generic_binary(self, binary, opstring, eager_grouping=False, **kw) text = ( binary.left._compiler_dispatch(self, eager_grouping=eager_grouping, **kw) + """, "$.""" - + right_value + + str(right_value) + '"' ) return "JSON_VALUE(%s)" % text From 3a5d90bf5271a0bf86f542d2fc40a536660561bc Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 23 Nov 2021 12:40:11 +0300 Subject: [PATCH 3/6] bug fixes --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 0bfc7926..1fbdd642 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -50,7 +50,7 @@ def reset_connection(dbapi_conn, connection_record): """An event of returning a connection back to a pool.""" dbapi_conn.connection.staleness = None - + # register a method to get a single value of a JSON object OPERATORS[json_getitem_op] = operator_lookup["json_getitem_op"] @@ -229,21 +229,6 @@ def visit_json_path_getitem_op_binary(self, binary, operator, **kw): self.process(binary.right, **kw), ) - def _generate_generic_binary(self, binary, opstring, eager_grouping=False, **kw): - """Build a JSON_VALUE() function arguments.""" - kw["_in_binary"] = True - right_value = getattr( - binary.right, "value", None - ) or binary.right._compiler_dispatch(self, eager_grouping=eager_grouping, **kw) - - text = ( - binary.left._compiler_dispatch(self, eager_grouping=eager_grouping, **kw) - + """, "$.""" - + str(right_value) - + '"' - ) - return "JSON_VALUE(%s)" % text - def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. From ed8e149dcb23744d838bf27058943df86b7ba856 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 23 Nov 2021 12:43:44 +0300 Subject: [PATCH 4/6] erase excess test override --- test/test_suite.py | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/test/test_suite.py b/test/test_suite.py index 9617ab78..28854ec4 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -1622,29 +1622,6 @@ def _test_round_trip(self, data_element): eq_(row, (data_element,)) - @_index_fixtures - def test_path_typed_comparison(self, datatype, value): - data_table = self.tables.data_table - data_element = {"key1": {"subkey1": value}} - with config.db.connect() as conn: - conn.execute( - data_table.insert(), - { - "id": random.randint(1, 100000000), - "name": "row1", - "data": data_element, - "nulldata": data_element, - }, - ) - - expr = data_table.c.data[("key1", "subkey1")] - expr = getattr(expr, "as_%s" % datatype)() - - row = conn.execute(select([expr]).where(expr == value)).first() - - # make sure we get a row even if value is None - eq_(row, (value,)) - def test_unicode_round_trip(self): # note we include Unicode supplementary characters as well with config.db.connect() as conn: From b4390a0779f46d82758a1786dc41c9d8e898b3c1 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Tue, 23 Nov 2021 12:51:27 +0300 Subject: [PATCH 5/6] erase excess override --- test/test_suite.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/test/test_suite.py b/test/test_suite.py index 28854ec4..0c8010c9 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -1581,31 +1581,6 @@ def test_user_agent(self): class JSONTest(_JSONTest): - def _index_fixtures(fn): - fn = testing.combinations( - ("boolean", True), - ("boolean", False), - ("boolean", None), - ("string", "some string"), - ("string", None), - ("string", util.u("réve illé")), - ( - "string", - util.u("réve🐍 illé"), - testing.requires.json_index_supplementary_unicode_element, - ), - ("integer", 15), - ("integer", 1), - ("integer", 0), - ("integer", None), - ("float", 28.5), - ("float", None), - # TODO: how to test for comaprison - # ("json", {"foo": "bar"}), - id_="sa", - )(fn) - return fn - @pytest.mark.skip("Values without keys are not supported.") def test_single_element_round_trip(self, element): pass From 374396952674c09d47006db0d23b28e4eb309ba2 Mon Sep 17 00:00:00 2001 From: IlyaFaer Date: Wed, 24 Nov 2021 13:19:01 +0300 Subject: [PATCH 6/6] fix errors --- .../sqlalchemy_spanner/sqlalchemy_spanner.py | 38 +++++++++++++++++++ test/test_suite.py | 3 ++ 2 files changed, 41 insertions(+) diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 1fbdd642..cb9d27cc 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -220,6 +220,44 @@ def visit_like_op_binary(self, binary, operator, **kw): binary.right._compiler_dispatch(self, **kw), ) + def _generate_generic_binary(self, binary, opstring, eager_grouping=False, **kw): + """The method is overriden to process JSON data type cases.""" + _in_binary = kw.get("_in_binary", False) + + kw["_in_binary"] = True + + if isinstance(opstring, str): + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) + if _in_binary and eager_grouping: + text = "(%s)" % text + else: + # got JSON data + right_value = getattr( + binary.right, "value", None + ) or binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + """, "$.""" + + str(right_value) + + '"' + ) + text = "JSON_VALUE(%s)" % text + + return text + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): """Build a JSON_VALUE() function call.""" expr = """JSON_VALUE(%s, "$.%s")""" diff --git a/test/test_suite.py b/test/test_suite.py index 0c8010c9..3992b3d1 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -1580,6 +1580,9 @@ def test_user_agent(self): ) +@pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" +) class JSONTest(_JSONTest): @pytest.mark.skip("Values without keys are not supported.") def test_single_element_round_trip(self, element):