Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(db_api): support JSON data type #627

Merged
merged 13 commits into from Nov 22, 2021
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/cursor.py
Expand Up @@ -223,6 +223,7 @@ def execute(self, sql, args=None):
ResultsChecksum(),
classification == parse_utils.STMT_INSERT,
)

(self._result_set, self._checksum,) = self.connection.run_statement(
statement
)
Expand Down
8 changes: 3 additions & 5 deletions google/cloud/spanner_v1/_helpers.py
Expand Up @@ -17,7 +17,6 @@
import datetime
import decimal
import math
import json

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
Expand Down Expand Up @@ -166,9 +165,8 @@ def _make_value_pb(value):
_assert_numeric_precision_and_scale(value)
return Value(string_value=str(value))
if isinstance(value, JsonObject):
return Value(
string_value=json.dumps(value, sort_keys=True, separators=(",", ":"),)
)
return Value(string_value=value.serialize())

raise ValueError("Unknown type: %s" % (value,))


Expand Down Expand Up @@ -243,7 +241,7 @@ def _parse_value_pb(value_pb, field_type):
elif type_code == TypeCode.NUMERIC:
return decimal.Decimal(value_pb.string_value)
elif type_code == TypeCode.JSON:
return value_pb.string_value
return JsonObject.from_str(value_pb.string_value)
else:
raise ValueError("Unknown type: %s" % (field_type,))

Expand Down
33 changes: 32 additions & 1 deletion google/cloud/spanner_v1/data_types.py
Expand Up @@ -14,6 +14,8 @@

"""Custom data types for spanner."""

import json


class JsonObject(dict):
"""
Expand All @@ -22,4 +24,33 @@ class JsonObject(dict):
normal parameters and JSON parameters.
"""

pass
def __init__(self, *args, **kwargs):
self._is_null = (args, kwargs) == ((), {}) or args == (None,)
if not self._is_null:
super(JsonObject, self).__init__(*args, **kwargs)

@classmethod
def from_str(cls, str_repr):
"""Initiate an object from its `str` representation.

Args:
str_repr (str): JSON text representation.

Returns:
JsonObject: JSON object.
"""
if str_repr == "null":
return cls()

return cls(json.loads(str_repr))

def serialize(self):
"""Return the object text representation.

Returns:
str: JSON object text representation.
"""
if self._is_null:
return None

return json.dumps(self, sort_keys=True, separators=(",", ":"))
8 changes: 4 additions & 4 deletions samples/samples/snippets_test.py
Expand Up @@ -50,13 +50,13 @@ def sample_name():

@pytest.fixture(scope="module")
def create_instance_id():
""" Id for the low-cost instance. """
"""Id for the low-cost instance."""
return f"create-instance-{uuid.uuid4().hex[:10]}"


@pytest.fixture(scope="module")
def lci_instance_id():
""" Id for the low-cost instance. """
"""Id for the low-cost instance."""
return f"lci-instance-{uuid.uuid4().hex[:10]}"


Expand Down Expand Up @@ -91,7 +91,7 @@ def database_ddl():

@pytest.fixture(scope="module")
def default_leader():
""" Default leader for multi-region instances. """
"""Default leader for multi-region instances."""
return "us-east4"


Expand Down Expand Up @@ -582,7 +582,7 @@ def test_update_data_with_json(capsys, instance_id, sample_database):
def test_query_data_with_json_parameter(capsys, instance_id, sample_database):
snippets.query_data_with_json_parameter(instance_id, sample_database.database_id)
out, _ = capsys.readouterr()
assert "VenueId: 19, VenueDetails: {\"open\":true,\"rating\":9}" in out
assert "VenueId: 19, VenueDetails: {'open': True, 'rating': 9}" in out


@pytest.mark.dependency(depends=["insert_datatypes_data"])
Expand Down
2 changes: 1 addition & 1 deletion tests/system/test_dbapi.py
Expand Up @@ -364,7 +364,7 @@ def test_autocommit_with_json_data(shared_instance, dbapi_database):
# Assert the response
assert len(got_rows) == 1
assert got_rows[0][0] == 123
assert got_rows[0][1] == '{"age":"26","name":"Jakob"}'
assert got_rows[0][1] == {"age": "26", "name": "Jakob"}

# Drop the table
cur.execute("DROP TABLE JsonDetails")
Expand Down
14 changes: 4 additions & 10 deletions tests/system/test_session_api.py
Expand Up @@ -19,7 +19,6 @@
import struct
import threading
import time
import json
import pytest

import grpc
Expand All @@ -28,6 +27,7 @@
from google.api_core import exceptions
from google.cloud import spanner_v1
from google.cloud._helpers import UTC
from google.cloud.spanner_v1.data_types import JsonObject
from tests import _helpers as ot_helpers
from . import _helpers
from . import _sample_data
Expand All @@ -43,23 +43,17 @@
BYTES_2 = b"Ym9vdHM="
NUMERIC_1 = decimal.Decimal("0.123456789")
NUMERIC_2 = decimal.Decimal("1234567890")
JSON_1 = json.dumps(
JSON_1 = JsonObject(
{
"sample_boolean": True,
"sample_int": 872163,
"sample float": 7871.298,
"sample_null": None,
"sample_string": "abcdef",
"sample_array": [23, 76, 19],
},
sort_keys=True,
separators=(",", ":"),
)
JSON_2 = json.dumps(
{"sample_object": {"name": "Anamika", "id": 2635}},
sort_keys=True,
separators=(",", ":"),
}
)
JSON_2 = JsonObject({"sample_object": {"name": "Anamika", "id": 2635}},)

COUNTERS_TABLE = "counters"
COUNTERS_COLUMNS = ("name", "value")
Expand Down
16 changes: 12 additions & 4 deletions tests/unit/test__helpers.py
Expand Up @@ -567,14 +567,22 @@ def test_w_json(self):
from google.cloud.spanner_v1 import Type
from google.cloud.spanner_v1 import TypeCode

VALUE = json.dumps(
{"id": 27863, "Name": "Anamika"}, sort_keys=True, separators=(",", ":")
)
VALUE = {"id": 27863, "Name": "Anamika"}
str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":"))

field_type = Type(code=TypeCode.JSON)
value_pb = Value(string_value=VALUE)
value_pb = Value(string_value=str_repr)

self.assertEqual(self._callFUT(value_pb, field_type), VALUE)

VALUE = None
str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":"))

field_type = Type(code=TypeCode.JSON)
value_pb = Value(string_value=str_repr)

self.assertEqual(self._callFUT(value_pb, field_type), {})

def test_w_unknown_type(self):
from google.protobuf.struct_pb2 import Value
from google.cloud.spanner_v1 import Type
Expand Down