From cd634633d8ba690ff5c91efdbd0ace19e1018fe2 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Tue, 17 Oct 2017 16:57:35 -0400 Subject: [PATCH] MockMvc re-uses response instance on async dispatch MockMvc now properly detects the presence of an AsyncContext and re-uses the response instance used to start it. This commit also includes a minor fix in ResponseBodyEmitterReturnValueHandler to ensure it does not disable ETag related content buffering for reactive return values that do not result in streaming (e.g. single value or collections). Issue: SPR-16067 --- .../test/web/servlet/MockMvc.java | 35 +++++++-- .../samples/standalone/FilterTests.java | 74 ++++++++++++++----- ...ResponseBodyEmitterReturnValueHandler.java | 14 ++-- 3 files changed, 91 insertions(+), 32 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/MockMvc.java b/spring-test/src/main/java/org/springframework/test/web/servlet/MockMvc.java index ccf80918ba..76b884a7cc 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/MockMvc.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/MockMvc.java @@ -18,9 +18,13 @@ package org.springframework.test.web.servlet; import java.util.ArrayList; import java.util.List; +import javax.servlet.AsyncContext; import javax.servlet.DispatcherType; import javax.servlet.Filter; import javax.servlet.ServletContext; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; import org.springframework.beans.Mergeable; import org.springframework.lang.Nullable; @@ -135,24 +139,35 @@ public final class MockMvc { } MockHttpServletRequest request = requestBuilder.buildRequest(this.servletContext); - MockHttpServletResponse response = new MockHttpServletResponse(); + + AsyncContext asyncContext = request.getAsyncContext(); + MockHttpServletResponse mockResponse; + HttpServletResponse servletResponse; + if (asyncContext != null) { + servletResponse = (HttpServletResponse) asyncContext.getResponse(); + mockResponse = unwrapResponseIfNecessary(servletResponse); + } + else { + mockResponse = new MockHttpServletResponse(); + servletResponse = mockResponse; + } if (requestBuilder instanceof SmartRequestBuilder) { request = ((SmartRequestBuilder) requestBuilder).postProcessRequest(request); } - final MvcResult mvcResult = new DefaultMvcResult(request, response); + final MvcResult mvcResult = new DefaultMvcResult(request, mockResponse); request.setAttribute(MVC_RESULT_ATTRIBUTE, mvcResult); RequestAttributes previousAttributes = RequestContextHolder.getRequestAttributes(); - RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, response)); + RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(request, servletResponse)); MockFilterChain filterChain = new MockFilterChain(this.servlet, this.filters); - filterChain.doFilter(request, response); + filterChain.doFilter(request, servletResponse); if (DispatcherType.ASYNC.equals(request.getDispatcherType()) && - request.getAsyncContext() != null & !request.isAsyncStarted()) { - request.getAsyncContext().complete(); + asyncContext != null & !request.isAsyncStarted()) { + asyncContext.complete(); } applyDefaultResultActions(mvcResult); @@ -176,6 +191,14 @@ public final class MockMvc { }; } + private MockHttpServletResponse unwrapResponseIfNecessary(ServletResponse servletResponse) { + while (servletResponse instanceof HttpServletResponseWrapper) { + servletResponse = ((HttpServletResponseWrapper) servletResponse).getResponse(); + } + Assert.isInstanceOf(MockHttpServletResponse.class, servletResponse); + return (MockHttpServletResponse) servletResponse; + } + private void applyDefaultResultActions(MvcResult mvcResult) throws Exception { diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java index e122f71760..0f7bc970c9 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/samples/standalone/FilterTests.java @@ -18,6 +18,7 @@ package org.springframework.test.web.servlet.samples.standalone; import java.io.IOException; import java.security.Principal; +import java.util.concurrent.CompletableFuture; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; @@ -29,22 +30,34 @@ import javax.validation.Valid; import org.junit.Test; +import org.springframework.http.MediaType; import org.springframework.stereotype.Controller; import org.springframework.test.web.Person; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; import org.springframework.validation.Errors; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RequestMethod; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.filter.OncePerRequestFilter; +import org.springframework.web.filter.ShallowEtagHeaderFilter; import org.springframework.web.servlet.ModelAndView; import org.springframework.web.servlet.mvc.support.RedirectAttributes; -import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.*; -import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.*; -import static org.springframework.test.web.servlet.setup.MockMvcBuilders.*; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.flash; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.model; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; +import static org.springframework.test.web.servlet.setup.MockMvcBuilders.standaloneSetup; /** * Tests with {@link Filter}'s. - * * @author Rob Winch */ public class FilterTests { @@ -107,10 +120,29 @@ public class FilterTests { .andExpect(model().attribute("principal", WrappingRequestResponseFilter.PRINCIPAL_NAME)); } + @Test // SPR-16067 + public void filterWrapsRequestResponseWithAsyncDispatch() throws Exception { + MockMvc mockMvc = standaloneSetup(new PersonController()) + .addFilters(new ShallowEtagHeaderFilter()) + .build(); + + MvcResult mvcResult = mockMvc.perform(get("/persons/1").accept(MediaType.APPLICATION_JSON)) + .andExpect(request().asyncStarted()) + .andExpect(request().asyncResult(new Person("Lukas"))) + .andReturn(); + + mockMvc.perform(asyncDispatch(mvcResult)) + .andExpect(status().isOk()) + .andExpect(header().longValue("Content-Length", 53)) + .andExpect(header().string("ETag", "\"0e37becb4f0c90709cb2e1efcc61eaa00\"")) + .andExpect(content().string("{\"name\":\"Lukas\",\"someDouble\":0.0,\"someBoolean\":false}")); + } + @Controller private static class PersonController { - @RequestMapping(value="/persons", method=RequestMethod.POST) + + @PostMapping(path="/persons") public String save(@Valid Person person, Errors errors, RedirectAttributes redirectAttrs) { if (errors.hasErrors()) { return "person/add"; @@ -120,18 +152,25 @@ public class FilterTests { return "redirect:/person/{id}"; } - @RequestMapping(value="/user") + @PostMapping("/user") public ModelAndView user(Principal principal) { return new ModelAndView("user/view", "principal", principal.getName()); } - @RequestMapping(value="/forward") + @GetMapping("/forward") public String forward() { return "forward:/persons"; } + + @GetMapping("persons/{id}") + @ResponseBody + public CompletableFuture getPerson() { + return CompletableFuture.completedFuture(new Person("Lukas")); + } } private class ContinueFilter extends OncePerRequestFilter { + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { @@ -144,28 +183,25 @@ public class FilterTests { public static final String PRINCIPAL_NAME = "WrapRequestResponseFilterPrincipal"; + @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { filterChain.doFilter(new HttpServletRequestWrapper(request) { @Override public Principal getUserPrincipal() { - return new Principal() { - @Override - public String getName() { - return PRINCIPAL_NAME; - } - }; + return () -> PRINCIPAL_NAME; } }, new HttpServletResponseWrapper(response)); } } private class RedirectFilter extends OncePerRequestFilter { + @Override - protected void doFilterInternal(HttpServletRequest request, - HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, + FilterChain filterChain) throws ServletException, IOException { response.sendRedirect("/login"); } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java index 162795340a..d361090cc8 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ResponseBodyEmitterReturnValueHandler.java @@ -140,23 +140,23 @@ public class ResponseBodyEmitterReturnValueHandler implements HandlerMethodRetur ServletRequest request = webRequest.getNativeRequest(ServletRequest.class); Assert.state(request != null, "No ServletRequest"); - ShallowEtagHeaderFilter.disableContentCaching(request); ResponseBodyEmitter emitter; - if (returnValue instanceof ResponseBodyEmitter) { emitter = (ResponseBodyEmitter) returnValue; } else { emitter = this.reactiveHandler.handleValue(returnValue, returnType, mavContainer, webRequest); + if (emitter == null) { + // Not streaming.. + return; + } } - - if (emitter == null) { - return; - } - emitter.extendResponse(outputMessage); + // At this point we know we're streaming.. + ShallowEtagHeaderFilter.disableContentCaching(request); + // Commit the response and wrap to ignore further header changes outputMessage.getBody(); outputMessage.flush();