diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java
index a563da2796..194d939c36 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java
@@ -205,7 +205,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
Principal user = determineUser(request, wsHandler, attributes);
if (logger.isTraceEnabled()) {
- logger.trace("Upgrading to WebSocket");
+ logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions);
}
this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes);
return true;
@@ -300,18 +300,23 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
/**
* Filter the list of requested WebSocket extensions.
- *
By default all request extensions are returned. The WebSocket server will further
- * compare the requested extensions against the list of supported extensions and
- * return only the ones that are both requested and supported.
+ *
As of 4.1 the default implementation of this method filters the list to
+ * leave only extensions that are both requested and supported.
* @param request the current request
- * @param requested the list of extensions requested by the client
- * @param supported the list of extensions supported by the server
+ * @param requestedExtensions the list of extensions requested by the client
+ * @param supportedExtensions the list of extensions supported by the server
* @return the selected extensions or an empty list
*/
protected List filterRequestedExtensions(ServerHttpRequest request,
- List requested, List supported) {
+ List requestedExtensions, List supportedExtensions) {
- return requested;
+ List result = new ArrayList(requestedExtensions.size());
+ for (WebSocketExtension extension : requestedExtensions) {
+ if (supportedExtensions.contains(extension)) {
+ result.add(extension);
+ }
+ }
+ return result;
}
/**
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java
index 0d24ce0bdf..5f04cdae28 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java
@@ -60,7 +60,7 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
- when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" });
+ when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] {"13"});
this.servletRequest.setMethod("GET");
@@ -79,6 +79,33 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
"STOMP", Collections.emptyList(), null, handler, attributes);
}
+
+ @Test
+ public void supportedExtensions() throws Exception {
+
+ WebSocketExtension extension1 = new WebSocketExtension("ext1");
+ WebSocketExtension extension2 = new WebSocketExtension("ext2");
+
+ when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] {"13"});
+ when(this.upgradeStrategy.getSupportedExtensions(this.request)).thenReturn(Arrays.asList(extension1));
+
+ this.servletRequest.setMethod("GET");
+
+ WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
+ headers.setUpgrade("WebSocket");
+ headers.setConnection("Upgrade");
+ headers.setSecWebSocketVersion("13");
+ headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
+ headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
+
+ WebSocketHandler handler = new TextWebSocketHandler();
+ Map attributes = Collections.emptyMap();
+ this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
+
+ verify(this.upgradeStrategy).upgrade(this.request, this.response, null, Arrays.asList(extension1),
+ null, handler, attributes);
+ }
+
@Test
public void subProtocolCapableHandler() throws Exception {