From ffac748f1c57c95d46acdb58f483a351906a694b Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Sat, 22 Mar 2014 17:06:45 -0400 Subject: [PATCH] Improve ConcurrentWebSocketSessionDecorator Before this change the decorator ensured that for a specific WebSocket session only one thread at a time can send a message. Other threads attempting to send would have their messages buffered and each time that occurs, a check is also made to see if the buffer limit has been reached or the send time limit has been exceeded and if so the session is closed. This change adds further protection to ensure only one thread at a time can perform the session limit checks and attempt to close the session. Furthermore if the session has timed out and become unresponsive, attempts to close it may block yet another thread. Taking this into consideration this change also ensures that state associated with the session is cleaned first before an attempt is made to close the session. Issue: SPR-11450 --- .../ConcurrentWebSocketSessionDecorator.java | 75 +++++++++++-------- .../SessionLimitExceededException.java | 34 +++++++++ .../messaging/StompSubProtocolHandler.java | 5 ++ .../SubProtocolWebSocketHandler.java | 21 +++++- ...currentWebSocketSessionDecoratorTests.java | 69 ++++++++++++++--- 5 files changed, 163 insertions(+), 41 deletions(-) create mode 100644 spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java index 0f989b1a44..a519f4581e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java @@ -34,7 +34,7 @@ import java.util.concurrent.locks.ReentrantLock; * only one thread can send messages at a time. * *

If a send is slow, subsequent attempts to send more messages from a different - * thread will fail to acquire the lock and the messages will be buffered instead -- + * thread will fail to acquire the flushLock and the messages will be buffered instead -- * at that time the specified buffer size limit and send time limit will be checked * and the session closed if the limits are exceeded. * @@ -46,23 +46,28 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat private static Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class); - private final int sendTimeLimit; - - private final int bufferSizeLimit; - private final Queue> buffer = new LinkedBlockingQueue>(); private final AtomicInteger bufferSize = new AtomicInteger(); + private final int bufferSizeLimit; + + private volatile long sendStartTime; - private final Lock lock = new ReentrantLock(); + private final int sendTimeLimit; + + + private volatile boolean sessionLimitExceeded; - public ConcurrentWebSocketSessionDecorator( - WebSocketSession delegateSession, int sendTimeLimit, int bufferSizeLimit) { + private final Lock flushLock = new ReentrantLock(); - super(delegateSession); + private final Lock closeLock = new ReentrantLock(); + + + public ConcurrentWebSocketSessionDecorator(WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit) { + super(delegate); this.sendTimeLimit = sendTimeLimit; this.bufferSizeLimit = bufferSizeLimit; } @@ -72,7 +77,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat return this.bufferSize.get(); } - public long getInProgressSendTime() { + public long getTimeSinceSendStarted() { long start = this.sendStartTime; return (start > 0 ? (System.currentTimeMillis() - start) : 0); } @@ -80,11 +85,20 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat public void sendMessage(WebSocketMessage message) throws IOException { + if (this.sessionLimitExceeded) { + return; + } + this.buffer.add(message); this.bufferSize.addAndGet(message.getPayloadLength()); do { if (!tryFlushMessageBuffer()) { + if (logger.isDebugEnabled()) { + logger.debug("Another send already in progress, session id '" + + getId() + "'" + ", in-progress send time " + getTimeSinceSendStarted() + + " (ms)" + ", buffer size " + this.bufferSize + " bytes"); + } checkSessionLimits(); break; } @@ -93,8 +107,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat } private boolean tryFlushMessageBuffer() throws IOException { - - if (this.lock.tryLock()) { + if (this.flushLock.tryLock() && !this.sessionLimitExceeded) { try { while (true) { WebSocketMessage messageToSend = this.buffer.poll(); @@ -109,34 +122,36 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat } finally { this.sendStartTime = 0; - lock.unlock(); + flushLock.unlock(); } return true; } - return false; } private void checkSessionLimits() throws IOException { - if (logger.isDebugEnabled()) { - logger.debug("Another send already in progress, session id '" + getId() + "'" + - ", in-progress send time " + getInProgressSendTime() + " (ms), " + - ", buffer size " + this.bufferSize + " bytes"); - } - if (getInProgressSendTime() > this.sendTimeLimit) { - logError("A message could not be sent due to a timeout"); - getDelegate().close(); - } - else if (this.bufferSize.get() > this.bufferSizeLimit) { - logError("The total send buffer byte count '" + this.bufferSize.get() + - "' for session '" + getId() + "' exceeds the allowed limit '" + this.bufferSizeLimit + "'"); - getDelegate().close(); + if (this.closeLock.tryLock() && !this.sessionLimitExceeded) { + try { + if (getTimeSinceSendStarted() > this.sendTimeLimit) { + sessionLimitReached( + "Message send time " + getTimeSinceSendStarted() + + " (ms) exceeded the allowed limit " + this.sendTimeLimit); + } + else if (this.bufferSize.get() > this.bufferSizeLimit) { + sessionLimitReached( + "The send buffer size " + this.bufferSize.get() + " bytes for " + + "session '" + getId() + " exceeded the allowed limit " + this.bufferSizeLimit); + } + } + finally { + this.closeLock.unlock(); + } } } - private void logError(String reason) { - logger.error(reason + ", number of buffered messages is '" + this.buffer.size() + - "', time since the last send started is '" + getInProgressSendTime() + "' (ms)"); + private void sessionLimitReached(String reason) { + this.sessionLimitExceeded = true; + throw new SessionLimitExceededException(reason); } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java new file mode 100644 index 0000000000..5f5bad5666 --- /dev/null +++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.web.socket.handler; + +/** + * Raised when a WebSocket session has exceeded limits it has been configured + * for, e.g. timeout, buffer size, etc. + * + * @author Rossen Stoyanchev + * @since 3.0.4 + */ +@SuppressWarnings("serial") +public class SessionLimitExceededException extends RuntimeException { + + + public SessionLimitExceededException(String message) { + super(message); + } + +} diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java index 7811b8f691..e9de1bbd34 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java @@ -43,6 +43,7 @@ import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; +import org.springframework.web.socket.handler.SessionLimitExceededException; /** * A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2 @@ -202,6 +203,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler { session.sendMessage(textMessage); } + catch (SessionLimitExceededException ex) { + // Bad session, just get out + throw ex; + } catch (Throwable ex) { sendErrorMessage(session, ex); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index 670544c580..21b10891a1 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -43,6 +43,7 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; +import org.springframework.web.socket.handler.SessionLimitExceededException; /** * An implementation of {@link WebSocketHandler} that delegates incoming WebSocket @@ -75,9 +76,9 @@ public class SubProtocolWebSocketHandler private final Map sessions = new ConcurrentHashMap(); - private int sendTimeLimit = 20 * 1000; + private int sendTimeLimit = 10 * 1000; - private int sendBufferSizeLimit = 1024 * 1024; + private int sendBufferSizeLimit = 64 * 1024; private Object lifecycleMonitor = new Object(); @@ -282,6 +283,18 @@ public class SubProtocolWebSocketHandler try { findProtocolHandler(session).handleMessageToClient(session, message); } + catch (SessionLimitExceededException e) { + try { + logger.error("Terminating session id '" + sessionId + "'", e); + + // Session may be unresponsive so clear first + clearSession(session, CloseStatus.NO_STATUS_CODE); + session.close(); + } + catch (Exception secondException) { + logger.error("Exception terminating session id '" + sessionId + "'", secondException); + } + } catch (Exception e) { logger.error("Failed to send message to client " + message, e); } @@ -309,6 +322,10 @@ public class SubProtocolWebSocketHandler @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { + clearSession(session, closeStatus); + } + + private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception { this.sessions.remove(session.getId()); findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java index 21413c32c6..6bfad1f910 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java @@ -29,6 +29,7 @@ import java.util.concurrent.atomic.AtomicReference; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * Unit tests for @@ -55,7 +56,7 @@ public class ConcurrentWebSocketSessionDecoratorTests { assertEquals(textMessage, session.getSentMessages().get(0)); assertEquals(0, concurrentSession.getBufferSize()); - assertEquals(0, concurrentSession.getInProgressSendTime()); + assertEquals(0, concurrentSession.getTimeSinceSendStarted()); assertTrue(session.isOpen()); } @@ -86,14 +87,14 @@ public class ConcurrentWebSocketSessionDecoratorTests { // ensure some send time elapses Thread.sleep(100); - assertTrue(concurrentSession.getInProgressSendTime() > 0); + assertTrue(concurrentSession.getTimeSinceSendStarted() > 0); TextMessage payload = new TextMessage("payload"); for (int i=0; i < 5; i++) { concurrentSession.sendMessage(payload); } - assertTrue(concurrentSession.getInProgressSendTime() > 0); + assertTrue(concurrentSession.getTimeSinceSendStarted() > 0); assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize()); assertTrue(blockingSession.isOpen()); } @@ -129,10 +130,13 @@ public class ConcurrentWebSocketSessionDecoratorTests { // ensure some send time elapses Thread.sleep(sendTimeLimit + 100); - TextMessage payload = new TextMessage("payload"); - concurrentSession.sendMessage(payload); - - assertFalse(blockingSession.isOpen()); + try { + TextMessage payload = new TextMessage("payload"); + concurrentSession.sendMessage(payload); + fail("Expected exception"); + } + catch (SessionLimitExceededException ex) { + } } @Test @@ -174,8 +178,12 @@ public class ConcurrentWebSocketSessionDecoratorTests { assertEquals(1023, concurrentSession.getBufferSize()); assertTrue(blockingSession.isOpen()); - concurrentSession.sendMessage(message); - assertFalse(blockingSession.isOpen()); + try { + concurrentSession.sendMessage(message); + fail("Expected exception"); + } + catch (SessionLimitExceededException ex) { + } } @@ -217,4 +225,47 @@ public class ConcurrentWebSocketSessionDecoratorTests { } } +// @Test +// public void sendSessionLimitException() throws IOException, InterruptedException { +// +// BlockingSession blockingSession = new BlockingSession(); +// blockingSession.setOpen(true); +// CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch(); +// +// int sendTimeLimit = 10 * 1000; +// int bufferSizeLimit = 1024; +// +// final ConcurrentWebSocketSessionDecorator concurrentSession = +// new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit); +// +// Executors.newSingleThreadExecutor().submit(new Runnable() { +// @Override +// public void run() { +// TextMessage textMessage = new TextMessage("slow message"); +// try { +// concurrentSession.sendMessage(textMessage); +// } +// catch (IOException e) { +// e.printStackTrace(); +// } +// } +// }); +// +// assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS)); +// +// StringBuilder sb = new StringBuilder(); +// for (int i=0 ; i < 1023; i++) { +// sb.append("a"); +// } +// +// TextMessage message = new TextMessage(sb.toString()); +// concurrentSession.sendMessage(message); +// +// assertEquals(1023, concurrentSession.getBufferSize()); +// assertTrue(blockingSession.isOpen()); +// +// concurrentSession.sendMessage(message); +// assertFalse(blockingSession.isOpen()); +// } + }