diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java index 6027fc3c90..4785d16109 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParser.java @@ -263,12 +263,15 @@ class MessageBrokerBeanDefinitionParser implements BeanDefinitionParser { } else { RuntimeBeanReference handshakeHandler = WebSocketNamespaceUtils.registerHandshakeHandler(element, context, source); + Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors"); + ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); ConstructorArgumentValues cavs = new ConstructorArgumentValues(); cavs.addIndexedArgumentValue(0, subProtoHandler); if (handshakeHandler != null) { cavs.addIndexedArgumentValue(1, handshakeHandler); } beanDef = new RootBeanDefinition(WebSocketHttpRequestHandler.class, cavs, null); + beanDef.getPropertyValues().add("handshakeInterceptors", interceptors); } return new RuntimeBeanReference(registerBeanDef(beanDef, context, source)); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java index bc7d3b804b..bc1763cce1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/config/WebSocketNamespaceUtils.java @@ -96,7 +96,11 @@ class WebSocketNamespaceUtils { sockJsServiceDef.getConstructorArgumentValues().addIndexedArgumentValue(1, transportHandler); } - String attrValue = sockJsElement.getAttribute("name"); + Element interceptorsElement = DomUtils.getChildElementByTagName(element, "handshake-interceptors"); + ManagedList interceptors = WebSocketNamespaceUtils.parseBeanSubElements(interceptorsElement, context); + sockJsServiceDef.getPropertyValues().add("handshakeInterceptors", interceptors); + + String attrValue = sockJsElement.getAttribute("name"); if (!attrValue.isEmpty()) { sockJsServiceDef.getPropertyValues().add("name", attrValue); } diff --git a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.0.xsd b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.0.xsd index 1c355d1af2..d69f5006a0 100644 --- a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.0.xsd +++ b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.0.xsd @@ -589,6 +589,7 @@ + diff --git a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd index 71a1f76b24..447928b262 100644 --- a/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd +++ b/spring-websocket/src/main/resources/org/springframework/web/socket/config/spring-websocket-4.1.xsd @@ -590,6 +590,7 @@ + diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java index 163ddcab90..0f4a02e345 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/HandlersBeanDefinitionParserTests.java @@ -48,7 +48,9 @@ import org.springframework.web.socket.server.support.DefaultHandshakeHandler; import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler; import org.springframework.web.socket.sockjs.SockJsService; import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler; +import org.springframework.web.socket.sockjs.transport.TransportHandler; import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService; +import org.springframework.web.socket.sockjs.transport.TransportType; import org.springframework.web.socket.sockjs.transport.handler.DefaultSockJsService; import org.springframework.web.socket.sockjs.transport.handler.EventSourceTransportHandler; import org.springframework.web.socket.sockjs.transport.handler.HtmlFileTransportHandler; @@ -79,8 +81,8 @@ public class HandlersBeanDefinitionParserTests { this.appContext = new GenericWebApplicationContext(); } - @Test + @Test public void webSocketHandlers() { loadBeanDefinitions("websocket-config-handlers.xml"); @@ -132,7 +134,6 @@ public class HandlersBeanDefinitionParserTests { assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof TestHandshakeHandler); List interceptors = handler.getHandshakeInterceptors(); - assertNotNull(interceptors); assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); handler = (WebSocketHttpRequestHandler) urlHandlerMapping.getUrlMap().get("/test"); @@ -142,7 +143,6 @@ public class HandlersBeanDefinitionParserTests { assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof TestHandshakeHandler); interceptors = handler.getHandshakeInterceptors(); - assertNotNull(interceptors); assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); } @@ -171,7 +171,9 @@ public class HandlersBeanDefinitionParserTests { assertThat(sockJsService, instanceOf(DefaultSockJsService.class)); DefaultSockJsService defaultSockJsService = (DefaultSockJsService) sockJsService; assertThat(defaultSockJsService.getTaskScheduler(), instanceOf(ThreadPoolTaskScheduler.class)); - assertThat(defaultSockJsService.getTransportHandlers().values(), + + Map transportHandlers = defaultSockJsService.getTransportHandlers(); + assertThat(transportHandlers.values(), containsInAnyOrder( instanceOf(XhrPollingTransportHandler.class), instanceOf(XhrReceivingTransportHandler.class), @@ -181,6 +183,12 @@ public class HandlersBeanDefinitionParserTests { instanceOf(EventSourceTransportHandler.class), instanceOf(HtmlFileTransportHandler.class), instanceOf(WebSocketTransportHandler.class))); + + WebSocketTransportHandler handler = (WebSocketTransportHandler) transportHandlers.get(TransportType.WEBSOCKET); + assertEquals(TestHandshakeHandler.class, handler.getHandshakeHandler().getClass()); + + List interceptors = defaultSockJsService.getHandshakeInterceptors(); + assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); } @Test diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java index 8066de2544..3cd52748bb 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/config/MessageBrokerBeanDefinitionParserTests.java @@ -75,16 +75,19 @@ import org.springframework.web.socket.sockjs.transport.handler.WebSocketTranspor * * @author Brian Clozel * @author Artem Bilan + * @author Rossen Stoyanchev */ public class MessageBrokerBeanDefinitionParserTests { private GenericWebApplicationContext appContext; + @Before public void setup() { this.appContext = new GenericWebApplicationContext(); } + @Test public void simpleBroker() { loadBeanDefinitions("websocket-config-broker-simple.xml"); @@ -103,6 +106,8 @@ public class MessageBrokerBeanDefinitionParserTests { HandshakeHandler handshakeHandler = wsHttpRequestHandler.getHandshakeHandler(); assertNotNull(handshakeHandler); assertTrue(handshakeHandler instanceof TestHandshakeHandler); + List interceptors = wsHttpRequestHandler.getHandshakeInterceptors(); + assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); WebSocketHandler wsHandler = unwrapWebSocketHandler(wsHttpRequestHandler.getWebSocketHandler()); assertNotNull(wsHandler); @@ -140,6 +145,9 @@ public class MessageBrokerBeanDefinitionParserTests { assertEquals(Runtime.getRuntime().availableProcessors(), scheduler.getScheduledThreadPoolExecutor().getCorePoolSize()); assertTrue(scheduler.getScheduledThreadPoolExecutor().getRemoveOnCancelPolicy()); + interceptors = defaultSockJsService.getHandshakeInterceptors(); + assertThat(interceptors, contains(instanceOf(FooTestInterceptor.class), instanceOf(BarTestInterceptor.class))); + UserSessionRegistry userSessionRegistry = this.appContext.getBean(UserSessionRegistry.class); assertNotNull(userSessionRegistry); diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml index d89a7f4223..4f700be019 100644 --- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-broker-simple.xml @@ -13,10 +13,18 @@ + + + + + + + + @@ -29,5 +37,6 @@ + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs-attributes.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs-attributes.xml index 8e7ab6cd45..4a2476c45c 100644 --- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs-attributes.xml +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs-attributes.xml @@ -17,13 +17,8 @@ - - - - - + + + diff --git a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs.xml b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs.xml index 689cddd19b..92c167772b 100644 --- a/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs.xml +++ b/spring-websocket/src/test/resources/org/springframework/web/socket/config/websocket-config-handlers-sockjs.xml @@ -8,10 +8,17 @@ + + + + + + +