Improve ConcurrentWebSocketSessionDecorator

Before this change the decorator ensured that for a specific WebSocket
session only one thread at a time can send a message. Other threads
attempting to send would have their messages buffered and each time
that occurs, a check is also made to see if the buffer limit has been
reached or the send time limit has been exceeded and if so the session
is closed.

This change adds further protection to ensure only one thread at a time
can perform the session limit checks and attempt to close the session.
Furthermore if the session has timed out and become unresponsive,
attempts to close it may block yet another thread. Taking this into
consideration this change also ensures that state associated with the
session is cleaned first before an attempt is made to close the session.

Issue: SPR-11450
master
Rossen Stoyanchev 11 years ago
parent 299be08268
commit ffac748f1c
  1. 75
      spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java
  2. 34
      spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java
  3. 5
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
  4. 21
      spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
  5. 69
      spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java

@ -34,7 +34,7 @@ import java.util.concurrent.locks.ReentrantLock;
* only one thread can send messages at a time.
*
* <p>If a send is slow, subsequent attempts to send more messages from a different
* thread will fail to acquire the lock and the messages will be buffered instead --
* thread will fail to acquire the flushLock and the messages will be buffered instead --
* at that time the specified buffer size limit and send time limit will be checked
* and the session closed if the limits are exceeded.
*
@ -46,23 +46,28 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
private static Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class);
private final int sendTimeLimit;
private final int bufferSizeLimit;
private final Queue<WebSocketMessage<?>> buffer = new LinkedBlockingQueue<WebSocketMessage<?>>();
private final AtomicInteger bufferSize = new AtomicInteger();
private final int bufferSizeLimit;
private volatile long sendStartTime;
private final Lock lock = new ReentrantLock();
private final int sendTimeLimit;
private volatile boolean sessionLimitExceeded;
public ConcurrentWebSocketSessionDecorator(
WebSocketSession delegateSession, int sendTimeLimit, int bufferSizeLimit) {
private final Lock flushLock = new ReentrantLock();
super(delegateSession);
private final Lock closeLock = new ReentrantLock();
public ConcurrentWebSocketSessionDecorator(WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit) {
super(delegate);
this.sendTimeLimit = sendTimeLimit;
this.bufferSizeLimit = bufferSizeLimit;
}
@ -72,7 +77,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
return this.bufferSize.get();
}
public long getInProgressSendTime() {
public long getTimeSinceSendStarted() {
long start = this.sendStartTime;
return (start > 0 ? (System.currentTimeMillis() - start) : 0);
}
@ -80,11 +85,20 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
public void sendMessage(WebSocketMessage<?> message) throws IOException {
if (this.sessionLimitExceeded) {
return;
}
this.buffer.add(message);
this.bufferSize.addAndGet(message.getPayloadLength());
do {
if (!tryFlushMessageBuffer()) {
if (logger.isDebugEnabled()) {
logger.debug("Another send already in progress, session id '" +
getId() + "'" + ", in-progress send time " + getTimeSinceSendStarted() +
" (ms)" + ", buffer size " + this.bufferSize + " bytes");
}
checkSessionLimits();
break;
}
@ -93,8 +107,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
}
private boolean tryFlushMessageBuffer() throws IOException {
if (this.lock.tryLock()) {
if (this.flushLock.tryLock() && !this.sessionLimitExceeded) {
try {
while (true) {
WebSocketMessage<?> messageToSend = this.buffer.poll();
@ -109,34 +122,36 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
}
finally {
this.sendStartTime = 0;
lock.unlock();
flushLock.unlock();
}
return true;
}
return false;
}
private void checkSessionLimits() throws IOException {
if (logger.isDebugEnabled()) {
logger.debug("Another send already in progress, session id '" + getId() + "'" +
", in-progress send time " + getInProgressSendTime() + " (ms), " +
", buffer size " + this.bufferSize + " bytes");
}
if (getInProgressSendTime() > this.sendTimeLimit) {
logError("A message could not be sent due to a timeout");
getDelegate().close();
}
else if (this.bufferSize.get() > this.bufferSizeLimit) {
logError("The total send buffer byte count '" + this.bufferSize.get() +
"' for session '" + getId() + "' exceeds the allowed limit '" + this.bufferSizeLimit + "'");
getDelegate().close();
if (this.closeLock.tryLock() && !this.sessionLimitExceeded) {
try {
if (getTimeSinceSendStarted() > this.sendTimeLimit) {
sessionLimitReached(
"Message send time " + getTimeSinceSendStarted() +
" (ms) exceeded the allowed limit " + this.sendTimeLimit);
}
else if (this.bufferSize.get() > this.bufferSizeLimit) {
sessionLimitReached(
"The send buffer size " + this.bufferSize.get() + " bytes for " +
"session '" + getId() + " exceeded the allowed limit " + this.bufferSizeLimit);
}
}
finally {
this.closeLock.unlock();
}
}
}
private void logError(String reason) {
logger.error(reason + ", number of buffered messages is '" + this.buffer.size() +
"', time since the last send started is '" + getInProgressSendTime() + "' (ms)");
private void sessionLimitReached(String reason) {
this.sessionLimitExceeded = true;
throw new SessionLimitExceededException(reason);
}
}

@ -0,0 +1,34 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.handler;
/**
* Raised when a WebSocket session has exceeded limits it has been configured
* for, e.g. timeout, buffer size, etc.
*
* @author Rossen Stoyanchev
* @since 3.0.4
*/
@SuppressWarnings("serial")
public class SessionLimitExceededException extends RuntimeException {
public SessionLimitExceededException(String message) {
super(message);
}
}

@ -43,6 +43,7 @@ import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.SessionLimitExceededException;
/**
* A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2
@ -202,6 +203,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler {
session.sendMessage(textMessage);
}
catch (SessionLimitExceededException ex) {
// Bad session, just get out
throw ex;
}
catch (Throwable ex) {
sendErrorMessage(session, ex);
}

@ -43,6 +43,7 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.SessionLimitExceededException;
/**
* An implementation of {@link WebSocketHandler} that delegates incoming WebSocket
@ -75,9 +76,9 @@ public class SubProtocolWebSocketHandler
private final Map<String, WebSocketSession> sessions = new ConcurrentHashMap<String, WebSocketSession>();
private int sendTimeLimit = 20 * 1000;
private int sendTimeLimit = 10 * 1000;
private int sendBufferSizeLimit = 1024 * 1024;
private int sendBufferSizeLimit = 64 * 1024;
private Object lifecycleMonitor = new Object();
@ -282,6 +283,18 @@ public class SubProtocolWebSocketHandler
try {
findProtocolHandler(session).handleMessageToClient(session, message);
}
catch (SessionLimitExceededException e) {
try {
logger.error("Terminating session id '" + sessionId + "'", e);
// Session may be unresponsive so clear first
clearSession(session, CloseStatus.NO_STATUS_CODE);
session.close();
}
catch (Exception secondException) {
logger.error("Exception terminating session id '" + sessionId + "'", secondException);
}
}
catch (Exception e) {
logger.error("Failed to send message to client " + message, e);
}
@ -309,6 +322,10 @@ public class SubProtocolWebSocketHandler
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
clearSession(session, closeStatus);
}
private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception {
this.sessions.remove(session.getId());
findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel);
}

@ -29,6 +29,7 @@ import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* Unit tests for
@ -55,7 +56,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
assertEquals(textMessage, session.getSentMessages().get(0));
assertEquals(0, concurrentSession.getBufferSize());
assertEquals(0, concurrentSession.getInProgressSendTime());
assertEquals(0, concurrentSession.getTimeSinceSendStarted());
assertTrue(session.isOpen());
}
@ -86,14 +87,14 @@ public class ConcurrentWebSocketSessionDecoratorTests {
// ensure some send time elapses
Thread.sleep(100);
assertTrue(concurrentSession.getInProgressSendTime() > 0);
assertTrue(concurrentSession.getTimeSinceSendStarted() > 0);
TextMessage payload = new TextMessage("payload");
for (int i=0; i < 5; i++) {
concurrentSession.sendMessage(payload);
}
assertTrue(concurrentSession.getInProgressSendTime() > 0);
assertTrue(concurrentSession.getTimeSinceSendStarted() > 0);
assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize());
assertTrue(blockingSession.isOpen());
}
@ -129,10 +130,13 @@ public class ConcurrentWebSocketSessionDecoratorTests {
// ensure some send time elapses
Thread.sleep(sendTimeLimit + 100);
TextMessage payload = new TextMessage("payload");
concurrentSession.sendMessage(payload);
assertFalse(blockingSession.isOpen());
try {
TextMessage payload = new TextMessage("payload");
concurrentSession.sendMessage(payload);
fail("Expected exception");
}
catch (SessionLimitExceededException ex) {
}
}
@Test
@ -174,8 +178,12 @@ public class ConcurrentWebSocketSessionDecoratorTests {
assertEquals(1023, concurrentSession.getBufferSize());
assertTrue(blockingSession.isOpen());
concurrentSession.sendMessage(message);
assertFalse(blockingSession.isOpen());
try {
concurrentSession.sendMessage(message);
fail("Expected exception");
}
catch (SessionLimitExceededException ex) {
}
}
@ -217,4 +225,47 @@ public class ConcurrentWebSocketSessionDecoratorTests {
}
}
// @Test
// public void sendSessionLimitException() throws IOException, InterruptedException {
//
// BlockingSession blockingSession = new BlockingSession();
// blockingSession.setOpen(true);
// CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch();
//
// int sendTimeLimit = 10 * 1000;
// int bufferSizeLimit = 1024;
//
// final ConcurrentWebSocketSessionDecorator concurrentSession =
// new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit);
//
// Executors.newSingleThreadExecutor().submit(new Runnable() {
// @Override
// public void run() {
// TextMessage textMessage = new TextMessage("slow message");
// try {
// concurrentSession.sendMessage(textMessage);
// }
// catch (IOException e) {
// e.printStackTrace();
// }
// }
// });
//
// assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS));
//
// StringBuilder sb = new StringBuilder();
// for (int i=0 ; i < 1023; i++) {
// sb.append("a");
// }
//
// TextMessage message = new TextMessage(sb.toString());
// concurrentSession.sendMessage(message);
//
// assertEquals(1023, concurrentSession.getBufferSize());
// assertTrue(blockingSession.isOpen());
//
// concurrentSession.sendMessage(message);
// assertFalse(blockingSession.isOpen());
// }
}

Loading…
Cancel
Save