diff --git a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java index c21a3ac0c6..ba2ad448c6 100644 --- a/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java +++ b/spring-web/src/main/java/org/springframework/web/filter/ForwardedHeaderFilter.java @@ -34,20 +34,23 @@ import org.springframework.http.HttpRequest; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UrlPathHelper; + /** * Filter that wraps the request in order to override its * {@link HttpServletRequest#getServerName() getServerName()}, * {@link HttpServletRequest#getServerPort() getServerPort()}, * {@link HttpServletRequest#getScheme() getScheme()}, and * {@link HttpServletRequest#isSecure() isSecure()} methods with values derived - * from "Fowarded" or "X-Forwarded-*" headers. In effect the wrapped request + * from "Forwarded" or "X-Forwarded-*" headers. In effect the wrapped request * reflects the client-originated protocol and address. * * @author Rossen Stoyanchev + * @author Eddú Meléndez * @since 4.3 */ public class ForwardedHeaderFilter extends OncePerRequestFilter { @@ -55,11 +58,12 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private static final Set FORWARDED_HEADER_NAMES; static { - FORWARDED_HEADER_NAMES = new HashSet(4); + FORWARDED_HEADER_NAMES = new HashSet(5); FORWARDED_HEADER_NAMES.add("Forwarded"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Host"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Port"); FORWARDED_HEADER_NAMES.add("X-Forwarded-Proto"); + FORWARDED_HEADER_NAMES.add("X-Forwarded-Prefix"); } @@ -131,6 +135,7 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { private final Map> headers; + public ForwardedHeaderRequestWrapper(HttpServletRequest request, ContextPathHelper pathHelper) { super(request); @@ -143,13 +148,23 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { this.host = uriComponents.getHost(); this.port = (port == -1 ? (this.secure ? 443 : 80) : port); - this.contextPath = (pathHelper != null ? pathHelper.getContextPath(request) : request.getContextPath()); - this.requestUri = (pathHelper != null ? pathHelper.getRequestUri(request) : request.getRequestURI()); + this.contextPath = initContextPath(request, pathHelper); + this.requestUri = initRequestUri(request, pathHelper); this.requestUrl = initRequestUrl(this.scheme, this.host, port, this.requestUri); - this.headers = initHeaders(request); } + + private static String initContextPath(HttpServletRequest request, ContextPathHelper pathHelper) { + String contextPath = (pathHelper != null ? pathHelper.getContextPath(request) : request.getContextPath()); + return prependForwardedPrefix(request, contextPath); + } + + private static String initRequestUri(HttpServletRequest request, ContextPathHelper pathHelper) { + String requestUri = (pathHelper != null ? pathHelper.getRequestUri(request) : request.getRequestURI()); + return prependForwardedPrefix(request, requestUri); + } + private static StringBuffer initRequestUrl(String scheme, String host, int port, String path) { StringBuffer sb = new StringBuffer(); sb.append(scheme).append("://").append(host); @@ -174,6 +189,17 @@ public class ForwardedHeaderFilter extends OncePerRequestFilter { return headers; } + private static String prependForwardedPrefix(HttpServletRequest request, String value) { + String header = request.getHeader("X-Forwarded-Prefix"); + if (StringUtils.hasText(header)) { + while (header.endsWith("/")) { + header = header.substring(0, header.length() - 1); + } + value = header + value; + } + return value; + } + @Override public String getScheme() { diff --git a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java index bd7a375bd6..4c9c45256f 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ForwardedHeaderFilterTests.java @@ -35,6 +35,7 @@ import static org.junit.Assert.assertTrue; /** * Unit tests for {@link ForwardedHeaderFilter}. * @author Rossen Stoyanchev + * @author Eddú Meléndez */ public class ForwardedHeaderFilterTests { @@ -170,6 +171,41 @@ public class ForwardedHeaderFilterTests { assertEquals("bar", actual.getHeader("foo")); } + @Test + public void requestUriWithForwardedPrefix() throws Exception { + this.request.addHeader("X-Forwarded-Prefix", "/prefix"); + this.request.setRequestURI("/mvc-showcase"); + + HttpServletRequest actual = filterAndGetWrappedRequest(); + assertEquals("http://localhost/prefix/mvc-showcase", actual.getRequestURL().toString()); + } + + @Test + public void requestUriWithForwardedPrefixTrailingSlash() throws Exception { + this.request.addHeader("X-Forwarded-Prefix", "/prefix/"); + this.request.setRequestURI("/mvc-showcase"); + + HttpServletRequest actual = filterAndGetWrappedRequest(); + assertEquals("http://localhost/prefix/mvc-showcase", actual.getRequestURL().toString()); + } + + @Test + public void contextPathWithForwardedPrefix() throws Exception { + this.request.addHeader("X-Forwarded-Prefix", "/prefix"); + this.request.setContextPath("/mvc-showcase"); + + String actual = filterAndGetContextPath(); + assertEquals("/prefix/mvc-showcase", actual); + } + + @Test + public void contextPathWithForwardedPrefixTrailingSlash() throws Exception { + this.request.addHeader("X-Forwarded-Prefix", "/prefix/"); + this.request.setContextPath("/mvc-showcase"); + + String actual = filterAndGetContextPath(); + assertEquals("/prefix/mvc-showcase", actual); + } private String filterAndGetContextPath() throws ServletException, IOException { return filterAndGetWrappedRequest().getContextPath();