diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java index 0f989b1a44..a519f4581e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java @@ -34,7 +34,7 @@ import java.util.concurrent.locks.ReentrantLock; * only one thread can send messages at a time. * *

If a send is slow, subsequent attempts to send more messages from a different - * thread will fail to acquire the lock and the messages will be buffered instead -- + * thread will fail to acquire the flushLock and the messages will be buffered instead -- * at that time the specified buffer size limit and send time limit will be checked * and the session closed if the limits are exceeded. * @@ -46,23 +46,28 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat private static Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class); - private final int sendTimeLimit; - - private final int bufferSizeLimit; - private final Queue> buffer = new LinkedBlockingQueue>(); private final AtomicInteger bufferSize = new AtomicInteger(); + private final int bufferSizeLimit; + + private volatile long sendStartTime; - private final Lock lock = new ReentrantLock(); + private final int sendTimeLimit; + + + private volatile boolean sessionLimitExceeded; - public ConcurrentWebSocketSessionDecorator( - WebSocketSession delegateSession, int sendTimeLimit, int bufferSizeLimit) { + private final Lock flushLock = new ReentrantLock(); - super(delegateSession); + private final Lock closeLock = new ReentrantLock(); + + + public ConcurrentWebSocketSessionDecorator(WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit) { + super(delegate); this.sendTimeLimit = sendTimeLimit; this.bufferSizeLimit = bufferSizeLimit; } @@ -72,7 +77,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat return this.bufferSize.get(); } - public long getInProgressSendTime() { + public long getTimeSinceSendStarted() { long start = this.sendStartTime; return (start > 0 ? (System.currentTimeMillis() - start) : 0); } @@ -80,11 +85,20 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat public void sendMessage(WebSocketMessage message) throws IOException { + if (this.sessionLimitExceeded) { + return; + } + this.buffer.add(message); this.bufferSize.addAndGet(message.getPayloadLength()); do { if (!tryFlushMessageBuffer()) { + if (logger.isDebugEnabled()) { + logger.debug("Another send already in progress, session id '" + + getId() + "'" + ", in-progress send time " + getTimeSinceSendStarted() + + " (ms)" + ", buffer size " + this.bufferSize + " bytes"); + } checkSessionLimits(); break; } @@ -93,8 +107,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat } private boolean tryFlushMessageBuffer() throws IOException { - - if (this.lock.tryLock()) { + if (this.flushLock.tryLock() && !this.sessionLimitExceeded) { try { while (true) { WebSocketMessage messageToSend = this.buffer.poll(); @@ -109,34 +122,36 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat } finally { this.sendStartTime = 0; - lock.unlock(); + flushLock.unlock(); } return true; } - return false; } private void checkSessionLimits() throws IOException { - if (logger.isDebugEnabled()) { - logger.debug("Another send already in progress, session id '" + getId() + "'" + - ", in-progress send time " + getInProgressSendTime() + " (ms), " + - ", buffer size " + this.bufferSize + " bytes"); - } - if (getInProgressSendTime() > this.sendTimeLimit) { - logError("A message could not be sent due to a timeout"); - getDelegate().close(); - } - else if (this.bufferSize.get() > this.bufferSizeLimit) { - logError("The total send buffer byte count '" + this.bufferSize.get() + - "' for session '" + getId() + "' exceeds the allowed limit '" + this.bufferSizeLimit + "'"); - getDelegate().close(); + if (this.closeLock.tryLock() && !this.sessionLimitExceeded) { + try { + if (getTimeSinceSendStarted() > this.sendTimeLimit) { + sessionLimitReached( + "Message send time " + getTimeSinceSendStarted() + + " (ms) exceeded the allowed limit " + this.sendTimeLimit); + } + else if (this.bufferSize.get() > this.bufferSizeLimit) { + sessionLimitReached( + "The send buffer size " + this.bufferSize.get() + " bytes for " + + "session '" + getId() + " exceeded the allowed limit " + this.bufferSizeLimit); + } + } + finally { + this.closeLock.unlock(); + } } } - private void logError(String reason) { - logger.error(reason + ", number of buffered messages is '" + this.buffer.size() + - "', time since the last send started is '" + getInProgressSendTime() + "' (ms)"); + private void sessionLimitReached(String reason) { + this.sessionLimitExceeded = true; + throw new SessionLimitExceededException(reason); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java new file mode 100644 index 0000000000..5f5bad5666 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.handler; + +/** + * Raised when a WebSocket session has exceeded limits it has been configured + * for, e.g. timeout, buffer size, etc. + * + * @author Rossen Stoyanchev + * @since 3.0.4 + */ +@SuppressWarnings("serial") +public class SessionLimitExceededException extends RuntimeException { + + + public SessionLimitExceededException(String message) { + super(message); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 7811b8f691..e9de1bbd34 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -43,6 +43,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.SessionLimitExceededException; /** * A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2 @@ -202,6 +203,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler { session.sendMessage(textMessage); } + catch (SessionLimitExceededException ex) { + // Bad session, just get out + throw ex; + } catch (Throwable ex) { sendErrorMessage(session, ex); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index 670544c580..21b10891a1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -43,6 +43,7 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; +import org.springframework.web.socket.handler.SessionLimitExceededException; /** * An implementation of {@link WebSocketHandler} that delegates incoming WebSocket @@ -75,9 +76,9 @@ public class SubProtocolWebSocketHandler private final Map sessions = new ConcurrentHashMap(); - private int sendTimeLimit = 20 * 1000; + private int sendTimeLimit = 10 * 1000; - private int sendBufferSizeLimit = 1024 * 1024; + private int sendBufferSizeLimit = 64 * 1024; private Object lifecycleMonitor = new Object(); @@ -282,6 +283,18 @@ public class SubProtocolWebSocketHandler try { findProtocolHandler(session).handleMessageToClient(session, message); } + catch (SessionLimitExceededException e) { + try { + logger.error("Terminating session id '" + sessionId + "'", e); + + // Session may be unresponsive so clear first + clearSession(session, CloseStatus.NO_STATUS_CODE); + session.close(); + } + catch (Exception secondException) { + logger.error("Exception terminating session id '" + sessionId + "'", secondException); + } + } catch (Exception e) { logger.error("Failed to send message to client " + message, e); } @@ -309,6 +322,10 @@ public class SubProtocolWebSocketHandler @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { + clearSession(session, closeStatus); + } + + private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception { this.sessions.remove(session.getId()); findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java index 21413c32c6..6bfad1f910 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java @@ -29,6 +29,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * Unit tests for @@ -55,7 +56,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { assertEquals(textMessage, session.getSentMessages().get(0)); assertEquals(0, concurrentSession.getBufferSize()); - assertEquals(0, concurrentSession.getInProgressSendTime()); + assertEquals(0, concurrentSession.getTimeSinceSendStarted()); assertTrue(session.isOpen()); } @@ -86,14 +87,14 @@ public class ConcurrentWebSocketSessionDecoratorTests { // ensure some send time elapses Thread.sleep(100); - assertTrue(concurrentSession.getInProgressSendTime() > 0); + assertTrue(concurrentSession.getTimeSinceSendStarted() > 0); TextMessage payload = new TextMessage("payload"); for (int i=0; i < 5; i++) { concurrentSession.sendMessage(payload); } - assertTrue(concurrentSession.getInProgressSendTime() > 0); + assertTrue(concurrentSession.getTimeSinceSendStarted() > 0); assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize()); assertTrue(blockingSession.isOpen()); } @@ -129,10 +130,13 @@ public class ConcurrentWebSocketSessionDecoratorTests { // ensure some send time elapses Thread.sleep(sendTimeLimit + 100); - TextMessage payload = new TextMessage("payload"); - concurrentSession.sendMessage(payload); - - assertFalse(blockingSession.isOpen()); + try { + TextMessage payload = new TextMessage("payload"); + concurrentSession.sendMessage(payload); + fail("Expected exception"); + } + catch (SessionLimitExceededException ex) { + } } @Test @@ -174,8 +178,12 @@ public class ConcurrentWebSocketSessionDecoratorTests { assertEquals(1023, concurrentSession.getBufferSize()); assertTrue(blockingSession.isOpen()); - concurrentSession.sendMessage(message); - assertFalse(blockingSession.isOpen()); + try { + concurrentSession.sendMessage(message); + fail("Expected exception"); + } + catch (SessionLimitExceededException ex) { + } } @@ -217,4 +225,47 @@ public class ConcurrentWebSocketSessionDecoratorTests { } } +// @Test +// public void sendSessionLimitException() throws IOException, InterruptedException { +// +// BlockingSession blockingSession = new BlockingSession(); +// blockingSession.setOpen(true); +// CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); +// +// int sendTimeLimit = 10 * 1000; +// int bufferSizeLimit = 1024; +// +// final ConcurrentWebSocketSessionDecorator concurrentSession = +// new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); +// +// Executors.newSingleThreadExecutor().submit(new Runnable() { +// @Override +// public void run() { +// TextMessage textMessage = new TextMessage("slow message"); +// try { +// concurrentSession.sendMessage(textMessage); +// } +// catch (IOException e) { +// e.printStackTrace(); +// } +// } +// }); +// +// assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); +// +// StringBuilder sb = new StringBuilder(); +// for (int i=0 ; i < 1023; i++) { +// sb.append("a"); +// } +// +// TextMessage message = new TextMessage(sb.toString()); +// concurrentSession.sendMessage(message); +// +// assertEquals(1023, concurrentSession.getBufferSize()); +// assertTrue(blockingSession.isOpen()); +// +// concurrentSession.sendMessage(message); +// assertFalse(blockingSession.isOpen()); +// } + }