From e74ac06733cccf5156b5f6f497f02851716f4fbc Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 23 Jul 2014 22:40:38 -0400 Subject: [PATCH] Filter WebSocket extensions Before this change the DefaultHandshakeHandler by default passed the list of requested WebSocket extensions as-is relying on the WebSocket engine to remove those not supported. This change ensures that WebSocket extensions not supported by the runtime are proactively removed instead. This change is preparation for SPR-11094. --- .../support/DefaultHandshakeHandler.java | 21 +++++++++----- .../server/DefaultHandshakeHandlerTests.java | 29 ++++++++++++++++++- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java index a563da2796..194d939c36 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java @@ -205,7 +205,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler { Principal user = determineUser(request, wsHandler, attributes); if (logger.isTraceEnabled()) { - logger.trace("Upgrading to WebSocket"); + logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions); } this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes); return true; @@ -300,18 +300,23 @@ public class DefaultHandshakeHandler implements HandshakeHandler { /** * Filter the list of requested WebSocket extensions. - *

By default all request extensions are returned. The WebSocket server will further - * compare the requested extensions against the list of supported extensions and - * return only the ones that are both requested and supported. + *

As of 4.1 the default implementation of this method filters the list to + * leave only extensions that are both requested and supported. * @param request the current request - * @param requested the list of extensions requested by the client - * @param supported the list of extensions supported by the server + * @param requestedExtensions the list of extensions requested by the client + * @param supportedExtensions the list of extensions supported by the server * @return the selected extensions or an empty list */ protected List filterRequestedExtensions(ServerHttpRequest request, - List requested, List supported) { + List requestedExtensions, List supportedExtensions) { - return requested; + List result = new ArrayList(requestedExtensions.size()); + for (WebSocketExtension extension : requestedExtensions) { + if (supportedExtensions.contains(extension)) { + result.add(extension); + } + } + return result; } /** diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java index 0d24ce0bdf..5f04cdae28 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java @@ -60,7 +60,7 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { this.handshakeHandler.setSupportedProtocols("stomp", "mqtt"); - when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" }); + when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] {"13"}); this.servletRequest.setMethod("GET"); @@ -79,6 +79,33 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests { "STOMP", Collections.emptyList(), null, handler, attributes); } + + @Test + public void supportedExtensions() throws Exception { + + WebSocketExtension extension1 = new WebSocketExtension("ext1"); + WebSocketExtension extension2 = new WebSocketExtension("ext2"); + + when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] {"13"}); + when(this.upgradeStrategy.getSupportedExtensions(this.request)).thenReturn(Arrays.asList(extension1)); + + this.servletRequest.setMethod("GET"); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders()); + headers.setUpgrade("WebSocket"); + headers.setConnection("Upgrade"); + headers.setSecWebSocketVersion("13"); + headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw=="); + headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2)); + + WebSocketHandler handler = new TextWebSocketHandler(); + Map attributes = Collections.emptyMap(); + this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes); + + verify(this.upgradeStrategy).upgrade(this.request, this.response, null, Arrays.asList(extension1), + null, handler, attributes); + } + @Test public void subProtocolCapableHandler() throws Exception {