diff --git a/firestore/collgroupref.go b/firestore/collgroupref.go index 90986ad71f6..8d278bab053 100644 --- a/firestore/collgroupref.go +++ b/firestore/collgroupref.go @@ -54,6 +54,9 @@ func newCollectionGroupRef(c *Client, dbPath, collectionID string) *CollectionGr // // If a Collection Group Query would return a large number of documents, this // can help to subdivide the query to smaller working units that can be distributed. +// +// If the goal is to run the queries across processes or workers, it may be useful to use +// `Query.Serialize` and `Query.Deserialize` to serialize the query. func (cgr CollectionGroupRef) GetPartitionedQueries(ctx context.Context, partitionCount int) ([]Query, error) { qp, err := cgr.getPartitions(ctx, partitionCount) if err != nil { diff --git a/firestore/doc.go b/firestore/doc.go index 3c320f62805..c3272cd634d 100644 --- a/firestore/doc.go +++ b/firestore/doc.go @@ -193,6 +193,29 @@ as a query. iter = client.Collection("States").Documents(ctx) +Collection Group Partition Queries + +You can partition the documents of a Collection Group allowing for smaller subqueries. + + collectionGroup = client.CollectionGroup("States") + partitions, err = collectionGroup.GetPartitionedQueries(ctx, 20) + +You can also Serialize/Deserialize queries making it possible to run/stream the +queries elsewhere; another process or machine for instance. + + queryProtos := make([][]byte, 0) + for _, query := range partitions { + protoBytes, err := query.Serialize() + // handle err + queryProtos = append(queryProtos, protoBytes) + ... + } + + for _, protoBytes := range queryProtos { + query, err := client.CollectionGroup("").Deserialize(protoBytes) + ... + } + Transactions Use a transaction to execute reads and writes atomically. All reads must happen diff --git a/firestore/integration_test.go b/firestore/integration_test.go index 9d823be9b60..7e1075a1d6c 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -1603,7 +1603,6 @@ func TestDetectProjectID(t *testing.T) { } func TestIntegration_ColGroupRefPartitions(t *testing.T) { - t.Skip("https://github.com/googleapis/google-cloud-go/issues/4325") h := testHelper{t} coll := integrationColl(t) ctx := context.Background() @@ -1622,7 +1621,7 @@ func TestIntegration_ColGroupRefPartitions(t *testing.T) { {collectionID: coll.collectionID, expectedPartitionCount: 1}, } { colGroup := iClient.CollectionGroup(tc.collectionID) - partitions, err := colGroup.getPartitions(ctx, 10) + partitions, err := colGroup.GetPartitionedQueries(ctx, 10) if err != nil { t.Fatalf("getPartitions: received unexpected error: %v", err) } @@ -1675,12 +1674,30 @@ func TestIntegration_ColGroupRefPartitionsLarge(t *testing.T) { // Verify that we retrieve 383 documents across all partitions. (128*2 + 127) totalCount := 0 for _, query := range partitions { - allDocs, err := query.Documents(ctx).GetAll() if err != nil { t.Fatalf("GetAll(): received unexpected error: %v", err) } totalCount += len(allDocs) + + // Verify that serialization round-trips. Check that the same results are + // returned even if we use the proto converted query + queryBytes, err := query.Serialize() + if err != nil { + t.Fatalf("Serialize error: %v", err) + } + q, err := iClient.CollectionGroup("DNE").Deserialize(queryBytes) + if err != nil { + t.Fatalf("Deserialize error: %v", err) + } + + protoReturnedDocs, err := q.Documents(ctx).GetAll() + if err != nil { + t.Fatalf("GetAll error: %v", err) + } + if len(allDocs) != len(protoReturnedDocs) { + t.Fatalf("Expected document count to be the same on both query runs: %v", err) + } } if got, want := totalCount, documentCount; got != want { diff --git a/firestore/query.go b/firestore/query.go index 0353b6bc796..e8df3eb2bf9 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -25,6 +25,7 @@ import ( "cloud.google.com/go/internal/btree" "cloud.google.com/go/internal/trace" + "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes/wrappers" "google.golang.org/api/iterator" pb "google.golang.org/genproto/googleapis/firestore/v1" @@ -39,8 +40,8 @@ type Query struct { path string // path to query (collection) parentPath string // path of the collection's parent (document) collectionID string - selection []FieldPath - filters []filter + selection []*pb.StructuredQuery_FieldReference + filters []*pb.StructuredQuery_Filter orders []order offset int32 limit *wrappers.Int32Value @@ -83,10 +84,26 @@ func (q Query) Select(paths ...string) Query { // // An empty SelectPaths call will produce a query that returns only document IDs. func (q Query) SelectPaths(fieldPaths ...FieldPath) Query { + if len(fieldPaths) == 0 { - q.selection = []FieldPath{{DocumentID}} + ref, err := fref(FieldPath{DocumentID}) + if err != nil { + q.err = err + return q + } + q.selection = []*pb.StructuredQuery_FieldReference{ + ref, + } } else { - q.selection = fieldPaths + q.selection = make([]*pb.StructuredQuery_FieldReference, len(fieldPaths)) + for i, fieldPath := range fieldPaths { + ref, err := fref(fieldPath) + if err != nil { + q.err = err + return q + } + q.selection[i] = ref + } } return q } @@ -103,8 +120,7 @@ func (q Query) Where(path, op string, value interface{}) Query { q.err = err return q } - q.filters = append(append([]filter(nil), q.filters...), filter{fp, op, value}) - return q + return q.WherePath(fp, op, value) } // WherePath returns a new Query that filters the set of results. @@ -112,7 +128,12 @@ func (q Query) Where(path, op string, value interface{}) Query { // The op argument must be one of "==", "!=", "<", "<=", ">", ">=", // "array-contains", "array-contains-any", "in" or "not-in". func (q Query) WherePath(fp FieldPath, op string, value interface{}) Query { - q.filters = append(append([]filter(nil), q.filters...), filter{fp, op, value}) + proto, err := filter{fp, op, value}.toProto() + if err != nil { + q.err = err + return q + } + q.filters = append(append([]*pb.StructuredQuery_Filter(nil), q.filters...), proto) return q } @@ -141,7 +162,7 @@ func (q Query) OrderBy(path string, dir Direction) Query { q.err = err return q } - q.orders = append(q.copyOrders(), order{fp, dir}) + q.orders = append(q.copyOrders(), order{fieldPath: fp, dir: dir}) return q } @@ -149,7 +170,7 @@ func (q Query) OrderBy(path string, dir Direction) Query { // returned. A Query can have multiple OrderBy/OrderByPath specifications. // OrderByPath appends the specification to the list of existing ones. func (q Query) OrderByPath(fp FieldPath, dir Direction) Query { - q.orders = append(q.copyOrders(), order{fp, dir}) + q.orders = append(q.copyOrders(), order{fieldPath: fp, dir: dir}) return q } @@ -250,6 +271,145 @@ func (q *Query) processCursorArg(name string, docSnapshotOrFieldValues []interfa func (q Query) query() *Query { return &q } +// Serialize creates a RunQueryRequest wire-format byte slice from a Query object. +// This can be used in combination with Deserialize to marshal Query objects. +// This could be useful, for instance, if executing a query formed in one +// process in another. +func (q Query) Serialize() ([]byte, error) { + structuredQuery, err := q.toProto() + if err != nil { + return nil, err + } + + p := &pb.RunQueryRequest{ + Parent: q.parentPath, + QueryType: &pb.RunQueryRequest_StructuredQuery{StructuredQuery: structuredQuery}, + } + + return proto.Marshal(p) +} + +// Deserialize takes a slice of bytes holding the wire-format message of RunQueryRequest, +// the underlying proto message used by Queries. It then populates and returns a +// Query object that can be used to execut that Query. +func (q Query) Deserialize(bytes []byte) (Query, error) { + runQueryRequest := pb.RunQueryRequest{} + err := proto.Unmarshal(bytes, &runQueryRequest) + if err != nil { + q.err = err + return q, err + } + return q.fromProto(&runQueryRequest) +} + +// fromProto creates a new Query object from a RunQueryRequest. This can be used +// in combination with ToProto to serialize Query objects. This could be useful, +// for instance, if executing a query formed in one process in another. +func (q Query) fromProto(pbQuery *pb.RunQueryRequest) (Query, error) { + // Ensure we are starting from an empty query, but with this client. + q = Query{c: q.c} + + pbq := pbQuery.GetStructuredQuery() + if from := pbq.GetFrom(); len(from) > 0 { + if len(from) > 1 { + err := errors.New("can only deserialize query with exactly one collection selector") + q.err = err + return q, err + } + + // collectionID string + q.collectionID = from[0].CollectionId + // allDescendants indicates whether this query is for all collections + // that match the ID under the specified parentPath. + q.allDescendants = from[0].AllDescendants + } + + // path string // path to query (collection) + // parentPath string // path of the collection's parent (document) + parent := pbQuery.GetParent() + q.parentPath = parent + q.path = parent + "/" + q.collectionID + + // startVals, endVals []interface{} + // startDoc, endDoc *DocumentSnapshot + // startBefore, endBefore bool + if startAt := pbq.GetStartAt(); startAt != nil { + if startAt.GetBefore() { + q.startBefore = true + } + for _, v := range startAt.GetValues() { + c, err := createFromProtoValue(v, q.c) + if err != nil { + q.err = err + return q, err + } + + var newQ Query + if startAt.GetBefore() { + newQ = q.StartAt(c) + } else { + newQ = q.StartAfter(c) + } + + q.startVals = append(q.startVals, newQ.startVals...) + } + } + if endAt := pbq.GetEndAt(); endAt != nil { + for _, v := range endAt.GetValues() { + c, err := createFromProtoValue(v, q.c) + + if err != nil { + q.err = err + return q, err + } + + var newQ Query + if endAt.GetBefore() { + newQ = q.EndBefore(c) + q.endBefore = true + } else { + newQ = q.EndAt(c) + } + q.endVals = append(q.endVals, newQ.endVals...) + + } + } + + // selection []*pb.StructuredQuery_FieldReference + if s := pbq.GetSelect(); s != nil { + q.selection = s.GetFields() + } + + // filters []*pb.StructuredQuery_Filter + if w := pbq.GetWhere(); w != nil { + if cf := w.GetCompositeFilter(); cf != nil { + q.filters = cf.GetFilters() + } else { + q.filters = []*pb.StructuredQuery_Filter{w} + } + } + + // orders []order + if orderBy := pbq.GetOrderBy(); orderBy != nil { + for _, v := range orderBy { + fp := v.GetField() + q.orders = append(q.orders, order{fieldReference: fp, dir: Direction(v.GetDirection())}) + } + } + + // offset int32 + q.offset = pbq.GetOffset() + + // limit *wrappers.Int32Value + if limit := pbq.GetLimit(); limit != nil { + q.limit = limit + } + + // NOTE: limit to last isn't part of the proto, this is a client-side concept + // limitToLast bool + return q, q.err +} + func (q Query) toProto() (*pb.StructuredQuery, error) { if q.err != nil { return nil, q.err @@ -277,20 +437,13 @@ func (q Query) toProto() (*pb.StructuredQuery, error) { } if len(q.selection) > 0 { p.Select = &pb.StructuredQuery_Projection{} - for _, fp := range q.selection { - if err := fp.validate(); err != nil { - return nil, err - } - p.Select.Fields = append(p.Select.Fields, fref(fp)) - } + p.Select.Fields = q.selection } // If there is only filter, use it directly. Otherwise, construct // a CompositeFilter. if len(q.filters) == 1 { - pf, err := q.filters[0].toProto() - if err != nil { - return nil, err - } + pf := q.filters[0] + p.Where = pf } else if len(q.filters) > 1 { cf := &pb.StructuredQuery_CompositeFilter{ @@ -299,13 +452,7 @@ func (q Query) toProto() (*pb.StructuredQuery, error) { p.Where = &pb.StructuredQuery_Filter{ FilterType: &pb.StructuredQuery_Filter_CompositeFilter{cf}, } - for _, f := range q.filters { - pf, err := f.toProto() - if err != nil { - return nil, err - } - cf.Filters = append(cf.Filters, pf) - } + cf.Filters = append(cf.Filters, q.filters...) } orders := q.orders if q.startDoc != nil || q.endDoc != nil { @@ -353,9 +500,12 @@ func (q *Query) adjustOrders() []order { // for the field of the first inequality. var orders []order for _, f := range q.filters { - if f.op != "==" { - orders = []order{{fieldPath: f.fieldPath, dir: Asc}} - break + if fieldFilter := f.GetFieldFilter(); fieldFilter != nil { + if fieldFilter.Op != pb.StructuredQuery_FieldFilter_EQUAL { + fp := f.GetFieldFilter().Field + orders = []order{{fieldReference: fp, dir: Asc}} + break + } } } // Add an ascending OrderBy(DocumentID). @@ -388,22 +538,26 @@ func (q *Query) fieldValuesToCursorValues(fieldValues []interface{}) ([]*pb.Valu for i, ord := range q.orders { fval := fieldValues[i] if ord.isDocumentID() { - // TODO(jba): support DocumentRefs as well as strings. // TODO(jba): error if document ref does not belong to the right collection. - docID, ok := fval.(string) - if !ok { + + switch docID := fval.(type) { + case string: + vals[i] = &pb.Value{ValueType: &pb.Value_ReferenceValue{q.path + "/" + docID}} + continue + case *DocumentRef: + // DocumentRef can be transformed in usual way. + default: return nil, fmt.Errorf("firestore: expected doc ID for DocumentID field, got %T", fval) } - vals[i] = &pb.Value{ValueType: &pb.Value_ReferenceValue{q.path + "/" + docID}} - } else { - var sawTransform bool - vals[i], sawTransform, err = toProtoValue(reflect.ValueOf(fval)) - if err != nil { - return nil, err - } - if sawTransform { - return nil, errors.New("firestore: transforms disallowed in query value") - } + } + + var sawTransform bool + vals[i], sawTransform, err = toProtoValue(reflect.ValueOf(fval)) + if err != nil { + return nil, err + } + if sawTransform { + return nil, errors.New("firestore: transforms disallowed in query value") } } return vals, nil @@ -419,7 +573,18 @@ func (q *Query) docSnapshotToCursorValues(ds *DocumentSnapshot, orders []order) } vals[i] = &pb.Value{ValueType: &pb.Value_ReferenceValue{ds.Ref.Path}} } else { - val, err := valueAtPath(ord.fieldPath, ds.proto.Fields) + var val *pb.Value + var err error + if len(ord.fieldPath) > 0 { + val, err = valueAtPath(ord.fieldPath, ds.proto.Fields) + } else { + // parse the field reference field path so we can use it to look up + fp, err := parseDotSeparatedString(ord.fieldReference.FieldPath) + if err != nil { + return nil, err + } + val, err = valueAtPath(fp, ds.proto.Fields) + } if err != nil { return nil, err } @@ -436,7 +601,7 @@ func (q Query) compareFunc() func(d1, d2 *DocumentSnapshot) (int, error) { if len(q.orders) > 0 { lastDir = q.orders[len(q.orders)-1].dir } - orders := append(q.copyOrders(), order{[]string{DocumentID}, lastDir}) + orders := append(q.copyOrders(), order{fieldPath: []string{DocumentID}, dir: lastDir}) return func(d1, d2 *DocumentSnapshot) (int, error) { for _, ord := range orders { var cmp int @@ -478,11 +643,15 @@ func (f filter) toProto() (*pb.StructuredQuery_Filter, error) { if f.op != "==" { return nil, fmt.Errorf("firestore: must use '==' when comparing %v", f.value) } + ref, err := fref(f.fieldPath) + if err != nil { + return nil, err + } return &pb.StructuredQuery_Filter{ FilterType: &pb.StructuredQuery_Filter_UnaryFilter{ UnaryFilter: &pb.StructuredQuery_UnaryFilter{ OperandType: &pb.StructuredQuery_UnaryFilter_Field{ - Field: fref(f.fieldPath), + Field: ref, }, Op: uop, }, @@ -521,10 +690,14 @@ func (f filter) toProto() (*pb.StructuredQuery_Filter, error) { if sawTransform { return nil, errors.New("firestore: transforms disallowed in query value") } + ref, err := fref(f.fieldPath) + if err != nil { + return nil, err + } return &pb.StructuredQuery_Filter{ FilterType: &pb.StructuredQuery_Filter_FieldFilter{ FieldFilter: &pb.StructuredQuery_FieldFilter{ - Field: fref(f.fieldPath), + Field: ref, Op: op, Value: val, }, @@ -555,26 +728,43 @@ func isNaN(x interface{}) bool { } type order struct { - fieldPath FieldPath - dir Direction + fieldPath FieldPath + fieldReference *pb.StructuredQuery_FieldReference + dir Direction } func (r order) isDocumentID() bool { + if r.fieldReference != nil { + return r.fieldReference.GetFieldPath() == DocumentID + } return len(r.fieldPath) == 1 && r.fieldPath[0] == DocumentID } func (r order) toProto() (*pb.StructuredQuery_Order, error) { - if err := r.fieldPath.validate(); err != nil { + if r.fieldReference != nil { + return &pb.StructuredQuery_Order{ + Field: r.fieldReference, + Direction: pb.StructuredQuery_Direction(r.dir), + }, nil + } + + field, err := fref(r.fieldPath) + if err != nil { return nil, err } + return &pb.StructuredQuery_Order{ - Field: fref(r.fieldPath), + Field: field, Direction: pb.StructuredQuery_Direction(r.dir), }, nil } -func fref(fp FieldPath) *pb.StructuredQuery_FieldReference { - return &pb.StructuredQuery_FieldReference{FieldPath: fp.toServiceFieldPath()} +func fref(fp FieldPath) (*pb.StructuredQuery_FieldReference, error) { + err := fp.validate() + if err != nil { + return &pb.StructuredQuery_FieldReference{}, err + } + return &pb.StructuredQuery_FieldReference{FieldPath: fp.toServiceFieldPath()}, nil } func trunc32(i int) int32 { diff --git a/firestore/query_test.go b/firestore/query_test.go index ca8c0061442..b97e446ddb7 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -16,6 +16,7 @@ package firestore import ( "context" + "fmt" "math" "sort" "testing" @@ -23,7 +24,9 @@ import ( "cloud.google.com/go/internal/pretty" tspb "github.com/golang/protobuf/ptypes/timestamp" "github.com/golang/protobuf/ptypes/wrappers" + "github.com/google/go-cmp/cmp" pb "google.golang.org/genproto/googleapis/firestore/v1" + "google.golang.org/protobuf/testing/protocmp" ) func TestFilterToProto(t *testing.T) { @@ -74,7 +77,14 @@ func TestFilterToProto(t *testing.T) { } } -func TestQueryToProto(t *testing.T) { +type toProtoScenario struct { + desc string + in Query + want *pb.StructuredQuery +} + +// Creates protos used to test toProto, FromProto, ToProto funcs. +func createTestScenarios(t *testing.T) []toProtoScenario { filtr := func(path []string, op string, val interface{}) *pb.StructuredQuery_Filter { f, err := filter{path, op, val}.toProto() if err != nil { @@ -92,11 +102,8 @@ func TestQueryToProto(t *testing.T) { Fields: map[string]*pb.Value{"a": intval(7), "b": intval(8), "c": arrayval(intval(1), intval(2))}, }, } - for _, test := range []struct { - desc string - in Query - want *pb.StructuredQuery - }{ + + return []toProtoScenario{ { desc: "q.Select()", in: q.Select(), @@ -435,7 +442,11 @@ func TestQueryToProto(t *testing.T) { }, }, }, - } { + } +} + +func TestQueryToProto(t *testing.T) { + for _, test := range createTestScenarios(t) { got, err := test.in.toProto() if err != nil { t.Fatalf("%s: %v", test.desc, err) @@ -448,8 +459,41 @@ func TestQueryToProto(t *testing.T) { } } +// Convert a Query to a Proto and back again verifying roundtripping +func TestQueryFromProtoRoundTrip(t *testing.T) { + c := &Client{projectID: "P", databaseID: "DB"} + + for _, test := range createTestScenarios(t) { + fmt.Println(test.desc) + proto, err := test.in.Serialize() + if err != nil { + t.Fatalf("%s: %v", test.desc, err) + continue + } + fmt.Printf("proto: %v\n", proto) + got, err := Query{c: c}.Deserialize(proto) + if err != nil { + t.Fatalf("%s: %v", test.desc, err) + continue + } + + want := test.in + gotProto, err := got.Serialize() + fmt.Println(gotProto) + if err != nil { + t.Fatalf("%s: %v", test.desc, err) + } + + // Compare protos before and after taking to a query. proto -> query -> proto. + if diff := cmp.Diff(gotProto, proto, protocmp.Transform()); diff != "" { + t.Errorf("%s:\ngot\n%v\nwant\n%v\ndiff\n%v", test.desc, pretty.Value(got), pretty.Value(want), diff) + } + } +} + func fref1(s string) *pb.StructuredQuery_FieldReference { - return fref([]string{s}) + ref, _ := fref([]string{s}) + return ref } func TestQueryToProtoErrors(t *testing.T) { @@ -568,13 +612,16 @@ func TestQueryFromCollectionRef(t *testing.T) { c := &Client{projectID: "P", databaseID: "D"} coll := c.Collection("C") got := coll.Select("x").Offset(8) + ref, _ := fref(FieldPath{"x"}) want := Query{ c: c, parentPath: c.path() + "/documents", path: "projects/P/databases/D/documents/C", collectionID: "C", - selection: []FieldPath{{"x"}}, - offset: 8, + selection: []*pb.StructuredQuery_FieldReference{ + ref, + }, + offset: 8, } if !testEqual(got, want) { t.Fatalf("\ngot %+v, \nwant %+v", got, want)