Decorators in Java
The decorator pattern is one of my favourites in Java. Early in my career we used it heavily in Spring Reactor (reactive Java) to seamlessly transfer MDCs across threads and with Spring AoP to add runtime behaviour to my classes with annotations (an example, acquiring a distributed lock in redis on an orderId before the request gets into the controller).
Decorators allows us to add to the behaviour of an object at runtime. Let us see how it is done with an example. In the following example, we will implement a BoundedSet
— A set in Java whose size can be limited. This example has been taken from Java: Concurrency in Practice, but the implementation is my own.
public abstract class ForwardingSet<T> implements Set<T> {
public abstract Set<T> delegate();
@Override
public int size() {
return delegate().size();
}
@Override
public boolean isEmpty() {
return delegate().isEmpty();
}
@Override
public boolean contains(Object o) {
return delegate().contains(o);
}
@Override
public Iterator<T> iterator() {
return delegate().iterator();
}
@Override
public Object[] toArray() {
return delegate().toArray();
}
@Override
public <T1> T1[] toArray(T1[] a) {
return delegate().toArray(a);
}
@Override
public boolean add(T t) {
return delegate().add(t);
}
@Override
public boolean remove(Object o) {
return delegate().remove(o);
}
@Override
public boolean containsAll(Collection<?> c) {
return delegate().containsAll(c);
}
@Override
public boolean addAll(Collection<? extends T> c) {
return delegate().addAll(c);
}
@Override
public boolean retainAll(Collection<?> c) {
return delegate().retainAll(c);
}
@Override
public boolean removeAll(Collection<?> c) {
return delegate().removeAll(c);
}
@Override
public void clear() {
delegate().clear();
}
}
ForwardingSet
is inspired from Guava’s ForwardingSet
. This set simply delegates all of its method calls to the underlying delegate object (digression: This is composition). What is a delegrate? Well, it is the underlying object to which all our calls are delegated to (after we are done modifying the behaviour). We can now subclass this class and override any methods to modify their behaviour, as we have done in BoundedSet:
@ThreadSafe
public class BoundedSet<T> extends ForwardingSet<T> {
private Set<T> set;
private Semaphore semaphore;
public BoundedSet(int bound) {
this.semaphore = new Semaphore(bound);
this.set = Collections.synchronizedSet(new HashSet<>());
}
@Override
public boolean add(T t) {
try {
semaphore.acquire();
boolean wasAdded = super.add(t);
if (!wasAdded) {
semaphore.release();
}
return wasAdded;
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
public boolean remove(Object o) {
boolean wasRemoved = super.remove(o);
if (wasRemoved) {
semaphore.release();
}
return wasRemoved;
}
@Override
public boolean addAll(Collection<? extends T> c) {
try {
boolean setChanged = false;
for(var e : c) {
semaphore.acquire();
boolean wasAdded = super.add(e);
if(!wasAdded) {
semaphore.release();
}
setChanged |= wasAdded;
}
return setChanged;
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
public boolean removeAll(Collection<?> c) {
boolean setChanged = false;
for(var e : c) {
boolean wasRemoved = super.remove(e);
if(wasRemoved) {
semaphore.release();
}
setChanged |= wasRemoved;
}
return setChanged;
}
@Override
public Set<T> delegate() {
return this.set;
}
}
Here, we have modified the behaviour of add, remove, addAll, removeAll methods. For each add, we acquire a permit from the semaphore. For each remove, we add a permit. The implementation is a bit terse as I am using a synchronised set here which does not scale well, but you get the idea. To use this class, do:
Set<Integer> set = new BoundedSet<>(10);
The delegate in this case is the underlying synchronised set.
Spring uses decorators heavily to intercept calls to your objects before they enter them. In our case, we had to solve the tricky problem of transferring MDCs b/w two threads whenever a thread switch happened in Java. We solved this by decorating the submit
method of the corresponding executor service to get the calling thread’s MDC, copying it to the current thread’s MDC (inside the runnable), and then dispatching the call to the underlying delegate. Here’s the code:
ForwardedExecutorService.java
public abstract class ForwardedExecutorService implements ExecutorService {
public abstract ExecutorService delegate();
@Override
public void shutdown() {
delegate().shutdown();
}
@Override
public List<Runnable> shutdownNow() {
return delegate().shutdownNow();
}
@Override
public boolean isShutdown() {
return delegate().isShutdown();
}
@Override
public boolean isTerminated() {
return delegate().isTerminated();
}
@Override
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return delegate().awaitTermination(timeout, unit);
}
@Override
public <T> Future<T> submit(Callable<T> task) {
return delegate().submit(task);
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
return delegate().submit(task, result);
}
@Override
public Future<?> submit(Runnable task) {
return delegate().submit(task);
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
return delegate().invokeAll(tasks);
}
@Override
public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) throws InterruptedException {
return delegate().invokeAll(tasks, timeout, unit);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
return delegate().invokeAny(tasks);
}
@Override
public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
return delegate().invokeAny(tasks, timeout, unit);
}
@Override
public void execute(Runnable command) {
delegate().execute(command);
}
}
MDCAwareExecutorService.java
public class MDCAwareExecutorService extends ForwardedExecutorService {
private ExecutorService es;
public MDCAwareExecutorService() {
this.es = Executors.newCachedThreadPool();
}
@Override
public ExecutorService delegate() {
return this.es;
}
@Override
public <T> Future<T> submit(Callable<T> task) {
return super.submit(decorateTask(task));
}
@Override
public <T> Future<T> submit(Runnable task, T result) {
return super.submit(decorateTask(task), result);
}
@Override
public Future<?> submit(Runnable task) {
return super.submit(decorateTask(task));
}
@Override
public void execute(Runnable command) {
super.execute(decorateTask(command));
}
private <V> Callable<V> decorateTask(Callable<V> task) {
//get current thread's MDC
Map<String, String> mdc = MDC.METADATA.get();
return () -> {
MDC.METADATA.set(mdc);
return task.call();
};
}
private Runnable decorateTask(Runnable task) {
//get current thread's MDC
Map<String, String> mdc = MDC.METADATA.get();
return () -> {
MDC.METADATA.set(mdc);
task.run();
};
}
}
MDC.java
public class MDC {
public static ThreadLocal<Map<String, String>> METADATA = new ThreadLocal<>();
}
Driver.java
public class Driver {
public static void main(String[] args) throws InterruptedException {
MDC.METADATA.set(Map.of("k1","v1","k2","v2"));
var es = new MDCAwareExecutorService();
es.submit(() -> {
System.out.println("Thread: " + Thread.currentThread().getName() + ", MDC: " + MDC.METADATA.get());;
});
Thread.sleep(Integer.MAX_VALUE);
}
}
That is all folks, till next time!