From c57c1956f01257ec425f6e25ecd462926f4e8811 Mon Sep 17 00:00:00 2001 From: Mattie Fu Date: Mon, 26 Jul 2021 18:39:59 +0000 Subject: [PATCH] feat: add extra contexts to ApiCallContext --- .../google/api/gax/grpc/GrpcCallContext.java | 54 +++++++++++++- .../api/gax/grpc/GrpcCallContextTest.java | 42 +++++++++++ .../api/gax/httpjson/HttpJsonCallContext.java | 54 +++++++++++++- .../gax/httpjson/HttpJsonCallContextTest.java | 42 +++++++++++ .../google/api/gax/rpc/ApiCallContext.java | 43 +++++++++++ .../gax/rpc/internal/ApiCallContextUtil.java | 71 +++++++++++++++++++ .../api/gax/rpc/testing/FakeCallContext.java | 56 ++++++++++++++- 7 files changed, 358 insertions(+), 4 deletions(-) create mode 100644 gax/src/main/java/com/google/api/gax/rpc/internal/ApiCallContextUtil.java diff --git a/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java b/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java index fdcdd9588b..f79405e707 100644 --- a/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java +++ b/gax-grpc/src/main/java/com/google/api/gax/grpc/GrpcCallContext.java @@ -34,6 +34,7 @@ import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.TransportChannel; +import com.google.api.gax.rpc.internal.ApiCallContextUtil; import com.google.api.gax.rpc.internal.Headers; import com.google.api.gax.tracing.ApiTracer; import com.google.api.gax.tracing.BaseApiTracer; @@ -43,7 +44,6 @@ import com.google.common.collect.ImmutableSet; import io.grpc.CallCredentials; import io.grpc.CallOptions; -import io.grpc.CallOptions.Key; import io.grpc.Channel; import io.grpc.Deadline; import io.grpc.Metadata; @@ -66,7 +66,7 @@ */ @BetaApi("Reference ApiCallContext instead - this class is likely to experience breaking changes") public final class GrpcCallContext implements ApiCallContext { - static final CallOptions.Key TRACER_KEY = Key.create("gax.tracer"); + static final CallOptions.Key TRACER_KEY = CallOptions.Key.create("gax.tracer"); private final Channel channel; private final CallOptions callOptions; @@ -77,6 +77,7 @@ public final class GrpcCallContext implements ApiCallContext { @Nullable private final RetrySettings retrySettings; @Nullable private final ImmutableSet retryableCodes; private final ImmutableMap> extraHeaders; + private final ImmutableMap extraContexts; /** Returns an empty instance with a null channel and default {@link CallOptions}. */ public static GrpcCallContext createDefault() { @@ -88,6 +89,7 @@ public static GrpcCallContext createDefault() { null, null, ImmutableMap.>of(), + ImmutableMap.of(), null, null); } @@ -102,6 +104,7 @@ public static GrpcCallContext of(Channel channel, CallOptions callOptions) { null, null, ImmutableMap.>of(), + ImmutableMap.of(), null, null); } @@ -114,6 +117,7 @@ private GrpcCallContext( @Nullable Duration streamIdleTimeout, @Nullable Integer channelAffinity, ImmutableMap> extraHeaders, + ImmutableMap extraContexts, @Nullable RetrySettings retrySettings, @Nullable Set retryableCodes) { this.channel = channel; @@ -123,6 +127,7 @@ private GrpcCallContext( this.streamIdleTimeout = streamIdleTimeout; this.channelAffinity = channelAffinity; this.extraHeaders = Preconditions.checkNotNull(extraHeaders); + this.extraContexts = Preconditions.checkNotNull(extraContexts); this.retrySettings = retrySettings; this.retryableCodes = retryableCodes == null ? null : ImmutableSet.copyOf(retryableCodes); } @@ -187,6 +192,7 @@ public GrpcCallContext withTimeout(@Nullable Duration timeout) { this.streamIdleTimeout, this.channelAffinity, this.extraHeaders, + this.extraContexts, this.retrySettings, this.retryableCodes); } @@ -212,6 +218,7 @@ public GrpcCallContext withStreamWaitTimeout(@Nullable Duration streamWaitTimeou this.streamIdleTimeout, this.channelAffinity, this.extraHeaders, + this.extraContexts, this.retrySettings, this.retryableCodes); } @@ -231,6 +238,7 @@ public GrpcCallContext withStreamIdleTimeout(@Nullable Duration streamIdleTimeou streamIdleTimeout, this.channelAffinity, this.extraHeaders, + this.extraContexts, this.retrySettings, this.retryableCodes); } @@ -245,6 +253,7 @@ public GrpcCallContext withChannelAffinity(@Nullable Integer affinity) { this.streamIdleTimeout, affinity, this.extraHeaders, + this.extraContexts, this.retrySettings, this.retryableCodes); } @@ -263,6 +272,7 @@ public GrpcCallContext withExtraHeaders(Map> extraHeaders) this.streamIdleTimeout, this.channelAffinity, newExtraHeaders, + this.extraContexts, this.retrySettings, this.retryableCodes); } @@ -282,6 +292,7 @@ public GrpcCallContext withRetrySettings(RetrySettings retrySettings) { this.streamIdleTimeout, this.channelAffinity, this.extraHeaders, + this.extraContexts, retrySettings, this.retryableCodes); } @@ -301,6 +312,7 @@ public GrpcCallContext withRetryableCodes(Set retryableCodes) { this.streamIdleTimeout, this.channelAffinity, this.extraHeaders, + this.extraContexts, this.retrySettings, retryableCodes); } @@ -370,6 +382,9 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { ImmutableMap> newExtraHeaders = Headers.mergeHeaders(this.extraHeaders, grpcCallContext.extraHeaders); + ImmutableMap newExtraContexts = + ApiCallContextUtil.mergeExtraContexts(this.extraContexts, grpcCallContext.extraContexts); + CallOptions newCallOptions = grpcCallContext .callOptions @@ -388,6 +403,7 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newStreamIdleTimeout, newChannelAffinity, newExtraHeaders, + newExtraContexts, newRetrySettings, newRetryableCodes); } @@ -448,6 +464,7 @@ public GrpcCallContext withChannel(Channel newChannel) { this.streamIdleTimeout, this.channelAffinity, this.extraHeaders, + this.extraContexts, this.retrySettings, this.retryableCodes); } @@ -462,6 +479,7 @@ public GrpcCallContext withCallOptions(CallOptions newCallOptions) { this.streamIdleTimeout, this.channelAffinity, this.extraHeaders, + this.extraContexts, this.retrySettings, this.retryableCodes); } @@ -491,6 +509,36 @@ public GrpcCallContext withTracer(@Nonnull ApiTracer tracer) { return withCallOptions(callOptions.withOption(TRACER_KEY, tracer)); } + /** {@inheritDoc} */ + @Override + public GrpcCallContext withExtraContext(Key key, T extraContext) { + Preconditions.checkNotNull(key); + + ImmutableMap newExtraContexts = + ApiCallContextUtil.addExtraContext(extraContexts, key, extraContext); + return new GrpcCallContext( + this.channel, + this.callOptions, + this.timeout, + this.streamWaitTimeout, + this.streamIdleTimeout, + this.channelAffinity, + this.extraHeaders, + newExtraContexts, + this.retrySettings, + this.retryableCodes); + } + + /** {@inheritDoc} */ + @Override + public T getExtraContext(Key key) { + Preconditions.checkNotNull(key); + if (extraContexts.containsKey(key)) { + return (T) extraContexts.get(key); + } + return key.getDefault(); + } + @Override public int hashCode() { return Objects.hash( @@ -501,6 +549,7 @@ public int hashCode() { streamIdleTimeout, channelAffinity, extraHeaders, + extraContexts, retrySettings, retryableCodes); } @@ -522,6 +571,7 @@ public boolean equals(Object o) { && Objects.equals(this.streamIdleTimeout, that.streamIdleTimeout) && Objects.equals(this.channelAffinity, that.channelAffinity) && Objects.equals(this.extraHeaders, that.extraHeaders) + && Objects.equals(this.extraContexts, that.extraContexts) && Objects.equals(this.retrySettings, that.retrySettings) && Objects.equals(this.retryableCodes, that.retryableCodes); } diff --git a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java index 1b9b5c1871..9d4e235502 100644 --- a/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java +++ b/gax-grpc/src/test/java/com/google/api/gax/grpc/GrpcCallContextTest.java @@ -354,6 +354,48 @@ public void testWithRetryableCodes() { assertNotNull(context.getRetryableCodes()); } + @Test + public void testWithExtraContext() { + GrpcCallContext emptyCallContext = GrpcCallContext.createDefault(); + ApiCallContext.Key contextKey1 = ApiCallContext.Key.create("testKey1"); + ApiCallContext.Key contextKey2 = ApiCallContext.Key.create("testKey2"); + String testContext1 = "test1"; + String testContext2 = "test2"; + String testContextOverwrite = "test1Overwrite"; + GrpcCallContext context = + emptyCallContext + .withExtraContext(contextKey1, testContext1) + .withExtraContext(contextKey2, testContext2); + assertEquals(testContext1, context.getExtraContext(contextKey1)); + assertEquals(testContext2, context.getExtraContext(contextKey2)); + GrpcCallContext newContext = context.withExtraContext(contextKey1, testContextOverwrite); + assertEquals(testContextOverwrite, newContext.getExtraContext(contextKey1)); + } + + @Test + public void testMergeExtraContext() { + GrpcCallContext emptyCallContext = GrpcCallContext.createDefault(); + ApiCallContext.Key contextKey1 = ApiCallContext.Key.create("testKey1"); + ApiCallContext.Key contextKey2 = ApiCallContext.Key.create("testKey2"); + ApiCallContext.Key contextKey3 = ApiCallContext.Key.create("testKey3"); + String testContext1 = "test1"; + String testContext2 = "test2"; + String testContext3 = "test3"; + String testContextOverwrite = "test1Overwrite"; + GrpcCallContext context1 = + emptyCallContext + .withExtraContext(contextKey1, testContext1) + .withExtraContext(contextKey2, testContext2); + GrpcCallContext context2 = + emptyCallContext + .withExtraContext(contextKey1, testContextOverwrite) + .withExtraContext(contextKey3, testContext3); + ApiCallContext mergedContext = context1.merge(context2); + assertEquals(testContextOverwrite, mergedContext.getExtraContext(contextKey1)); + assertEquals(testContext2, mergedContext.getExtraContext(contextKey2)); + assertEquals(testContext3, mergedContext.getExtraContext(contextKey3)); + } + private static Map> createTestExtraHeaders(String... keyValues) { Map> extraHeaders = new HashMap<>(); for (int i = 0; i < keyValues.length; i += 2) { diff --git a/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java b/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java index 700716217f..81fcfcd0d9 100644 --- a/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java +++ b/gax-httpjson/src/main/java/com/google/api/gax/httpjson/HttpJsonCallContext.java @@ -34,6 +34,7 @@ import com.google.api.gax.rpc.ApiCallContext; import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.TransportChannel; +import com.google.api.gax.rpc.internal.ApiCallContextUtil; import com.google.api.gax.rpc.internal.Headers; import com.google.api.gax.tracing.ApiTracer; import com.google.api.gax.tracing.BaseApiTracer; @@ -65,6 +66,7 @@ public final class HttpJsonCallContext implements ApiCallContext { private final Instant deadline; private final Credentials credentials; private final ImmutableMap> extraHeaders; + private final ImmutableMap extraContexts; private final ApiTracer tracer; private final RetrySettings retrySettings; private final ImmutableSet retryableCodes; @@ -72,7 +74,15 @@ public final class HttpJsonCallContext implements ApiCallContext { /** Returns an empty instance. */ public static HttpJsonCallContext createDefault() { return new HttpJsonCallContext( - null, null, null, null, ImmutableMap.>of(), null, null, null); + null, + null, + null, + null, + ImmutableMap.>of(), + ImmutableMap.of(), + null, + null, + null); } private HttpJsonCallContext( @@ -81,6 +91,7 @@ private HttpJsonCallContext( Instant deadline, Credentials credentials, ImmutableMap> extraHeaders, + ImmutableMap extraContexts, ApiTracer tracer, RetrySettings defaultRetrySettings, Set defaultRetryableCodes) { @@ -89,6 +100,7 @@ private HttpJsonCallContext( this.deadline = deadline; this.credentials = credentials; this.extraHeaders = extraHeaders; + this.extraContexts = extraContexts; this.tracer = tracer; this.retrySettings = defaultRetrySettings; this.retryableCodes = @@ -152,6 +164,9 @@ public HttpJsonCallContext merge(ApiCallContext inputCallContext) { ImmutableMap> newExtraHeaders = Headers.mergeHeaders(extraHeaders, httpJsonCallContext.extraHeaders); + ImmutableMap newExtraContext = + ApiCallContextUtil.mergeExtraContexts(extraContexts, httpJsonCallContext.extraContexts); + ApiTracer newTracer = httpJsonCallContext.tracer; if (newTracer == null) { newTracer = this.tracer; @@ -173,6 +188,7 @@ public HttpJsonCallContext merge(ApiCallContext inputCallContext) { newDeadline, newCredentials, newExtraHeaders, + newExtraContext, newTracer, newRetrySettings, newRetryableCodes); @@ -186,6 +202,7 @@ public HttpJsonCallContext withCredentials(Credentials newCredentials) { this.deadline, newCredentials, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -220,6 +237,7 @@ public HttpJsonCallContext withTimeout(Duration timeout) { this.deadline, this.credentials, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -265,6 +283,7 @@ public ApiCallContext withExtraHeaders(Map> extraHeaders) { this.deadline, this.credentials, newExtraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -276,6 +295,32 @@ public Map> getExtraHeaders() { return extraHeaders; } + /** {@inheritDoc} */ + @Override + public ApiCallContext withExtraContext(Key key, T extraContext) { + ImmutableMap newExtraContexts = + ApiCallContextUtil.addExtraContext(extraContexts, key, extraContext); + return new HttpJsonCallContext( + this.channel, + this.timeout, + this.deadline, + this.credentials, + this.extraHeaders, + newExtraContexts, + this.tracer, + this.retrySettings, + this.retryableCodes); + } + + /** {@inheritDoc} */ + @Override + public T getExtraContext(Key key) { + if (extraContexts.containsKey(key)) { + return (T) extraContexts.get(key); + } + return key.getDefault(); + } + public HttpJsonChannel getChannel() { return channel; } @@ -301,6 +346,7 @@ public HttpJsonCallContext withRetrySettings(RetrySettings retrySettings) { this.deadline, this.credentials, this.extraHeaders, + this.extraContexts, this.tracer, retrySettings, this.retryableCodes); @@ -319,6 +365,7 @@ public HttpJsonCallContext withRetryableCodes(Set retryableCode this.deadline, this.credentials, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, retryableCodes); @@ -331,6 +378,7 @@ public HttpJsonCallContext withChannel(HttpJsonChannel newChannel) { this.deadline, this.credentials, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -343,6 +391,7 @@ public HttpJsonCallContext withDeadline(Instant newDeadline) { newDeadline, this.credentials, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -368,6 +417,7 @@ public HttpJsonCallContext withTracer(@Nonnull ApiTracer newTracer) { this.deadline, this.credentials, this.extraHeaders, + this.extraContexts, newTracer, this.retrySettings, this.retryableCodes); @@ -387,6 +437,7 @@ public boolean equals(Object o) { && Objects.equals(this.deadline, that.deadline) && Objects.equals(this.credentials, that.credentials) && Objects.equals(this.extraHeaders, that.extraHeaders) + && Objects.equals(this.extraContexts, that.extraContexts) && Objects.equals(this.tracer, that.tracer) && Objects.equals(this.retrySettings, that.retrySettings) && Objects.equals(this.retryableCodes, that.retryableCodes); @@ -400,6 +451,7 @@ public int hashCode() { deadline, credentials, extraHeaders, + extraContexts, tracer, retrySettings, retryableCodes); diff --git a/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonCallContextTest.java b/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonCallContextTest.java index 245ace65ca..dea746f2c0 100644 --- a/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonCallContextTest.java +++ b/gax-httpjson/src/test/java/com/google/api/gax/httpjson/HttpJsonCallContextTest.java @@ -229,4 +229,46 @@ public void testWithExtraHeaders() { ApiCallContext context = emptyContext.withExtraHeaders(headers); assertEquals(headers, context.getExtraHeaders()); } + + @Test + public void testWithExtraContext() { + ApiCallContext emptyCallContext = HttpJsonCallContext.createDefault(); + ApiCallContext.Key contextKey1 = ApiCallContext.Key.create("testKey1"); + ApiCallContext.Key contextKey2 = ApiCallContext.Key.create("testKey2"); + String testContext1 = "test1"; + String testContext2 = "test2"; + String testContextOverwrite = "test1Overwrite"; + ApiCallContext context = + emptyCallContext + .withExtraContext(contextKey1, testContext1) + .withExtraContext(contextKey2, testContext2); + assertEquals(testContext1, context.getExtraContext(contextKey1)); + assertEquals(testContext2, context.getExtraContext(contextKey2)); + ApiCallContext newContext = context.withExtraContext(contextKey1, testContextOverwrite); + assertEquals(testContextOverwrite, newContext.getExtraContext(contextKey1)); + } + + @Test + public void testMergeExtraContext() { + ApiCallContext emptyCallContext = HttpJsonCallContext.createDefault(); + ApiCallContext.Key contextKey1 = ApiCallContext.Key.create("testKey1"); + ApiCallContext.Key contextKey2 = ApiCallContext.Key.create("testKey2"); + ApiCallContext.Key contextKey3 = ApiCallContext.Key.create("testKey3"); + String testContext1 = "test1"; + String testContext2 = "test2"; + String testContext3 = "test3"; + String testContextOverwrite = "test1Overwrite"; + ApiCallContext context1 = + emptyCallContext + .withExtraContext(contextKey1, testContext1) + .withExtraContext(contextKey2, testContext2); + ApiCallContext context2 = + emptyCallContext + .withExtraContext(contextKey1, testContextOverwrite) + .withExtraContext(contextKey3, testContext3); + ApiCallContext mergedContext = context1.merge(context2); + assertEquals(testContextOverwrite, mergedContext.getExtraContext(contextKey1)); + assertEquals(testContext2, mergedContext.getExtraContext(contextKey2)); + assertEquals(testContext3, mergedContext.getExtraContext(contextKey3)); + } } diff --git a/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java b/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java index a9366fbc50..02421142fd 100644 --- a/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java +++ b/gax/src/main/java/com/google/api/gax/rpc/ApiCallContext.java @@ -36,6 +36,7 @@ import com.google.api.gax.rpc.StatusCode.Code; import com.google.api.gax.tracing.ApiTracer; import com.google.auth.Credentials; +import com.google.common.base.Preconditions; import java.util.List; import java.util.Map; import java.util.Set; @@ -242,4 +243,46 @@ public interface ApiCallContext extends RetryingContext { /** Return the extra headers set for this context. */ @BetaApi("The surface for extra headers is not stable yet and may change in the future.") Map> getExtraHeaders(); + + /** + * Return a new ApiCallContext with the extra context merged into the present instance. Any + * existing value of the key is overwritten. + */ + @BetaApi("The surface for extra contexts is not stable yet and may change in the future.") + ApiCallContext withExtraContext(Key key, T extraContext); + + /** Return the extra context set for this context. */ + @SuppressWarnings("unchecked ") + @BetaApi("The surface for extra contexts is not stable yet and may change in the future.") + T getExtraContext(Key key); + + /** Key for extra contexts key-value pair. */ + public static final class Key { + private final String name; + private final T defaultValue; + + private Key(String name, T defaultValue) { + this.name = name; + this.defaultValue = defaultValue; + } + + /** + * Factory method for creating instances of {@link Key}. The default value of the key is null. + */ + public static Key create(String name) { + Preconditions.checkNotNull(name, "Key name cannot be null."); + return new Key<>(name, null); + } + + /** Factory method for creating instances of {@link Key} with default values. */ + public static Key createWithDefault(String name, T defaultValue) { + Preconditions.checkNotNull(name, "Key name cannot be null."); + return new Key<>(name, defaultValue); + } + + /** Returns the user supplied default value of the key. */ + public T getDefault() { + return defaultValue; + } + } } diff --git a/gax/src/main/java/com/google/api/gax/rpc/internal/ApiCallContextUtil.java b/gax/src/main/java/com/google/api/gax/rpc/internal/ApiCallContextUtil.java new file mode 100644 index 0000000000..853326e658 --- /dev/null +++ b/gax/src/main/java/com/google/api/gax/rpc/internal/ApiCallContextUtil.java @@ -0,0 +1,71 @@ +/* + * Copyright 2021 Google LLC + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google LLC nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +package com.google.api.gax.rpc.internal; + +import com.google.api.core.InternalApi; +import com.google.api.gax.rpc.ApiCallContext.Key; +import com.google.common.collect.ImmutableMap; + +@InternalApi +public final class ApiCallContextUtil { + + public static ImmutableMap addExtraContext( + ImmutableMap oldExtraContexts, Key newKey, Object newContext) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + if (!oldExtraContexts.containsKey(newKey)) { + builder.putAll(oldExtraContexts).put(newKey, newContext); + return builder.build(); + } + for (Key oldKey : oldExtraContexts.keySet()) { + if (oldKey.equals(newKey)) { + builder.put(oldKey, newContext); + } else { + builder.put(oldKey, oldExtraContexts.get(oldKey)); + } + } + return builder.build(); + } + + public static ImmutableMap mergeExtraContexts( + ImmutableMap oldExtraContexts, ImmutableMap newExtraContexts) { + ImmutableMap.Builder builder = ImmutableMap.builder(); + for (Key key : oldExtraContexts.keySet()) { + Object oldValue = oldExtraContexts.get(key); + Object newValue = newExtraContexts.get(key); + builder.put(key, newValue != null ? newValue : oldValue); + } + for (Key key : newExtraContexts.keySet()) { + if (!oldExtraContexts.containsKey(key)) { + builder.put(key, newExtraContexts.get(key)); + } + } + return builder.build(); + } +} diff --git a/gax/src/test/java/com/google/api/gax/rpc/testing/FakeCallContext.java b/gax/src/test/java/com/google/api/gax/rpc/testing/FakeCallContext.java index 6d16048d11..08fb27505b 100644 --- a/gax/src/test/java/com/google/api/gax/rpc/testing/FakeCallContext.java +++ b/gax/src/test/java/com/google/api/gax/rpc/testing/FakeCallContext.java @@ -35,6 +35,7 @@ import com.google.api.gax.rpc.ClientContext; import com.google.api.gax.rpc.StatusCode; import com.google.api.gax.rpc.TransportChannel; +import com.google.api.gax.rpc.internal.ApiCallContextUtil; import com.google.api.gax.rpc.internal.Headers; import com.google.api.gax.tracing.ApiTracer; import com.google.api.gax.tracing.BaseApiTracer; @@ -57,6 +58,7 @@ public class FakeCallContext implements ApiCallContext { private final Duration streamWaitTimeout; private final Duration streamIdleTimeout; private final ImmutableMap> extraHeaders; + private final ImmutableMap extraContexts; private final ApiTracer tracer; private final RetrySettings retrySettings; private final ImmutableSet retryableCodes; @@ -68,6 +70,7 @@ private FakeCallContext( Duration streamWaitTimeout, Duration streamIdleTimeout, ImmutableMap> extraHeaders, + ImmutableMap extraContexts, ApiTracer tracer, RetrySettings retrySettings, Set retryableCodes) { @@ -77,6 +80,7 @@ private FakeCallContext( this.streamWaitTimeout = streamWaitTimeout; this.streamIdleTimeout = streamIdleTimeout; this.extraHeaders = extraHeaders; + this.extraContexts = extraContexts; this.tracer = tracer; this.retrySettings = retrySettings; this.retryableCodes = retryableCodes == null ? null : ImmutableSet.copyOf(retryableCodes); @@ -84,7 +88,16 @@ private FakeCallContext( public static FakeCallContext createDefault() { return new FakeCallContext( - null, null, null, null, null, ImmutableMap.>of(), null, null, null); + null, + null, + null, + null, + null, + ImmutableMap.>of(), + ImmutableMap.of(), + null, + null, + null); } @Override @@ -157,6 +170,10 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { ImmutableMap> newExtraHeaders = Headers.mergeHeaders(extraHeaders, fakeCallContext.extraHeaders); + + ImmutableMap newExtraContext = + ApiCallContextUtil.mergeExtraContexts(extraContexts, fakeCallContext.extraContexts); + return new FakeCallContext( newCallCredentials, newChannel, @@ -164,6 +181,7 @@ public ApiCallContext merge(ApiCallContext inputCallContext) { newStreamWaitTimeout, newStreamIdleTimeout, newExtraHeaders, + newExtraContext, newTracer, newRetrySettings, newRetryableCodes); @@ -181,6 +199,7 @@ public FakeCallContext withRetrySettings(RetrySettings retrySettings) { this.streamWaitTimeout, this.streamIdleTimeout, this.extraHeaders, + this.extraContexts, this.tracer, retrySettings, this.retryableCodes); @@ -198,6 +217,7 @@ public FakeCallContext withRetryableCodes(Set retryableCodes) { this.streamWaitTimeout, this.streamIdleTimeout, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, retryableCodes); @@ -237,6 +257,7 @@ public FakeCallContext withCredentials(Credentials credentials) { this.streamWaitTimeout, this.streamIdleTimeout, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -261,6 +282,7 @@ public FakeCallContext withChannel(FakeChannel channel) { this.streamWaitTimeout, this.streamIdleTimeout, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -285,6 +307,7 @@ public FakeCallContext withTimeout(Duration timeout) { this.streamWaitTimeout, this.streamIdleTimeout, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -299,6 +322,7 @@ public ApiCallContext withStreamWaitTimeout(@Nullable Duration streamWaitTimeout streamWaitTimeout, this.streamIdleTimeout, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -314,6 +338,7 @@ public ApiCallContext withStreamIdleTimeout(@Nullable Duration streamIdleTimeout this.streamWaitTimeout, streamIdleTimeout, this.extraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -331,6 +356,7 @@ public ApiCallContext withExtraHeaders(Map> extraHeaders) { streamWaitTimeout, streamIdleTimeout, newExtraHeaders, + this.extraContexts, this.tracer, this.retrySettings, this.retryableCodes); @@ -341,6 +367,33 @@ public Map> getExtraHeaders() { return this.extraHeaders; } + @Override + public ApiCallContext withExtraContext(Key key, T extraContext) { + Preconditions.checkNotNull(key); + ImmutableMap newExtraContexts = + ApiCallContextUtil.addExtraContext(extraContexts, key, extraContext); + return new FakeCallContext( + credentials, + channel, + timeout, + streamWaitTimeout, + streamIdleTimeout, + extraHeaders, + newExtraContexts, + tracer, + retrySettings, + retryableCodes); + } + + @Override + public T getExtraContext(Key key) { + Preconditions.checkNotNull(key); + if (extraContexts.containsKey(key)) { + return (T) extraContexts.get(key); + } + return key.getDefault(); + } + /** {@inheritDoc} */ @Override @Nonnull @@ -363,6 +416,7 @@ public ApiCallContext withTracer(@Nonnull ApiTracer tracer) { this.streamWaitTimeout, this.streamIdleTimeout, this.extraHeaders, + this.extraContexts, tracer, this.retrySettings, this.retryableCodes);