diff --git a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java index 613b2c6456..d067fc9ec1 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java @@ -16,6 +16,8 @@ package org.springframework.messaging.simp.broker; +import static org.springframework.messaging.support.MessageHeaderAccessor.getAccessor; + import java.util.Collection; import java.util.HashSet; import java.util.LinkedHashMap; @@ -24,8 +26,20 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; - +import java.util.concurrent.CopyOnWriteArraySet; + +import org.springframework.expression.AccessException; +import org.springframework.expression.EvaluationContext; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.PropertyAccessor; +import org.springframework.expression.TypedValue; +import org.springframework.expression.spel.SpelEvaluationException; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.StandardEvaluationContext; import org.springframework.messaging.Message; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; import org.springframework.util.AntPathMatcher; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -34,7 +48,13 @@ import org.springframework.util.PathMatcher; /** - * A default, simple in-memory implementation of {@link SubscriptionRegistry}. + * Implementation of {@link SubscriptionRegistry} that stores subscriptions + * in memory and uses a {@link org.springframework.util.PathMatcher PathMatcher} + * for matching destinations. + * + *

As of 4.2 this class supports a {@link #setSelectorHeaderName selector} + * header on subscription messages with Spring EL expressions evaluated against + * the headers to filter out messages in addition to destination matching. * * @author Rossen Stoyanchev * @author Sebastien Deleuze @@ -51,6 +71,10 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { private PathMatcher pathMatcher = new AntPathMatcher(); + private String selectorHeaderName = "selector"; + + private ExpressionParser expressionParser = new SpelExpressionParser(); + private final DestinationCache destinationCache = new DestinationCache(); private final SessionSubscriptionRegistry subscriptionRegistry = new SessionSubscriptionRegistry(); @@ -85,10 +109,52 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { return this.pathMatcher; } + /** + * Configure the name of a selector header that a subscription message can + * have in order to filter messages based on their headers. The value of the + * header can use Spring EL expressions against message headers. + *

For example the following expression expects a header called "foo" to + * have the value "bar": + *

+	 * headers.foo == 'bar'
+	 * 
+ *

By default this is set to "selector". + * @since 4.2 + */ + public void setSelectorHeaderName(String selectorHeaderName) { + Assert.notNull(selectorHeaderName); + this.selectorHeaderName = selectorHeaderName; + } + + /** + * Return the name for the selector header. + */ + public String getSelectorHeaderName() { + return this.selectorHeaderName; + } + @Override - protected void addSubscriptionInternal(String sessionId, String subsId, String destination, Message message) { - this.subscriptionRegistry.addSubscription(sessionId, subsId, destination); + protected void addSubscriptionInternal(String sessionId, String subsId, String destination, + Message message) { + + Expression expression = null; + MessageHeaders headers = message.getHeaders(); + String selector = SimpMessageHeaderAccessor.getFirstNativeHeader(getSelectorHeaderName(), headers); + if (selector != null) { + try { + expression = this.expressionParser.parseExpression(selector); + if (logger.isTraceEnabled()) { + logger.trace("Subscription selector: [" + selector + "]"); + } + } + catch (Throwable ex) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to parse selector: " + selector, ex); + } + } + } + this.subscriptionRegistry.addSubscription(sessionId, subsId, destination, expression); this.destinationCache.updateAfterNewSubscription(destination, sessionId, subsId); } @@ -112,17 +178,19 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { } @Override - protected MultiValueMap findSubscriptionsInternal(String destination, Message message) { + protected MultiValueMap findSubscriptionsInternal(String destination, + Message message) { + MultiValueMap result = this.destinationCache.getSubscriptions(destination); if (result != null) { - return result; + return filterSubscriptions(result, message); } result = new LinkedMultiValueMap(); for (SessionSubscriptionInfo info : this.subscriptionRegistry.getAllSubscriptions()) { for (String destinationPattern : info.getDestinations()) { if (this.pathMatcher.match(destinationPattern, destination)) { - for (String subscriptionId : info.getSubscriptions(destinationPattern)) { - result.add(info.sessionId, subscriptionId); + for (Subscription subscription : info.getSubscriptions(destinationPattern)) { + result.add(info.sessionId, subscription.getId()); } } } @@ -130,6 +198,44 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { if (!result.isEmpty()) { this.destinationCache.addSubscriptions(destination, result); } + return filterSubscriptions(result, message); + } + + private MultiValueMap filterSubscriptions(MultiValueMap allMatches, + Message message) { + + EvaluationContext context = null; + MultiValueMap result = new LinkedMultiValueMap(allMatches.size()); + for (String sessionId : allMatches.keySet()) { + for (String subId : allMatches.get(sessionId)) { + SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId); + Subscription sub = info.getSubscription(subId); + Expression expression = sub.getSelectorExpression(); + if (expression == null) { + result.add(sessionId, subId); + continue; + } + if (context == null) { + context = new StandardEvaluationContext(message); + context.getPropertyAccessors().add(new SimpMessageHeaderPropertyAccessor()); + } + try { + if (expression.getValue(context, boolean.class)) { + result.add(sessionId, subId); + } + } + catch (SpelEvaluationException ex) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to evaluate selector: " + ex.getMessage()); + } + } + catch (Throwable ex) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to evaluate selector.", ex); + } + } + } + } return result; } @@ -257,7 +363,9 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { return this.sessions.values(); } - public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, String destination) { + public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, + String destination, Expression selectorExpression) { + SessionSubscriptionInfo info = this.sessions.get(sessionId); if (info == null) { info = new SessionSubscriptionInfo(sessionId); @@ -266,7 +374,7 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { info = value; } } - info.addSubscription(destination, subscriptionId); + info.addSubscription(destination, subscriptionId, selectorExpression); return info; } @@ -287,8 +395,9 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { private final String sessionId; - // destination -> subscriptionIds - private final Map> subscriptions = new ConcurrentHashMap>(4); + // destination -> subscriptions + private final Map> destinationLookup = + new ConcurrentHashMap>(4); private final Object monitor = new Object(); @@ -303,37 +412,50 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { } public Set getDestinations() { - return this.subscriptions.keySet(); + return this.destinationLookup.keySet(); } - public Set getSubscriptions(String destination) { - return this.subscriptions.get(destination); + public Set getSubscriptions(String destination) { + return this.destinationLookup.get(destination); + } + + public Subscription getSubscription(String subscriptionId) { + for (String destination : this.destinationLookup.keySet()) { + for (Subscription sub : this.destinationLookup.get(destination)) { + if (sub.getId().equalsIgnoreCase(subscriptionId)) { + return sub; + } + } + } + return null; } - public void addSubscription(String destination, String subscriptionId) { - Set subs = this.subscriptions.get(destination); + public void addSubscription(String destination, String subscriptionId, Expression selectorExpression) { + Set subs = this.destinationLookup.get(destination); if (subs == null) { synchronized (this.monitor) { - subs = this.subscriptions.get(destination); + subs = this.destinationLookup.get(destination); if (subs == null) { - subs = new HashSet(4); - this.subscriptions.put(destination, subs); + subs = new CopyOnWriteArraySet(); + this.destinationLookup.put(destination, subs); } } } - subs.add(subscriptionId); + subs.add(new Subscription(subscriptionId, selectorExpression)); } public String removeSubscription(String subscriptionId) { - for (String destination : this.subscriptions.keySet()) { - Set subscriptionIds = this.subscriptions.get(destination); - if (subscriptionIds.remove(subscriptionId)) { - synchronized (this.monitor) { - if (subscriptionIds.isEmpty()) { - this.subscriptions.remove(destination); + for (String destination : this.destinationLookup.keySet()) { + Set subscriptions = this.destinationLookup.get(destination); + for (Subscription sub : subscriptions) { + if (sub.getId().equals(subscriptionId) && subscriptions.remove(sub)) { + synchronized (this.monitor) { + if (subscriptions.isEmpty()) { + this.destinationLookup.remove(destination); + } } + return destination; } - return destination; } } return null; @@ -341,7 +463,73 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry { @Override public String toString() { - return "[sessionId=" + this.sessionId + ", subscriptions=" + this.subscriptions + "]"; + return "[sessionId=" + this.sessionId + ", subscriptions=" + this.destinationLookup + "]"; + } + } + + private static class Subscription { + + private final String id; + + private final Expression selectorExpression; + + + public Subscription(String id, Expression selector) { + this.id = id; + this.selectorExpression = selector; + } + + + public String getId() { + return this.id; + } + + public Expression getSelectorExpression() { + return this.selectorExpression; + } + + @Override + public String toString() { + return "Subscription id='" + this.id; + } + } + + private static class SimpMessageHeaderPropertyAccessor implements PropertyAccessor { + + @Override + public Class[] getSpecificTargetClasses() { + return new Class[] {MessageHeaders.class}; + } + + @Override + public boolean canRead(EvaluationContext context, Object target, String name) { + return true; + } + + @Override + public TypedValue read(EvaluationContext context, Object target, String name) throws AccessException { + MessageHeaders headers = (MessageHeaders) target; + SimpMessageHeaderAccessor accessor = getAccessor(headers, SimpMessageHeaderAccessor.class); + Object value; + if ("destination".equalsIgnoreCase(name)) { + value = accessor.getDestination(); + } + else { + value = accessor.getFirstNativeHeader(name); + if (value == null) { + value = headers.get(name); + } + } + return new TypedValue(value); + } + + @Override + public boolean canWrite(EvaluationContext context, Object target, String name) { + return false; + } + + @Override + public void write(EvaluationContext context, Object target, String name, Object value) { } } diff --git a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java index ef46586df0..d49ffa3e90 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistryTests.java @@ -249,6 +249,32 @@ public class DefaultSubscriptionRegistryTests { assertEquals("Expected no elements " + actual, 0, actual.size()); } + @Test + public void registerSubscriptionWithSelector() throws Exception { + + String sessionId = "sess01"; + String subscriptionId = "subs01"; + String destination = "/foo"; + String selector = "headers.foo == 'bar'"; + + this.registry.registerSubscription(subscribeMessage(sessionId, subscriptionId, destination, selector)); + + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(); + accessor.setDestination(destination); + accessor.setNativeHeader("foo", "bar"); + Message message = MessageBuilder.createMessage("", accessor.getMessageHeaders()); + + MultiValueMap actual = this.registry.findSubscriptions(message); + assertEquals(1, actual.size()); + assertEquals(Arrays.asList(subscriptionId), actual.get(sessionId)); + + accessor = SimpMessageHeaderAccessor.create(); + accessor.setDestination(destination); + message = MessageBuilder.createMessage("", accessor.getMessageHeaders()); + + assertEquals(0, this.registry.findSubscriptions(message).size()); + } + @Test public void unregisterSubscription() { @@ -348,26 +374,33 @@ public class DefaultSubscriptionRegistryTests { } private Message subscribeMessage(String sessionId, String subscriptionId, String destination) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); - headers.setSessionId(sessionId); - headers.setSubscriptionId(subscriptionId); - if (destination != null) { - headers.setDestination(destination); + return subscribeMessage(sessionId, subscriptionId, destination, null); + } + + private Message subscribeMessage(String sessionId, String subId, String dest, String selector) { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); + accessor.setSessionId(sessionId); + accessor.setSubscriptionId(subId); + if (dest != null) { + accessor.setDestination(dest); + } + if (selector != null) { + accessor.setNativeHeader("selector", selector); } - return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build(); + return MessageBuilder.createMessage("", accessor.getMessageHeaders()); } - private Message unsubscribeMessage(String sessionId, String subscriptionId) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); - headers.setSessionId(sessionId); - headers.setSubscriptionId(subscriptionId); - return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build(); + private Message unsubscribeMessage(String sessionId, String subId) { + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); + accessor.setSessionId(sessionId); + accessor.setSubscriptionId(subId); + return MessageBuilder.createMessage("", accessor.getMessageHeaders()); } private Message message(String destination) { - SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(); - headers.setDestination(destination); - return MessageBuilder.withPayload("").copyHeaders(headers.toMap()).build(); + SimpMessageHeaderAccessor accessor = SimpMessageHeaderAccessor.create(); + accessor.setDestination(destination); + return MessageBuilder.createMessage("", accessor.getMessageHeaders()); } private List sort(List list) {