Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support JSON data type #135

Merged
merged 9 commits into from Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions google/cloud/sqlalchemy_spanner/requirements.py
Expand Up @@ -17,6 +17,10 @@


class Requirements(SuiteRequirements): # pragma: no cover
@property
def json_type(self):
return exclusions.open()

@property
def computed_columns(self):
return exclusions.open()
Expand Down
62 changes: 62 additions & 0 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Expand Up @@ -34,9 +34,13 @@
GenericTypeCompiler,
IdentifierPreparer,
SQLCompiler,
OPERATORS,
RESERVED_WORDS,
)
from sqlalchemy.sql.default_comparator import operator_lookup
from sqlalchemy.sql.operators import json_getitem_op

from google.cloud.spanner_v1.data_types import JsonObject
from google.cloud import spanner_dbapi
from google.cloud.sqlalchemy_spanner._opentelemetry_tracing import trace_call

Expand All @@ -47,6 +51,10 @@ def reset_connection(dbapi_conn, connection_record):
dbapi_conn.connection.staleness = None


# register a method to get a single value of a JSON object
OPERATORS[json_getitem_op] = operator_lookup["json_getitem_op"]


# Spanner-to-SQLAlchemy types map
_type_map = {
"BOOL": types.Boolean,
Expand All @@ -60,8 +68,10 @@ def reset_connection(dbapi_conn, connection_record):
"TIME": types.TIME,
"TIMESTAMP": types.TIMESTAMP,
"ARRAY": types.ARRAY,
"JSON": types.JSON,
}


_type_map_inv = {
types.Boolean: "BOOL",
types.BINARY: "BYTES(MAX)",
Expand Down Expand Up @@ -210,6 +220,53 @@ def visit_like_op_binary(self, binary, operator, **kw):
binary.right._compiler_dispatch(self, **kw),
)

def _generate_generic_binary(self, binary, opstring, eager_grouping=False, **kw):
"""The method is overriden to process JSON data type cases."""
_in_binary = kw.get("_in_binary", False)

kw["_in_binary"] = True

if isinstance(opstring, str):
text = (
binary.left._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw
)
+ opstring
+ binary.right._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw
)
)
if _in_binary and eager_grouping:
text = "(%s)" % text
else:
# got JSON data
right_value = getattr(
binary.right, "value", None
) or binary.right._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw
)

text = (
binary.left._compiler_dispatch(
self, eager_grouping=eager_grouping, **kw
)
+ """, "$."""
+ str(right_value)
+ '"'
)
text = "JSON_VALUE(%s)" % text

return text

def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
"""Build a JSON_VALUE() function call."""
expr = """JSON_VALUE(%s, "$.%s")"""

return expr % (
self.process(binary.left, **kw),
self.process(binary.right, **kw),
)

def render_literal_value(self, value, type_):
"""Render the value of a bind parameter as a quoted literal.

Expand Down Expand Up @@ -404,6 +461,9 @@ def visit_NUMERIC(self, type_, **kw):
def visit_BIGINT(self, type_, **kw):
return "INT64"

def visit_JSON(self, type_, **kw):
return "JSON"


class SpannerDialect(DefaultDialect):
"""Cloud Spanner dialect.
Expand Down Expand Up @@ -434,6 +494,8 @@ class SpannerDialect(DefaultDialect):
statement_compiler = SpannerSQLCompiler
type_compiler = SpannerTypeCompiler
execution_ctx_cls = SpannerExecutionContext
_json_serializer = JsonObject
_json_deserializer = JsonObject

@classmethod
def dbapi(cls):
Expand Down
134 changes: 130 additions & 4 deletions test/test_suite.py
Expand Up @@ -20,6 +20,7 @@
import os
import pkg_resources
import pytest
import random
import unittest
from unittest import mock

Expand Down Expand Up @@ -61,7 +62,6 @@
)

from google.api_core.datetime_helpers import DatetimeWithNanoseconds

from google.cloud import spanner_dbapi

from sqlalchemy.testing.suite.test_cte import * # noqa: F401, F403
Expand Down Expand Up @@ -98,15 +98,17 @@
)
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403
_DateFixture as _DateFixtureTest,
_LiteralRoundTripFixture,
_UnicodeFixture as _UnicodeFixtureTest,
BooleanTest as _BooleanTest,
DateTest as _DateTest,
_DateFixture as _DateFixtureTest,
DateTimeHistoricTest,
DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest,
DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest,
DateTimeTest as _DateTimeTest,
IntegerTest as _IntegerTest,
_LiteralRoundTripFixture,
JSONTest as _JSONTest,
NumericTest as _NumericTest,
StringTest as _StringTest,
TextTest as _TextTest,
Expand All @@ -115,7 +117,6 @@
TimestampMicrosecondsTest,
UnicodeVarcharTest as _UnicodeVarcharTest,
UnicodeTextTest as _UnicodeTextTest,
_UnicodeFixture as _UnicodeFixtureTest,
)
from test._helpers import get_db_url

Expand Down Expand Up @@ -1751,3 +1752,128 @@ def test_get_column_returns_computed(self):
is_true("computed" in compData)
is_true("sqltext" in compData["computed"])
eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")


@pytest.mark.skipif(
bool(os.environ.get("SPANNER_EMULATOR_HOST")), reason="Skipped on emulator"
)
class JSONTest(_JSONTest):
@pytest.mark.skip("Values without keys are not supported.")
def test_single_element_round_trip(self, element):
pass

def _test_round_trip(self, data_element):
data_table = self.tables.data_table

config.db.execute(
data_table.insert(),
{"id": random.randint(1, 100000000), "name": "row1", "data": data_element},
)

row = config.db.execute(select([data_table.c.data])).first()

eq_(row, (data_element,))

def test_unicode_round_trip(self):
# note we include Unicode supplementary characters as well
with config.db.connect() as conn:
conn.execute(
self.tables.data_table.insert(),
{
"id": random.randint(1, 100000000),
"name": "r1",
"data": {
util.u("réve🐍 illé"): util.u("réve🐍 illé"),
"data": {"k1": util.u("drôl🐍e")},
},
},
)

eq_(
conn.scalar(select([self.tables.data_table.c.data])),
{
util.u("réve🐍 illé"): util.u("réve🐍 illé"),
"data": {"k1": util.u("drôl🐍e")},
},
)

@pytest.mark.skip("Parameterized types are not supported.")
def test_eval_none_flag_orm(self):
pass

@pytest.mark.skip(
"Spanner JSON_VALUE() always returns STRING,"
"thus, this test case can't be executed."
)
def test_index_typed_comparison(self):
pass

@pytest.mark.skip(
"Spanner JSON_VALUE() always returns STRING,"
"thus, this test case can't be executed."
)
def test_path_typed_comparison(self):
pass

@pytest.mark.skip("Custom JSON de-/serializers are not supported.")
def test_round_trip_custom_json(self):
pass

def _index_fixtures(fn):
fn = testing.combinations(
("boolean", True),
("boolean", False),
("boolean", None),
("string", "some string"),
("string", None),
("integer", 15),
("integer", 1),
("integer", 0),
("integer", None),
("float", 28.5),
("float", None),
id_="sa",
)(fn)
return fn

@_index_fixtures
def test_index_typed_access(self, datatype, value):
data_table = self.tables.data_table
data_element = {"key1": value}
with config.db.connect() as conn:
conn.execute(
data_table.insert(),
{
"id": random.randint(1, 100000000),
"name": "row1",
"data": data_element,
"nulldata": data_element,
},
)

expr = data_table.c.data["key1"]
expr = getattr(expr, "as_%s" % datatype)()

roundtrip = conn.scalar(select([expr]))
if roundtrip in ("true", "false", None):
roundtrip = str(roundtrip).capitalize()

eq_(str(roundtrip), str(value))

@pytest.mark.skip(
"Spanner doesn't support type casts inside JSON_VALUE() function."
)
def test_round_trip_json_null_as_json_null(self):
pass

@pytest.mark.skip(
"Spanner doesn't support type casts inside JSON_VALUE() function."
)
def test_round_trip_none_as_json_null(self):
pass

@pytest.mark.skip(
"Spanner doesn't support type casts inside JSON_VALUE() function."
)
def test_round_trip_none_as_sql_null(self):
pass