Filter WebSocket extensions

Before this change the DefaultHandshakeHandler by default passed the
list of requested WebSocket extensions as-is relying on the WebSocket
engine to remove those not supported.

This change ensures that WebSocket extensions not supported by the
runtime are proactively removed instead.

This change is preparation for SPR-11094.
master
Rossen Stoyanchev 10 years ago
parent 78484129f5
commit e74ac06733
  1. 21
      spring-websocket/src/main/java/org/springframework/web/socket/server/support/DefaultHandshakeHandler.java
  2. 29
      spring-websocket/src/test/java/org/springframework/web/socket/server/DefaultHandshakeHandlerTests.java

@ -205,7 +205,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
Principal user = determineUser(request, wsHandler, attributes); Principal user = determineUser(request, wsHandler, attributes);
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("Upgrading to WebSocket"); logger.trace("Upgrading to WebSocket, subProtocol=" + subProtocol + ", extensions=" + extensions);
} }
this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes); this.requestUpgradeStrategy.upgrade(request, response, subProtocol, extensions, user, wsHandler, attributes);
return true; return true;
@ -300,18 +300,23 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
/** /**
* Filter the list of requested WebSocket extensions. * Filter the list of requested WebSocket extensions.
* <p>By default all request extensions are returned. The WebSocket server will further * <p>As of 4.1 the default implementation of this method filters the list to
* compare the requested extensions against the list of supported extensions and * leave only extensions that are both requested and supported.
* return only the ones that are both requested and supported.
* @param request the current request * @param request the current request
* @param requested the list of extensions requested by the client * @param requestedExtensions the list of extensions requested by the client
* @param supported the list of extensions supported by the server * @param supportedExtensions the list of extensions supported by the server
* @return the selected extensions or an empty list * @return the selected extensions or an empty list
*/ */
protected List<WebSocketExtension> filterRequestedExtensions(ServerHttpRequest request, protected List<WebSocketExtension> filterRequestedExtensions(ServerHttpRequest request,
List<WebSocketExtension> requested, List<WebSocketExtension> supported) { List<WebSocketExtension> requestedExtensions, List<WebSocketExtension> supportedExtensions) {
return requested; List<WebSocketExtension> result = new ArrayList<WebSocketExtension>(requestedExtensions.size());
for (WebSocketExtension extension : requestedExtensions) {
if (supportedExtensions.contains(extension)) {
result.add(extension);
}
}
return result;
} }
/** /**

@ -60,7 +60,7 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt"); this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] { "13" }); when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] {"13"});
this.servletRequest.setMethod("GET"); this.servletRequest.setMethod("GET");
@ -79,6 +79,33 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
"STOMP", Collections.<WebSocketExtension>emptyList(), null, handler, attributes); "STOMP", Collections.<WebSocketExtension>emptyList(), null, handler, attributes);
} }
@Test
public void supportedExtensions() throws Exception {
WebSocketExtension extension1 = new WebSocketExtension("ext1");
WebSocketExtension extension2 = new WebSocketExtension("ext2");
when(this.upgradeStrategy.getSupportedVersions()).thenReturn(new String[] {"13"});
when(this.upgradeStrategy.getSupportedExtensions(this.request)).thenReturn(Arrays.asList(extension1));
this.servletRequest.setMethod("GET");
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
headers.setUpgrade("WebSocket");
headers.setConnection("Upgrade");
headers.setSecWebSocketVersion("13");
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
WebSocketHandler handler = new TextWebSocketHandler();
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
verify(this.upgradeStrategy).upgrade(this.request, this.response, null, Arrays.asList(extension1),
null, handler, attributes);
}
@Test @Test
public void subProtocolCapableHandler() throws Exception { public void subProtocolCapableHandler() throws Exception {

Loading…
Cancel
Save