diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java index 4da9e7ef84..53aa9699ca 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptor.java @@ -38,7 +38,15 @@ import org.springframework.web.socket.server.HandshakeInterceptor; */ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor { - private Collection attributeNames; + /** + * The name of the attribute under which the HTTP session id is exposed when + * {@link #setCopyHttpSessionId(boolean) copyHttpSessionId} is "true". + */ + public static final String HTTP_SESSION_ID_ATTR_NAME = "HTTP.SESSION.ID"; + + private final Collection attributeNames; + + private boolean copyHttpSessionId; /** @@ -56,6 +64,25 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor { this.attributeNames = attributeNames; } + /** + * When set to "true", the HTTP session id is copied to the WebSocket + * handshake attributes, and is subsequently available via + * {@link org.springframework.web.socket.WebSocketSession#getAttributes()} + * under the key {@link #HTTP_SESSION_ID_ATTR_NAME}. + *

By default this is "false". + * @param copyHttpSessionId whether to copy the HTTP session id. + */ + public void setCopyHttpSessionId(boolean copyHttpSessionId) { + this.copyHttpSessionId = copyHttpSessionId; + } + + /** + * Whether to copy the HTTP session id to the handshake attributes. + */ + public boolean isCopyHttpSessionId() { + return this.copyHttpSessionId; + } + @Override public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, @@ -72,6 +99,9 @@ public class HttpSessionHandshakeInterceptor implements HandshakeInterceptor { attributes.put(name, session.getAttribute(name)); } } + if (isCopyHttpSessionId()) { + attributes.put(HTTP_SESSION_ID_ATTR_NAME, session.getId()); + } } } return true; diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java index 6956051066..8ab3162c72 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/server/support/HttpSessionHandshakeInterceptorTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2013 the original author or authors. + * Copyright 2002-2014 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import java.util.Set; import org.junit.Test; import org.mockito.Mockito; +import org.springframework.mock.web.test.MockHttpSession; import org.springframework.web.socket.AbstractHttpRequestTests; import org.springframework.web.socket.WebSocketHandler; @@ -46,7 +47,7 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes this.servletRequest.getSession().setAttribute("bar", "baz"); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); - interceptor.beforeHandshake(request, response, wsHandler, attributes); + interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertEquals(2, attributes.size()); assertEquals("bar", attributes.get("foo")); @@ -64,12 +65,28 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes Set names = Collections.singleton("foo"); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(names); - interceptor.beforeHandshake(request, response, wsHandler, attributes); + interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertEquals(1, attributes.size()); assertEquals("bar", attributes.get("foo")); } + @Test + public void copyHttpSessionId() throws Exception { + + Map attributes = new HashMap(); + WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); + + this.servletRequest.setSession(new MockHttpSession(null, "foo")); + + HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); + interceptor.setCopyHttpSessionId(true); + interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); + + assertEquals(1, attributes.size()); + assertEquals("foo", attributes.get(HttpSessionHandshakeInterceptor.HTTP_SESSION_ID_ATTR_NAME)); + } + @Test public void doNotCauseSessionCreation() throws Exception { @@ -77,7 +94,7 @@ public class HttpSessionHandshakeInterceptorTests extends AbstractHttpRequestTes WebSocketHandler wsHandler = Mockito.mock(WebSocketHandler.class); HttpSessionHandshakeInterceptor interceptor = new HttpSessionHandshakeInterceptor(); - interceptor.beforeHandshake(request, response, wsHandler, attributes); + interceptor.beforeHandshake(this.request, this.response, wsHandler, attributes); assertNull(this.servletRequest.getSession(false)); }