diff --git a/google/cloud/sqlalchemy_spanner/requirements.py b/google/cloud/sqlalchemy_spanner/requirements.py index 8e30dd7a..d552dc34 100644 --- a/google/cloud/sqlalchemy_spanner/requirements.py +++ b/google/cloud/sqlalchemy_spanner/requirements.py @@ -17,6 +17,10 @@ class Requirements(SuiteRequirements): # pragma: no cover + @property + def json_type(self): + return exclusions.open() + @property def computed_columns(self): return exclusions.open() diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 68a0f94e..ff01e277 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -34,9 +34,13 @@ GenericTypeCompiler, IdentifierPreparer, SQLCompiler, + OPERATORS, RESERVED_WORDS, ) +from sqlalchemy.sql.default_comparator import operator_lookup +from sqlalchemy.sql.operators import json_getitem_op +from google.cloud.spanner_v1.data_types import JsonObject from google.cloud import spanner_dbapi from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call @@ -47,6 +51,10 @@ def reset_connection(dbapi_conn, connection_record): dbapi_conn.connection.staleness = None +# register a method to get a single value of a JSON object +OPERATORS[json_getitem_op] = operator_lookup["json_getitem_op"] + + # Spanner-to-SQLAlchemy types map _type_map = { "BOOL": types.Boolean, @@ -60,8 +68,10 @@ def reset_connection(dbapi_conn, connection_record): "TIME": types.TIME, "TIMESTAMP": types.TIMESTAMP, "ARRAY": types.ARRAY, + "JSON": types.JSON, } + _type_map_inv = { types.Boolean: "BOOL", types.BINARY: "BYTES(MAX)", @@ -210,6 +220,53 @@ def visit_like_op_binary(self, binary, operator, **kw): binary.right._compiler_dispatch(self, **kw), ) + def _generate_generic_binary(self, binary, opstring, eager_grouping=False, **kw): + """The method is overriden to process JSON data type cases.""" + _in_binary = kw.get("_in_binary", False) + + kw["_in_binary"] = True + + if isinstance(opstring, str): + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + opstring + + binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + ) + if _in_binary and eager_grouping: + text = "(%s)" % text + else: + # got JSON data + right_value = getattr( + binary.right, "value", None + ) or binary.right._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + text = ( + binary.left._compiler_dispatch( + self, eager_grouping=eager_grouping, **kw + ) + + """, "$.""" + + str(right_value) + + '"' + ) + text = "JSON_VALUE(%s)" % text + + return text + + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): + """Build a JSON_VALUE() function call.""" + expr = """JSON_VALUE(%s, "$.%s")""" + + return expr % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. @@ -404,6 +461,9 @@ def visit_NUMERIC(self, type_, **kw): def visit_BIGINT(self, type_, **kw): return "INT64" + def visit_JSON(self, type_, **kw): + return "JSON" + class SpannerDialect(DefaultDialect): """Cloud Spanner dialect. @@ -434,6 +494,8 @@ class SpannerDialect(DefaultDialect): statement_compiler = SpannerSQLCompiler type_compiler = SpannerTypeCompiler execution_ctx_cls = SpannerExecutionContext + _json_serializer = JsonObject + _json_deserializer = JsonObject @classmethod def dbapi(cls): diff --git a/test/test_suite.py b/test/test_suite.py index 7915c49c..b2339466 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -20,6 +20,7 @@ import os import pkg_resources import pytest +import random import unittest from unittest import mock @@ -61,7 +62,6 @@ ) from google.api_core.datetime_helpers import DatetimeWithNanoseconds - from google.cloud import spanner_dbapi from sqlalchemy.testing.suite.test_cte import * # noqa: F401, F403 @@ -98,15 +98,17 @@ ) from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403 + _DateFixture as _DateFixtureTest, + _LiteralRoundTripFixture, + _UnicodeFixture as _UnicodeFixtureTest, BooleanTest as _BooleanTest, DateTest as _DateTest, - _DateFixture as _DateFixtureTest, DateTimeHistoricTest, DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest, DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest, DateTimeTest as _DateTimeTest, IntegerTest as _IntegerTest, - _LiteralRoundTripFixture, + JSONTest as _JSONTest, NumericTest as _NumericTest, StringTest as _StringTest, TextTest as _TextTest, @@ -115,7 +117,6 @@ TimestampMicrosecondsTest, UnicodeVarcharTest as _UnicodeVarcharTest, UnicodeTextTest as _UnicodeTextTest, - _UnicodeFixture as _UnicodeFixtureTest, ) from test._helpers import get_db_url @@ -1751,3 +1752,128 @@ def test_get_column_returns_computed(self): is_true("computed" in compData) is_true("sqltext" in compData["computed"]) eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42") + + +@pytest.mark.skipif( + bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator" +) +class JSONTest(_JSONTest): + @pytest.mark.skip("Values without keys are not supported.") + def test_single_element_round_trip(self, element): + pass + + def _test_round_trip(self, data_element): + data_table = self.tables.data_table + + config.db.execute( + data_table.insert(), + {"id": random.randint(1, 100000000), "name": "row1", "data": data_element}, + ) + + row = config.db.execute(select([data_table.c.data])).first() + + eq_(row, (data_element,)) + + def test_unicode_round_trip(self): + # note we include Unicode supplementary characters as well + with config.db.connect() as conn: + conn.execute( + self.tables.data_table.insert(), + { + "id": random.randint(1, 100000000), + "name": "r1", + "data": { + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, + }, + }, + ) + + eq_( + conn.scalar(select([self.tables.data_table.c.data])), + { + util.u("réve🐍 illé"): util.u("réve🐍 illé"), + "data": {"k1": util.u("drôl🐍e")}, + }, + ) + + @pytest.mark.skip("Parameterized types are not supported.") + def test_eval_none_flag_orm(self): + pass + + @pytest.mark.skip( + "Spanner JSON_VALUE() always returns STRING," + "thus, this test case can't be executed." + ) + def test_index_typed_comparison(self): + pass + + @pytest.mark.skip( + "Spanner JSON_VALUE() always returns STRING," + "thus, this test case can't be executed." + ) + def test_path_typed_comparison(self): + pass + + @pytest.mark.skip("Custom JSON de-/serializers are not supported.") + def test_round_trip_custom_json(self): + pass + + def _index_fixtures(fn): + fn = testing.combinations( + ("boolean", True), + ("boolean", False), + ("boolean", None), + ("string", "some string"), + ("string", None), + ("integer", 15), + ("integer", 1), + ("integer", 0), + ("integer", None), + ("float", 28.5), + ("float", None), + id_="sa", + )(fn) + return fn + + @_index_fixtures + def test_index_typed_access(self, datatype, value): + data_table = self.tables.data_table + data_element = {"key1": value} + with config.db.connect() as conn: + conn.execute( + data_table.insert(), + { + "id": random.randint(1, 100000000), + "name": "row1", + "data": data_element, + "nulldata": data_element, + }, + ) + + expr = data_table.c.data["key1"] + expr = getattr(expr, "as_%s" % datatype)() + + roundtrip = conn.scalar(select([expr])) + if roundtrip in ("true", "false", None): + roundtrip = str(roundtrip).capitalize() + + eq_(str(roundtrip), str(value)) + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_json_null_as_json_null(self): + pass + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_none_as_json_null(self): + pass + + @pytest.mark.skip( + "Spanner doesn't support type casts inside JSON_VALUE() function." + ) + def test_round_trip_none_as_sql_null(self): + pass