diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java index 016b7ef924..54eebd1be2 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java @@ -64,6 +64,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.hamcrest.Matchers.*; +import static org.junit.Assert.fail; /** * Integration tests using the @@ -296,6 +297,8 @@ public abstract class AbstractSockJsIntegrationTests { private volatile WebSocketSession session; + private volatile Throwable transportError; + private volatile CloseStatus closeStatus; @@ -309,6 +312,11 @@ public abstract class AbstractSockJsIntegrationTests { this.receivedMessages.add(message); } + @Override + public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { + this.transportError = exception; + } + @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { this.closeStatus = status; @@ -321,12 +329,22 @@ public abstract class AbstractSockJsIntegrationTests { public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException { TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS); - assertNotNull("Timed out waiting for [" + expected + "]", actual); - assertEquals(expected, actual); + if (actual != null) { + assertEquals(expected, actual); + } + else if (this.transportError != null) { + throw new AssertionError("Transport error", this.transportError); + } + else { + fail("Timed out waiting for [" + expected + "]"); + } } public CloseStatus awaitCloseStatus(long timeToWait) throws InterruptedException { - awaitEvent(() -> this.closeStatus != null, timeToWait, " CloseStatus"); + awaitEvent(() -> this.closeStatus != null || this.transportError != null, timeToWait, " CloseStatus"); + if (this.transportError != null) { + throw new AssertionError("Transport error", this.transportError); + } return this.closeStatus; } }