diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java
index 0f989b1a44..a519f4581e 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecorator.java
@@ -34,7 +34,7 @@ import java.util.concurrent.locks.ReentrantLock;
* only one thread can send messages at a time.
*
*
If a send is slow, subsequent attempts to send more messages from a different
- * thread will fail to acquire the lock and the messages will be buffered instead --
+ * thread will fail to acquire the flushLock and the messages will be buffered instead --
* at that time the specified buffer size limit and send time limit will be checked
* and the session closed if the limits are exceeded.
*
@@ -46,23 +46,28 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
private static Log logger = LogFactory.getLog(ConcurrentWebSocketSessionDecorator.class);
- private final int sendTimeLimit;
-
- private final int bufferSizeLimit;
-
private final Queue> buffer = new LinkedBlockingQueue>();
private final AtomicInteger bufferSize = new AtomicInteger();
+ private final int bufferSizeLimit;
+
+
private volatile long sendStartTime;
- private final Lock lock = new ReentrantLock();
+ private final int sendTimeLimit;
+
+
+ private volatile boolean sessionLimitExceeded;
- public ConcurrentWebSocketSessionDecorator(
- WebSocketSession delegateSession, int sendTimeLimit, int bufferSizeLimit) {
+ private final Lock flushLock = new ReentrantLock();
- super(delegateSession);
+ private final Lock closeLock = new ReentrantLock();
+
+
+ public ConcurrentWebSocketSessionDecorator(WebSocketSession delegate, int sendTimeLimit, int bufferSizeLimit) {
+ super(delegate);
this.sendTimeLimit = sendTimeLimit;
this.bufferSizeLimit = bufferSizeLimit;
}
@@ -72,7 +77,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
return this.bufferSize.get();
}
- public long getInProgressSendTime() {
+ public long getTimeSinceSendStarted() {
long start = this.sendStartTime;
return (start > 0 ? (System.currentTimeMillis() - start) : 0);
}
@@ -80,11 +85,20 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
public void sendMessage(WebSocketMessage> message) throws IOException {
+ if (this.sessionLimitExceeded) {
+ return;
+ }
+
this.buffer.add(message);
this.bufferSize.addAndGet(message.getPayloadLength());
do {
if (!tryFlushMessageBuffer()) {
+ if (logger.isDebugEnabled()) {
+ logger.debug("Another send already in progress, session id '" +
+ getId() + "'" + ", in-progress send time " + getTimeSinceSendStarted() +
+ " (ms)" + ", buffer size " + this.bufferSize + " bytes");
+ }
checkSessionLimits();
break;
}
@@ -93,8 +107,7 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
}
private boolean tryFlushMessageBuffer() throws IOException {
-
- if (this.lock.tryLock()) {
+ if (this.flushLock.tryLock() && !this.sessionLimitExceeded) {
try {
while (true) {
WebSocketMessage> messageToSend = this.buffer.poll();
@@ -109,34 +122,36 @@ public class ConcurrentWebSocketSessionDecorator extends WebSocketSessionDecorat
}
finally {
this.sendStartTime = 0;
- lock.unlock();
+ flushLock.unlock();
}
return true;
}
-
return false;
}
private void checkSessionLimits() throws IOException {
- if (logger.isDebugEnabled()) {
- logger.debug("Another send already in progress, session id '" + getId() + "'" +
- ", in-progress send time " + getInProgressSendTime() + " (ms), " +
- ", buffer size " + this.bufferSize + " bytes");
- }
- if (getInProgressSendTime() > this.sendTimeLimit) {
- logError("A message could not be sent due to a timeout");
- getDelegate().close();
- }
- else if (this.bufferSize.get() > this.bufferSizeLimit) {
- logError("The total send buffer byte count '" + this.bufferSize.get() +
- "' for session '" + getId() + "' exceeds the allowed limit '" + this.bufferSizeLimit + "'");
- getDelegate().close();
+ if (this.closeLock.tryLock() && !this.sessionLimitExceeded) {
+ try {
+ if (getTimeSinceSendStarted() > this.sendTimeLimit) {
+ sessionLimitReached(
+ "Message send time " + getTimeSinceSendStarted() +
+ " (ms) exceeded the allowed limit " + this.sendTimeLimit);
+ }
+ else if (this.bufferSize.get() > this.bufferSizeLimit) {
+ sessionLimitReached(
+ "The send buffer size " + this.bufferSize.get() + " bytes for " +
+ "session '" + getId() + " exceeded the allowed limit " + this.bufferSizeLimit);
+ }
+ }
+ finally {
+ this.closeLock.unlock();
+ }
}
}
- private void logError(String reason) {
- logger.error(reason + ", number of buffered messages is '" + this.buffer.size() +
- "', time since the last send started is '" + getInProgressSendTime() + "' (ms)");
+ private void sessionLimitReached(String reason) {
+ this.sessionLimitExceeded = true;
+ throw new SessionLimitExceededException(reason);
}
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java b/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java
new file mode 100644
index 0000000000..5f5bad5666
--- /dev/null
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/handler/SessionLimitExceededException.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright 2002-2014 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.web.socket.handler;
+
+/**
+ * Raised when a WebSocket session has exceeded limits it has been configured
+ * for, e.g. timeout, buffer size, etc.
+ *
+ * @author Rossen Stoyanchev
+ * @since 3.0.4
+ */
+@SuppressWarnings("serial")
+public class SessionLimitExceededException extends RuntimeException {
+
+
+ public SessionLimitExceededException(String message) {
+ super(message);
+ }
+
+}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
index 7811b8f691..e9de1bbd34 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java
@@ -43,6 +43,7 @@ import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
+import org.springframework.web.socket.handler.SessionLimitExceededException;
/**
* A {@link SubProtocolHandler} for STOMP that supports versions 1.0, 1.1, and 1.2
@@ -202,6 +203,10 @@ public class StompSubProtocolHandler implements SubProtocolHandler {
session.sendMessage(textMessage);
}
+ catch (SessionLimitExceededException ex) {
+ // Bad session, just get out
+ throw ex;
+ }
catch (Throwable ex) {
sendErrorMessage(session, ex);
}
diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
index 670544c580..21b10891a1 100644
--- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
+++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java
@@ -43,6 +43,7 @@ import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
+import org.springframework.web.socket.handler.SessionLimitExceededException;
/**
* An implementation of {@link WebSocketHandler} that delegates incoming WebSocket
@@ -75,9 +76,9 @@ public class SubProtocolWebSocketHandler
private final Map sessions = new ConcurrentHashMap();
- private int sendTimeLimit = 20 * 1000;
+ private int sendTimeLimit = 10 * 1000;
- private int sendBufferSizeLimit = 1024 * 1024;
+ private int sendBufferSizeLimit = 64 * 1024;
private Object lifecycleMonitor = new Object();
@@ -282,6 +283,18 @@ public class SubProtocolWebSocketHandler
try {
findProtocolHandler(session).handleMessageToClient(session, message);
}
+ catch (SessionLimitExceededException e) {
+ try {
+ logger.error("Terminating session id '" + sessionId + "'", e);
+
+ // Session may be unresponsive so clear first
+ clearSession(session, CloseStatus.NO_STATUS_CODE);
+ session.close();
+ }
+ catch (Exception secondException) {
+ logger.error("Exception terminating session id '" + sessionId + "'", secondException);
+ }
+ }
catch (Exception e) {
logger.error("Failed to send message to client " + message, e);
}
@@ -309,6 +322,10 @@ public class SubProtocolWebSocketHandler
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
+ clearSession(session, closeStatus);
+ }
+
+ private void clearSession(WebSocketSession session, CloseStatus closeStatus) throws Exception {
this.sessions.remove(session.getId());
findProtocolHandler(session).afterSessionEnded(session, closeStatus, this.clientInboundChannel);
}
diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java
index 21413c32c6..6bfad1f910 100644
--- a/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java
+++ b/spring-websocket/src/test/java/org/springframework/web/socket/handler/ConcurrentWebSocketSessionDecoratorTests.java
@@ -29,6 +29,7 @@ import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
/**
* Unit tests for
@@ -55,7 +56,7 @@ public class ConcurrentWebSocketSessionDecoratorTests {
assertEquals(textMessage, session.getSentMessages().get(0));
assertEquals(0, concurrentSession.getBufferSize());
- assertEquals(0, concurrentSession.getInProgressSendTime());
+ assertEquals(0, concurrentSession.getTimeSinceSendStarted());
assertTrue(session.isOpen());
}
@@ -86,14 +87,14 @@ public class ConcurrentWebSocketSessionDecoratorTests {
// ensure some send time elapses
Thread.sleep(100);
- assertTrue(concurrentSession.getInProgressSendTime() > 0);
+ assertTrue(concurrentSession.getTimeSinceSendStarted() > 0);
TextMessage payload = new TextMessage("payload");
for (int i=0; i < 5; i++) {
concurrentSession.sendMessage(payload);
}
- assertTrue(concurrentSession.getInProgressSendTime() > 0);
+ assertTrue(concurrentSession.getTimeSinceSendStarted() > 0);
assertEquals(5 * payload.getPayloadLength(), concurrentSession.getBufferSize());
assertTrue(blockingSession.isOpen());
}
@@ -129,10 +130,13 @@ public class ConcurrentWebSocketSessionDecoratorTests {
// ensure some send time elapses
Thread.sleep(sendTimeLimit + 100);
- TextMessage payload = new TextMessage("payload");
- concurrentSession.sendMessage(payload);
-
- assertFalse(blockingSession.isOpen());
+ try {
+ TextMessage payload = new TextMessage("payload");
+ concurrentSession.sendMessage(payload);
+ fail("Expected exception");
+ }
+ catch (SessionLimitExceededException ex) {
+ }
}
@Test
@@ -174,8 +178,12 @@ public class ConcurrentWebSocketSessionDecoratorTests {
assertEquals(1023, concurrentSession.getBufferSize());
assertTrue(blockingSession.isOpen());
- concurrentSession.sendMessage(message);
- assertFalse(blockingSession.isOpen());
+ try {
+ concurrentSession.sendMessage(message);
+ fail("Expected exception");
+ }
+ catch (SessionLimitExceededException ex) {
+ }
}
@@ -217,4 +225,47 @@ public class ConcurrentWebSocketSessionDecoratorTests {
}
}
+// @Test
+// public void sendSessionLimitException() throws IOException, InterruptedException {
+//
+// BlockingSession blockingSession = new BlockingSession();
+// blockingSession.setOpen(true);
+// CountDownLatch sentMessageLatch = blockingSession.getSentMessageLatch();
+//
+// int sendTimeLimit = 10 * 1000;
+// int bufferSizeLimit = 1024;
+//
+// final ConcurrentWebSocketSessionDecorator concurrentSession =
+// new ConcurrentWebSocketSessionDecorator(blockingSession, sendTimeLimit, bufferSizeLimit);
+//
+// Executors.newSingleThreadExecutor().submit(new Runnable() {
+// @Override
+// public void run() {
+// TextMessage textMessage = new TextMessage("slow message");
+// try {
+// concurrentSession.sendMessage(textMessage);
+// }
+// catch (IOException e) {
+// e.printStackTrace();
+// }
+// }
+// });
+//
+// assertTrue(sentMessageLatch.await(5, TimeUnit.SECONDS));
+//
+// StringBuilder sb = new StringBuilder();
+// for (int i=0 ; i < 1023; i++) {
+// sb.append("a");
+// }
+//
+// TextMessage message = new TextMessage(sb.toString());
+// concurrentSession.sendMessage(message);
+//
+// assertEquals(1023, concurrentSession.getBufferSize());
+// assertTrue(blockingSession.isOpen());
+//
+// concurrentSession.sendMessage(message);
+// assertFalse(blockingSession.isOpen());
+// }
+
}