diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java index a563da2796..194d939c36 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java @@ -205,7 +205,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { Principal user = determineUser(request, wsHandler, attributes); if (logger.isTraceEnabled()) { - logger.trace("Upgrading to WebSocket"); + logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions); } this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes); return true; @@ -300,18 +300,23 @@ public class DefaultHandshakeHandler implements HandshakeHandler { /** * Filter the list of requested WebSocket extensions. - *

By default all request extensions are returned. The WebSocket server will further - * compare the requested extensions against the list of supported extensions and - * return only the ones that are both requested and supported. + *

As of 4.1 the default implementation of this method filters the list to + * leave only extensions that are both requested and supported. * @param request the current request - * @param requested the list of extensions requested by the client - * @param supported the list of extensions supported by the server + * @param requestedExtensions the list of extensions requested by the client + * @param supportedExtensions the list of extensions supported by the server * @return the selected extensions or an empty list */ protected List filterRequestedExtensions(ServerHttpRequest request, - List requested, List supported) { + List requestedExtensions, List supportedExtensions) { - return requested; + List result = new ArrayList(requestedExtensions.size()); + for (WebSocketExtension extension : requestedExtensions) { + if (supportedExtensions.contains(extension)) { + result.add(extension); + } + } + return result; } /** diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java index 0d24ce0bdf..5f04cdae28 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -60,7 +60,7 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { this.handshakeHandler.setSupportedProtocols("stomp", "mqtt"); - when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" }); + when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] {"13"}); this.servletRequest.setMethod("GET"); @@ -79,6 +79,33 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { "STOMP", Collections.emptyList(), null, handler, attributes); } + + @Test + public void supportedExtensions() throws Exception { + + WebSocketExtension extension1 = new WebSocketExtension("ext1"); + WebSocketExtension extension2 = new WebSocketExtension("ext2"); + + when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] {"13"}); + when(this.upgradeStrategy.getSupportedExtensions(this.request)).thenReturn(Arrays.asList(extension1)); + + this.servletRequest.setMethod("GET"); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); + headers.setUpgrade("WebSocket"); + headers.setConnection("Upgrade"); + headers.setSecWebSocketVersion("13"); + headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); + headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2)); + + WebSocketHandler handler = new TextWebSocketHandler(); + Map attributes = Collections.emptyMap(); + this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); + + verify(this.upgradeStrategy).upgrade(this.request, this.response, null, Arrays.asList(extension1), + null, handler, attributes); + } + @Test public void subProtocolCapableHandler() throws Exception {