From 1f990c3df62b2ea26e44e93af547b074b0efc18d Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 18 Feb 2015 11:36:22 -0500 Subject: [PATCH] Fix handling of empty payload Pong message on Jetty Issue: SPR-12727 --- .../jetty/JettyWebSocketHandlerAdapter.java | 11 ++- .../web/socket/WebSocketIntegrationTests.java | 75 ++++++++++++++++++- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapter.java b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapter.java index 6eb7326425..1972184f00 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapter.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/adapter/jetty/JettyWebSocketHandlerAdapter.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package org.springframework.web.socket.adapter.jetty; +import java.nio.ByteBuffer; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.eclipse.jetty.websocket.api.Session; @@ -36,6 +38,7 @@ import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator; + /** * Adapts {@link WebSocketHandler} to the Jetty 9 WebSocket API. * @@ -45,8 +48,11 @@ import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator @WebSocket public class JettyWebSocketHandlerAdapter { + private static final ByteBuffer EMPTY_PAYLOAD = ByteBuffer.wrap(new byte[0]); + private static final Log logger = LogFactory.getLog(JettyWebSocketHandlerAdapter.class); + private final WebSocketHandler webSocketHandler; private final JettyWebSocketSession wsSession; @@ -96,7 +102,8 @@ public class JettyWebSocketHandlerAdapter { @OnWebSocketFrame public void onWebSocketFrame(Frame frame) { if (OpCode.PONG == frame.getOpCode()) { - PongMessage message = new PongMessage(frame.getPayload()); + ByteBuffer payload = frame.getPayload() != null ? frame.getPayload() : EMPTY_PAYLOAD; + PongMessage message = new PongMessage(payload); try { this.webSocketHandler.handleMessage(this.wsSession, message); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java index 4598394a63..9b1086a780 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/WebSocketIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,14 @@ package org.springframework.web.socket; +import static org.junit.Assert.*; + import java.net.URI; +import java.util.ArrayList; import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,10 +38,10 @@ import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.config.annotation.EnableWebSocket; import org.springframework.web.socket.config.annotation.WebSocketConfigurer; import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; +import org.springframework.web.socket.handler.AbstractWebSocketHandler; import org.springframework.web.socket.handler.TextWebSocketHandler; import org.springframework.web.socket.server.support.DefaultHandshakeHandler; -import static org.junit.Assert.*; /** * Client and server-side WebSocket integration tests. @@ -67,6 +73,24 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest URI url = new URI(getWsBaseUrl() + "/ws"); WebSocketSession session = this.webSocketClient.doHandshake(new TextWebSocketHandler(), headers, url).get(); assertEquals("foo", session.getAcceptedProtocol()); + session.close(); + } + + // SPR-12727 + + @Test + public void unsolicitedPongWithEmptyPayload() throws Exception { + TestWebSocketHandler serverHandler = this.wac.getBean(TestWebSocketHandler.class); + serverHandler.setWaitMessageCount(1); + + String url = getWsBaseUrl() + "/ws"; + WebSocketSession session = this.webSocketClient.doHandshake(new AbstractWebSocketHandler() {}, url).get(); + session.sendMessage(new PongMessage()); + + serverHandler.await(); + assertNull(serverHandler.getTransportError()); + assertEquals(1, serverHandler.getReceivedMessages().size()); + assertEquals(PongMessage.class, serverHandler.getReceivedMessages().get(0).getClass()); } @@ -84,8 +108,51 @@ public class WebSocketIntegrationTests extends AbstractWebSocketIntegrationTest } @Bean - public TextWebSocketHandler handler() { - return new TextWebSocketHandler(); + public TestWebSocketHandler handler() { + return new TestWebSocketHandler(); + } + + } + + private static class TestWebSocketHandler extends AbstractWebSocketHandler { + + private List receivedMessages = new ArrayList<>(); + + private int waitMessageCount; + + private final CountDownLatch latch = new CountDownLatch(1); + + private Throwable transportError; + + + public void setWaitMessageCount(int waitMessageCount) { + this.waitMessageCount = waitMessageCount; + } + + public List getReceivedMessages() { + return this.receivedMessages; + } + + public Throwable getTransportError() { + return this.transportError; + } + + @Override + public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { + this.receivedMessages.add(message); + if (this.receivedMessages.size() >= this.waitMessageCount) { + this.latch.countDown(); + } + } + + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { + this.transportError = exception; + this.latch.countDown(); + } + + public void await() throws InterruptedException { + this.latch.await(5, TimeUnit.SECONDS); } }