Mapped Diagnostic Context logging with play framewok and akka in java - java

i am trying mdc logging in play filter in java for all requests i followed this tutorial in scala and tried converting to java http://yanns.github.io/blog/2014/05/04/slf4j-mapped-diagnostic-context-mdc-with-play-framework/
but still the mdc is not propagated to all execution contexts.
i am using this dispathcher as default dispatcher but there are many execution contexts for it. i need the mdc propagated to all execution contexts
below is my java code
import java.util.Map;
import org.slf4j.MDC;
import scala.concurrent.ExecutionContext;
import scala.concurrent.duration.Duration;
import scala.concurrent.duration.FiniteDuration;
import akka.dispatch.Dispatcher;
import akka.dispatch.ExecutorServiceFactoryProvider;
import akka.dispatch.MessageDispatcherConfigurator;
public class MDCPropagatingDispatcher extends Dispatcher {
public MDCPropagatingDispatcher(
MessageDispatcherConfigurator _configurator, String id,
int throughput, Duration throughputDeadlineTime,
ExecutorServiceFactoryProvider executorServiceFactoryProvider,
FiniteDuration shutdownTimeout) {
super(_configurator, id, throughput, throughputDeadlineTime,
executorServiceFactoryProvider, shutdownTimeout);
}
#Override
public ExecutionContext prepare() {
final Map<String, String> mdcContext = MDC.getCopyOfContextMap();
return new ExecutionContext() {
#Override
public void execute(Runnable r) {
Map<String, String> oldMDCContext = MDC.getCopyOfContextMap();
setContextMap(mdcContext);
try {
r.run();
} finally {
setContextMap(oldMDCContext);
}
}
#Override
public ExecutionContext prepare() {
return this;
}
#Override
public void reportFailure(Throwable t) {
play.Logger.info("error occured in dispatcher");
}
};
}
private void setContextMap(Map<String, String> context) {
if (context == null) {
MDC.clear();
} else {
play.Logger.info("set context "+ context.toString());
MDC.setContextMap(context);
}
}
}
import java.util.concurrent.TimeUnit;
import scala.concurrent.duration.Duration;
import scala.concurrent.duration.FiniteDuration;
import com.typesafe.config.Config;
import akka.dispatch.DispatcherPrerequisites;
import akka.dispatch.MessageDispatcher;
import akka.dispatch.MessageDispatcherConfigurator;
public class MDCPropagatingDispatcherConfigurator extends
MessageDispatcherConfigurator {
private MessageDispatcher instance;
public MDCPropagatingDispatcherConfigurator(Config config,
DispatcherPrerequisites prerequisites) {
super(config, prerequisites);
Duration throughputDeadlineTime = new FiniteDuration(-1,
TimeUnit.MILLISECONDS);
FiniteDuration shutDownDuration = new FiniteDuration(1,
TimeUnit.MILLISECONDS);
instance = new MDCPropagatingDispatcher(this, "play.akka.actor.contexts.play-filter-context",
100, throughputDeadlineTime,
configureExecutor(), shutDownDuration);
}
public MessageDispatcher dispatcher() {
return instance;
}
}
filter interceptor
public class MdcLogFilter implements EssentialFilter {
#Override
public EssentialAction apply(final EssentialAction next) {
return new MdcLogAction() {
#Override
public Iteratee<byte[], SimpleResult> apply(
final RequestHeader requestHeader) {
final String uuid = Utils.generateRandomUUID();
MDC.put("uuid", uuid);
play.Logger.info("request started"+uuid);
final ExecutionContext playFilterContext = Akka.system()
.dispatchers()
.lookup("play.akka.actor.contexts.play-custom-filter-context");
return next.apply(requestHeader).map(
new AbstractFunction1<SimpleResult, SimpleResult>() {
#Override
public SimpleResult apply(SimpleResult simpleResult) {
play.Logger.info("request ended"+uuid);
MDC.remove("uuid");
return simpleResult;
}
}, playFilterContext);
}
#Override
public EssentialAction apply() {
return next.apply();
}
};
}
}

Below is my solution, proven in real life. It's in Scala, and not for Play, but for Scalatra, but the underlying concept is the same. Hope you'll be able to figure out how to port this to Java.
import org.slf4j.MDC
import java.util.{Map => JMap}
import scala.concurrent.{ExecutionContextExecutor, ExecutionContext}
object MDCHttpExecutionContext {
def fromExecutionContextWithCurrentMDC(delegate: ExecutionContext): ExecutionContextExecutor =
new MDCHttpExecutionContext(MDC.getCopyOfContextMap(), delegate)
}
class MDCHttpExecutionContext(mdcContext: JMap[String, String], delegate: ExecutionContext)
extends ExecutionContextExecutor {
def execute(runnable: Runnable): Unit = {
val callingThreadMDC = MDC.getCopyOfContextMap()
delegate.execute(new Runnable {
def run() {
val currentThreadMDC = MDC.getCopyOfContextMap()
setContextMap(callingThreadMDC)
try {
runnable.run()
} finally {
setContextMap(currentThreadMDC)
}
}
})
}
private[this] def setContextMap(context: JMap[String, String]): Unit = {
Option(context) match {
case Some(ctx) => {
MDC.setContextMap(context)
}
case None => {
MDC.clear()
}
}
}
def reportFailure(t: Throwable): Unit = delegate.reportFailure(t)
}
You'll have to make sure that this ExecutionContext is used in all of your asynchronous calls. I achieve this through Dependency Injection, but there are different ways. That's how I do it with subcut:
bind[ExecutionContext] idBy BindingIds.GlobalExecutionContext toSingle {
MDCHttpExecutionContext.fromExecutionContextWithCurrentMDC(
ExecutionContext.fromExecutorService(
Executors.newFixedThreadPool(globalThreadPoolSize)
)
)
}
The idea behind this approach is as follows. MDC uses thread-local storage for the attributes and their values. If a single request of yours can run on a multiple threads, then you need to make sure the new thread you start uses the right MDC. For that, you create a custom executor that ensures the proper copying of the MDC values into the new thread before it starts executing the task you assign to it. You also must ensure that when the thread finishes your task and continues with something else, you put the old values into its MDC, because threads from a pool can switch between different requests.

Related

Resilience4j context propagator not able to propagte thread local values

I am trying to migrate my circuit breaker code from Hystrix to Resilience4j. The communication is between two applications out of which one is an artifact containing all the resilience 4j config in the java code itself and the second application which is a microservice uses it directly.
There's one RequestId which generates in the microservice and propagates to the artifact context where it gets printed in the logs. With Hystrix, it was working perfectly fine but ever since I moved to resilience, I am getting null for the request Id.
Below is my config for bulk head and context propagator :
ThreadPoolBulkheadConfig bulkheadConfig = ThreadPoolBulkheadConfig.custom()
.maxThreadPoolSize(maxThreadPoolSize)
.coreThreadPoolSize(coreThreadPoolSize)
.queueCapacity(queueCapacity)
.contextPropagator(new DummyContextPropagator())
.build();
// Bulk Head Registry
ThreadPoolBulkheadRegistry bulkheadRegistry = ThreadPoolBulkheadRegistry.of(bulkheadConfig);
// Create Bulk Head
ThreadPoolBulkhead bulkhead = bulkheadRegistry.bulkhead(name, bulkheadConfig);
Dummy Context Propagator :
public class DummyContextPropagator implements ContextPropagator {
private static final Logger log = LoggerFactory.getLogger( DummyContextPropagator.class);
#Override
public Supplier<Optional<Object>> retrieve() {
return () -> (Optional<Object>) get();
}
#Override
public Consumer<Optional<Object>> copy() {
return (t) -> t.ifPresent(e -> {
clear();
put(e);
});
}
#Override
public Consumer<Optional<Object>> clear() {
return (t) -> DummyContextHolder.clear();
}
public static class DummyContextHolder {
private static final ThreadLocal threadLocal = new ThreadLocal();
private DummyContextHolder() {
}
public static void put(Object context) {
if (threadLocal.get() != null) {
clear();
}
threadLocal.set(context);
}
public static void clear() {
if (threadLocal.get() != null) {
threadLocal.set(null);
threadLocal.remove();
}
}
public static Optional<Object> get() {
return Optional.ofNullable(threadLocal.get());
}
}
}
However, nothing seems to work so that I can get the RequestId.
Am I doing everything right or is there another way to do that ?
i think you want to get params from threadlocal from parent-thread when you in sub-thread, in hystrix it use command-model to decorate callabletask
in resilience4j i think u can fix it like this:
#Resource
DispatcherServlet dispatcherServlet;
#PostConstruct
public void changeThreadLocalModel() {
dispatcherServlet.setThreadContextInheritable(true);
}
i find my last answer may lead to some problems, when you use "dispatcherServlet.setThreadContextInheritable(true);"
it may pollute your custom thread-pool`s threadlocalmap;
so here is my final resolve, and it only works at resilience4j;
#Resource
Resilience4jBulkheadProvider resilience4jBulkheadProvider;
#PostConstruct
public void concurrentThreadContextStrategy() {
ThreadPoolBulkheadConfig threadPoolBulkheadConfig = ThreadPoolBulkheadConfig.custom().contextPropagator(new CustomInheritContextPropagator()).build();
resilience4jBulkheadProvider.configureDefault(id -> new Resilience4jBulkheadConfigurationBuilder()
.bulkheadConfig(BulkheadConfig.ofDefaults()).threadPoolBulkheadConfig(threadPoolBulkheadConfig)
.build());
}
private static class CustomInheritContextPropagator implements ContextPropagator<RequestAttributes> {
#Override
public Supplier<Optional<RequestAttributes>> retrieve() {
// give requestcontext to reference from threadlocal;
// this method call by web-container thread, such as tomcat, jetty,or undertow, depends on what you used;
return () -> Optional.ofNullable(RequestContextHolder.getRequestAttributes());
}
#Override
public Consumer<Optional<RequestAttributes>> copy() {
// load requestcontex into real-call thread
// this method call by resilience4j bulkhead thread;
return requestAttributes -> requestAttributes.ifPresent(context -> {
RequestContextHolder.resetRequestAttributes();
RequestContextHolder.setRequestAttributes(context);
});
}
#Override
public Consumer<Optional<RequestAttributes>> clear() {
// clean requestcontext finally ;
// this method call by resilience4j bulkhead thread;
return requestAttributes -> RequestContextHolder.resetRequestAttributes();
}
}
i got the same problem with springboot 2.5 et springboot cloud 2020.0.6
and I solved it with an implementation of ContextPropagator
public class SleuthPropagator implements ContextPropagator<TraceContext> {
ThreadLocal<ScopedSpan> scopedSpanThreadLocal = new ThreadLocal<>();
#Override
public Supplier<Optional<TraceContext>> retrieve() {
return this::getCurrentcontext;
}
#Override
public Consumer<Optional<TraceContext>> copy() {
return c -> {
if (!c.isPresent()) {
return;
}
TraceContext traceContext = c.get();
ScopedSpan resilience4jSpan = getTracer()
.map(t -> t.startScopedSpanWithParent("Resilience4j", traceContext))
.orElse(null);
scopedSpanThreadLocal.set(resilience4jSpan);
};
}
#Override
public Consumer<Optional<TraceContext>> clear() {
return t -> {
try {
ScopedSpan resilience4jSpan = scopedSpanThreadLocal.get();
if (resilience4jSpan != null) {
resilience4jSpan.finish();
}
} finally {
scopedSpanThreadLocal.remove();
}
};
}
private static Optional<Tracer> getTracer() {
return Optional.ofNullable(Tracing.current())
.map(Tracing::tracer);
}
private Optional<TraceContext> getCurrentcontext() {
return getTracer()
.map(Tracer::currentSpan)
.map(Span::context);
}
}
And use the propagator in adding this to your application.properties
resilience4j.thread-pool-bulkhead.instances.YOUR_BULKHEAD_CONFIG.context-propagators=com.your.package.SleuthPropagator

How to add instrumentation to GraphQL Java with graphql-spring-boot?

does anybody know how I can add instrumentation to a GraphQL execution when using graphql-spring-boot (https://github.com/graphql-java-kickstart/graphql-spring-boot) ? I know how this is possible with plain-vanilla graphql-java: https://www.graphql-java.com/documentation/v13/instrumentation/
However, I don't know how to do this when graphql-spring-boot is used and takes control over the execution. Due to lack of documentation I tried it simply this way:
#Service
public class GraphQLInstrumentationProvider implements InstrumentationProvider {
#Override
public Instrumentation getInstrumentation() {
return SimpleInstrumentation.INSTANCE;
}
}
But the method getInstrumentation on my InstrumentationProvider bean is (as expected) never called. Any help appreciated.
Answering my own question. In the meantime I managed to do it this way:
final class RequestLoggingInstrumentation extends SimpleInstrumentation {
private static final Logger logger = LoggerFactory.getLogger(RequestLoggingInstrumentation.class);
#Override
public InstrumentationContext<ExecutionResult> beginExecution(InstrumentationExecutionParameters parameters) {
long startMillis = System.currentTimeMillis();
var executionId = parameters.getExecutionInput().getExecutionId();
if (logger.isInfoEnabled()) {
logger.info("GraphQL execution {} started", executionId);
var query = parameters.getQuery();
logger.info("[{}] query: {}", executionId, query);
if (parameters.getVariables() != null && !parameters.getVariables().isEmpty()) {
logger.info("[{}] variables: {}", executionId, parameters.getVariables());
}
}
return new SimpleInstrumentationContext<>() {
#Override
public void onCompleted(ExecutionResult executionResult, Throwable t) {
if (logger.isInfoEnabled()) {
long endMillis = System.currentTimeMillis();
if (t != null) {
logger.info("GraphQL execution {} failed: {}", executionId, t.getMessage(), t);
} else {
var resultMap = executionResult.toSpecification();
var resultJSON = ObjectMapper.pojoToJSON(resultMap).replace("\n", "\\n");
logger.info("[{}] completed in {}ms", executionId, endMillis - startMillis);
logger.info("[{}] result: {}", executionId, resultJSON);
}
}
}
};
}
}
#Service
class InstrumentationService {
private final ContextFactory contextFactory;
InstrumentationService(ContextFactory contextFactory) {
this.contextFactory = contextFactory;
}
/**
* Return all instrumentations as a bean.
* The result will be used in class {#link com.oembedler.moon.graphql.boot.GraphQLWebAutoConfiguration}.
*/
#Bean
List<Instrumentation> instrumentations() {
// Note: Due to a bug in GraphQLWebAutoConfiguration, the returned list has to be modifiable (it will be sorted)
return new ArrayList<>(
List.of(new RequestLoggingInstrumentation()));
}
}
It helped me to have a look into the class GraphQLWebAutoConfiguration. There I found out that the framework expects a bean of type List<Instrumentation>, which contains all the instrumentations that will be added to the GraphQL execution.
There is a simpler way to add instrumentation with spring boot:
#Configuration
public class InstrumentationConfiguration {
#Bean
public Instrumentation someFieldCheckingInstrumentation() {
return new FieldValidationInstrumentation(env -> {
// ...
});
}
}
Spring boot will collect all beans which implement Instrumentation (see GraphQLWebAutoConfiguration).

How can we use lagom's Read-side processor with Dgraph?

I am a newbie to lagom and dgraph. And I got stuck to how to use lagom's read-side processor with Dgraph. Just to give you an idea following is the code which uses Cassandra with lagom.
import akka.NotUsed;
import com.lightbend.lagom.javadsl.api.ServiceCall;
import com.lightbend.lagom.javadsl.persistence.cassandra.CassandraSession;
import java.util.concurrent.CompletableFuture;
import javax.inject.Inject;
import akka.stream.javadsl.Source;
public class FriendServiceImpl implements FriendService {
private final CassandraSession cassandraSession;
#Inject
public FriendServiceImpl(CassandraSession cassandraSession) {
this.cassandraSession = cassandraSession;
}
//Implement your service method here
}
Lagom does not provide out-of-the-box support for Dgraph. If you have to use Lagom's Read-Side processor with Dgraph, then you have to use Lagom's Generic Read Side support. Like this:
/**
* Read side processor for Dgraph.
*/
public class FriendEventProcessor extends ReadSideProcessor<FriendEvent> {
private static void createModel() {
//TODO: Initialize schema in Dgraph
}
#Override
public ReadSideProcessor.ReadSideHandler<FriendEvent> buildHandler() {
return new ReadSideHandler<FriendEvent>() {
private final Done doneInstance = Done.getInstance();
#Override
public CompletionStage<Done> globalPrepare() {
createModel();
return CompletableFuture.completedFuture(doneInstance);
}
#Override
public CompletionStage<Offset> prepare(final AggregateEventTag<FriendEvent> tag) {
return CompletableFuture.completedFuture(Offset.NONE);
}
#Override
public Flow<Pair<FriendEvent, Offset>, Done, ?> handle() {
return Flow.<Pair<FriendEvent, Offset>>create()
.mapAsync(1, eventAndOffset -> {
if (eventAndOffset.first() instanceof FriendCreated) {
//TODO: Add Friend in Dgraph;
}
return CompletableFuture.completedFuture(doneInstance);
}
);
}
};
}
#Override
public PSequence<AggregateEventTag<FriendEvent>> aggregateTags() {
return FriendEvent.TAG.allTags();
}
}
For FriendEvent.TAG.allTags(), you have to add following code in FriendEvent interface:
int NUM_SHARDS = 20;
AggregateEventShards<FriendEvent> TAG =
AggregateEventTag.sharded(FriendEvent.class, NUM_SHARDS);
#Override
default AggregateEventShards<FriendEvent> aggregateTag() {
return TAG;
}
I hope this helps!

Actor MDC context in aroundReceive method

I have a Java Akka application and I want to set a separate MDC context for each message handling based on information inside the message, for example I have the following base interface for all messages:
public interface IdMessage {
String getId();
}
Also I have the following base actor for all actors:
public class BaseActor extends AbstractActor {
private final DiagnosticLoggingAdapter log = Logging.apply(this);
#Override
public void aroundReceive(PartialFunction<Object, BoxedUnit> receive, Object msg) {
if (msg instanceof IdMessage) {
final Map<String, Object> originalMDC = log.getMDC();
log.setMDC(ImmutableMap.of("id", ((IdMessage) msg).getId()));
try {
super.aroundReceive(receive, msg);
} finally {
if (originalMDC != null) {
log.setMDC(originalMDC);
} else {
log.clearMDC();
}
}
} else {
super.aroundReceive(receive, msg);
}
}
}
And the actual actor implementation:
public class SomeActor extends BaseActor {
SomeActor() {
receive(ReceiveBuilder
.match(SomeMessage.class, message -> {
...
}
}
}
I would like to make sure that all logs entries inside SomeActor#receive() will have MDC context set in the BaseActor. To make this work SomeActor#receice() need to be executed in the same thread as BaseActor#aroundReceive() method.
I didn't find any information about the behaviour of aroundReceive - is that going to be always executed in the same thread as the actual receive method? Based on my testing it's always executed in the same thread.
I was able to figure out the proper implementation by myself and would like to share it in case someone face with the same issue.
The aroundReceive is going to be executed in the same thread as receive, so this is the right place to set MDC context.
I used org.slf4j.MDC for setting the MDC context, here is the full code:
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
import com.google.common.collect.ImmutableMap;
import akka.actor.AbstractActor;
import scala.PartialFunction;
import scala.runtime.BoxedUnit;
public class BaseActor extends AbstractActor {
private final Logger log = LoggerFactory.getLogger(BaseActor.class);
#Override
public void aroundReceive(PartialFunction<Object, BoxedUnit> receive, Object msg) {
if (msg instanceof IdMessage) {
final Map<String, Object> originalMDC = log.getMDC();
MDC.setContextMap(ImmutableMap.of("id", ((IdMessage) msg).getId()));
try {
super.aroundReceive(receive, msg);
} finally {
if (originalMDC != null) {
MDC.setContextMap(originalMDC);
} else {
MDC.clear();
}
}
} else {
super.aroundReceive(receive, msg);
}
}
}
With that implementation of BaseActor all log entries in receive are logged with a proper MDC context. Additional information could be found in this interesting blog post (with Scala implementation).
Note: I was not able to reach the same functionality with Akka DiagnosticLoggingAdapter although it has methods to set MDC context.

Unit testing clients of Observables

I have the following method go() I'd like to test:
private Pair<String, String> mPair;
public void go() {
Observable.zip(
mApi.webCall(),
mApi.webCall2(),
new Func2<String, String, Pair<String, String>>() {
#Override
public Pair<String, String> call(String s, String s2) {
return new Pair(s, s2);
}
}
)
.subscribeOn(Schedulers.io())
.observeOn(AndroidSchedulers.mainThread())
.subscribe(new Action1<Pair<String, String>>() {
#Override
public void call(Pair<String, String> pair) {
mApi.webCall3(pair.first, pair.second);
}
});
}
This method uses Observable.zip() to execute to http requests asynchronously, and merge them together in one Pair. In the end, another http request is executed with the result of these previous requests.
I'd like to verify that calling the go() method makes the webCall() and webCall2() requests, followed by the webCall3(String, String) request. Therefore, I'd like the following test to pass (using Mockito to spy the Api):
#Test
public void testGo() {
/* Given */
Api api = spy(new Api() {
#Override
public Observable<String> webCall() {
return Observable.just("First");
}
#Override
public Observable<String> webCall2() {
return Observable.just("second");
}
#Override
public void webCall3() {
}
});
Test test = new Test(api);
/* When */
test.go();
/* Then */
verify(api).webCall();
verify(api).webCall2();
verify(api).webCall3("First", "second");
}
However when running this, web calls are executed asynchronously, and my test executes the assertion before the subscriber is done causing my test to fail.
I have read that you can use RxJavaSchedulersHook and RxAndroidSchedulersHook to return Schedulers.immediate() for all methods, but this results in the test running indefinitely.
I am running my unit tests on a local JVM.
How can I achieve this, preferably without having to modify the signature of go()?
(Lambdas thanks to retrolambda)
For starters, I would rephrase go as:
private Pair<String, String> mPair;
public Observable<Pair<String, String>> go() {
return Observable.zip(
mApi.webCall(),
mApi.webCall2(),
(String s, String s2) -> new Pair(s, s2)
.subscribeOn(Schedulers.io())
.observeOn(AndroidSchedulers.mainThread())
.doOnNext(pair -> mPair = pair);
}
public Pair<String, String> getPair() {
return mPair;
}
doOnNext allows you to intercept the value that is being processed in the chain whenever someone will subscribe to the Observable
Then, I would call the test like that:
Pair result = test.go().toBlocking().lastOrDefault(null);
Then you can test what result is.
I would use the TestScheduler and TestSubscriber in your tests. In order to use this you'll have to receive the observable that composes the zip so you can subscribe to that work with the testScheduler. Also you'll have to parameterize your schedulers. You won't have to change your go method's signature but you would have to parameterize the schedulers in underlying functionality. You could inject the schedulers by constructor, override a protected field by inheritance, or call to a package protected overload. I have written my examples assuming an overload that accepts the schedulers as arguments and returns the Observable.
The TestScheduler gives you a synchronous way to trigger async operator behavior in a predictable reproducible way. The TestSubscriber gives you a way to await termination and assert over values and signals received. Also you might want to be aware that the delay(long, TimeUnit) operator by default schedules work on the computation scheduler. You'll need to use the testScheduler there as well.
Scheduler ioScheduler = Schedulers.io();
Scheduler mainThreadScheduler = AndroidSchedulers.mainThread();
public void go() {
go(subscribeOnScheduler, mainThreadScheduler).toBlocking().single();
}
/*package*/ Observable<Pair<String, String>> go(Scheduler ioScheduler, Scheduler mainThreadScheduler) {
return Observable.zip(
mApi.webCall(),
mApi.webCall2(),
new Func2<String, String, Pair<String, String>>() {
#Override
public Pair<String, String> call(String s, String s2) {
return new Pair(s, s2);
}
})
.doOnNext(new Action1<Pair<String, String>>() {
#Override
public void call(Pair<String, String>() {
mApi.webCall3(pair.first, pair.second);
})
})
.subscribeOn(ioScheduler)
.observeOn(mainThreadScheduler);
}
Test code
#Test
public void testGo() {
/* Given */
TestScheduler testScheduler = new TestScheduler();
Api api = spy(new Api() {
#Override
public Observable<String> webCall() {
return Observable.just("First").delay(1, TimeUnit.SECONDS, testScheduler);
}
#Override
public Observable<String> webCall2() {
return Observable.just("second");
}
#Override
public void webCall3() {
}
});
Test test = new Test(api);
/* When */
test.go(testScheduler, testScheduler).subscribe(subscriber);
testScheduler.triggerActions();
subscriber.awaitTerminalEvent();
/* Then */
verify(api).webCall();
verify(api).webCall2();
verify(api).webCall3("First", "second");
}
I have found out that I can retrieve my Schedulers in a non-static way, basically injecting them into my client class. The SchedulerProvider replaces the static calls to Schedulers.x():
public interface SchedulerProvider {
Scheduler io();
Scheduler mainThread();
}
The production implementation delegates back to Schedulers:
public class SchedulerProviderImpl implements SchedulerProvider {
public static final SchedulerProvider INSTANCE = new SchedulerProviderImpl();
#Override
public Scheduler io() {
return Schedulers.io();
}
#Override
public Scheduler mainThread() {
return AndroidSchedulers.mainThread();
}
}
However, during tests I can create a TestSchedulerProvider:
public class TestSchedulerProvider implements SchedulerProvider {
private final TestScheduler mIOScheduler = new TestScheduler();
private final TestScheduler mMainThreadScheduler = new TestScheduler();
#Override
public TestScheduler io() {
return mIOScheduler;
}
#Override
public TestScheduler mainThread() {
return mMainThreadScheduler;
}
}
Now I can inject the SchedulerProvider in to the Test class containing the go() method:
class Test {
/* ... */
Test(Api api, SchedulerProvider schedulerProvider) {
mApi = api;
mSchedulerProvider = schedulerProvider;
}
void go() {
Observable.zip(
mApi.webCall(),
mApi.webCall2(),
new Func2<String, String, Pair<String, String>>() {
#Override
public Pair<String, String> call(String s, String s2) {
return new Pair(s, s2);
}
}
)
.subscribeOn(mSchedulerProvider.io())
.observeOn(mSchedulerProvider.mainThread())
.subscribe(new Action1<Pair<String, String>>() {
#Override
public void call(Pair<String, String> pair) {
mApi.webCall3(pair.first, pair.second);
}
});
}
}
Testing this works as follows:
#Test
public void testGo() {
/* Given */
TestSchedulerProvider testSchedulerProvider = new TestSchedulerProvider();
Api api = spy(new Api() {
#Override
public Observable<String> webCall() {
return Observable.just("First");
}
#Override
public Observable<String> webCall2() {
return Observable.just("second");
}
#Override
public void webCall3() {
}
});
Test test = new Test(api, testSchedulerProvider);
/* When */
test.go();
testSchedulerProvider.io().triggerActions();
testSchedulerProvider.mainThread().triggerActions();
/* Then */
verify(api).webCall();
verify(api).webCall2();
verify(api).webCall3("First", "second");
}
I had a similar issue that took one more step in order to be solved.:
existingObservable
.zipWith(Observable.interval(100, TimeUnit.MILLISECONDS), new Func1<> ...)
.subscribeOn(schedulersProvider.computation())
Was still not using the provided TestScheduler schedulersProvider returned. It was necessary to specify .subscribeOn() on the individual streams that i was zipping in order to work.:
existingObservable.subscribeOn(schedulersProvider.computation())
.zipWith(Observable.interval(100, TimeUnit.MILLISECONDS).subscribeOn(schedulersProvider.computation()), new Func1<> ...)
.subscribeOn(schedulersProvider.computation())
Note that schedulersProvider is a mock returning the TestScheduler of my Test!

Categories

Resources