diff --git a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java index 1a1c14c9b9..314f33327d 100644 --- a/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java +++ b/spring-web/src/main/java/org/springframework/web/util/ContentCachingResponseWrapper.java @@ -48,6 +48,8 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { private int statusCode = HttpServletResponse.SC_OK; + private Integer contentLength; + /** * Create a new ContentCachingResponseWrapper for the given servlet response. @@ -73,21 +75,34 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @Override public void sendError(int sc) throws IOException { - copyBodyToResponse(); - super.sendError(sc); + copyBodyToResponse(false); + try { + super.sendError(sc); + } + catch (IllegalStateException ex) { + // Possibly on Tomcat when called too late: fall back to silent setStatus + super.setStatus(sc); + } this.statusCode = sc; } @Override + @SuppressWarnings("deprecation") public void sendError(int sc, String msg) throws IOException { - copyBodyToResponse(); - super.sendError(sc, msg); + copyBodyToResponse(false); + try { + super.sendError(sc, msg); + } + catch (IllegalStateException ex) { + // Possibly on Tomcat when called too late: fall back to silent setStatus + super.setStatus(sc, msg); + } this.statusCode = sc; } @Override public void sendRedirect(String location) throws IOException { - copyBodyToResponse(); + copyBodyToResponse(false); super.sendRedirect(location); } @@ -109,6 +124,7 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { @Override public void setContentLength(int len) { this.content.resize(len); + this.contentLength = len; } // Overrides Servlet 3.1 setContentLengthLong(long) at runtime @@ -117,7 +133,9 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { throw new IllegalArgumentException("Content-Length exceeds ShallowEtagHeaderFilter's maximum (" + Integer.MAX_VALUE + "): " + len); } - this.content.resize((int) len); + int lenInt = (int) len; + this.content.resize(lenInt); + this.contentLength = lenInt; } @Override @@ -150,24 +168,44 @@ public class ContentCachingResponseWrapper extends HttpServletResponseWrapper { return this.content.toByteArray(); } + /** + * Return an {@link InputStream} to the cached content. + */ + public InputStream getContentInputStream(){ + return this.content.getInputStream(); + } + + /** + * Return the current size of the cached content. + */ + public int getContentSize(){ + return this.content.size(); + } + + /** + * Copy the complete cached body content to the response. + */ public void copyBodyToResponse() throws IOException { + copyBodyToResponse(true); + } + + /** + * Copy the cached body content to the response. + * @param complete whether to set a corresponding content length + * for the complete cached body content + */ + protected void copyBodyToResponse(boolean complete) throws IOException { if (this.content.size() > 0) { HttpServletResponse rawResponse = (HttpServletResponse) getResponse(); - if(! rawResponse.isCommitted()){ - rawResponse.setContentLength(this.content.size()); + if ((complete || this.contentLength != null) && !rawResponse.isCommitted()){ + rawResponse.setContentLength(complete ? this.content.size() : this.contentLength); + this.contentLength = null; } this.content.writeTo(rawResponse.getOutputStream()); this.content.reset(); } } - public int getContentSize(){ - return this.content.size(); - } - - public InputStream getContentInputStream(){ - return this.content.getInputStream(); - } private class ResponseServletOutputStream extends ServletOutputStream { diff --git a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java index d05356d308..ea78811f89 100644 --- a/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java +++ b/spring-web/src/test/java/org/springframework/web/filter/ShallowEtagHeaderFilterTests.java @@ -150,6 +150,7 @@ public class ShallowEtagHeaderFilterTests { final byte[] responseBody = "Hello World".getBytes("UTF-8"); FilterChain filterChain = (filterRequest, filterResponse) -> { assertEquals("Invalid request passed", request, filterRequest); + response.setContentLength(100); FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); ((HttpServletResponse) filterResponse).sendError(HttpServletResponse.SC_FORBIDDEN); }; @@ -157,7 +158,7 @@ public class ShallowEtagHeaderFilterTests { assertEquals("Invalid status", 403, response.getStatus()); assertNull("Invalid ETag header", response.getHeader("ETag")); - assertTrue("Invalid Content-Length header", response.getContentLength() > 0); + assertEquals("Invalid Content-Length header", 100, response.getContentLength()); assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); } @@ -169,6 +170,7 @@ public class ShallowEtagHeaderFilterTests { final byte[] responseBody = "Hello World".getBytes("UTF-8"); FilterChain filterChain = (filterRequest, filterResponse) -> { assertEquals("Invalid request passed", request, filterRequest); + response.setContentLength(100); FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); ((HttpServletResponse) filterResponse).sendError(HttpServletResponse.SC_FORBIDDEN, "ERROR"); }; @@ -176,7 +178,7 @@ public class ShallowEtagHeaderFilterTests { assertEquals("Invalid status", 403, response.getStatus()); assertNull("Invalid ETag header", response.getHeader("ETag")); - assertTrue("Invalid Content-Length header", response.getContentLength() > 0); + assertEquals("Invalid Content-Length header", 100, response.getContentLength()); assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); assertEquals("Invalid error message", "ERROR", response.getErrorMessage()); } @@ -189,6 +191,7 @@ public class ShallowEtagHeaderFilterTests { final byte[] responseBody = "Hello World".getBytes("UTF-8"); FilterChain filterChain = (filterRequest, filterResponse) -> { assertEquals("Invalid request passed", request, filterRequest); + response.setContentLength(100); FileCopyUtils.copy(responseBody, filterResponse.getOutputStream()); ((HttpServletResponse) filterResponse).sendRedirect("http://www.google.com"); }; @@ -196,7 +199,7 @@ public class ShallowEtagHeaderFilterTests { assertEquals("Invalid status", 302, response.getStatus()); assertNull("Invalid ETag header", response.getHeader("ETag")); - assertTrue("Invalid Content-Length header", response.getContentLength() > 0); + assertEquals("Invalid Content-Length header", 100, response.getContentLength()); assertArrayEquals("Invalid content", responseBody, response.getContentAsByteArray()); assertEquals("Invalid redirect URL", "http://www.google.com", response.getRedirectedUrl()); }