diff --git a/org.springframework.web/src/main/java/org/springframework/web/client/RestTemplate.java b/org.springframework.web/src/main/java/org/springframework/web/client/RestTemplate.java index 0006241760..6e5fa12385 100644 --- a/org.springframework.web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/org.springframework.web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -58,7 +58,8 @@ import org.springframework.web.util.UriUtils; * GET{@link #getForObject} HEAD{@link #headForHeaders} * OPTIONS{@link #optionsForAllow} POST{@link #postForLocation} * {@link #postForObject} PUT{@link #put} - * any{@link #execute} + * any{@link #exchange} + * {@link #execute} * *

For each of these HTTP methods, there are three corresponding Java methods in the {@code RestTemplate}. Two * variant take a {@code String} URI as first argument (eg. {@link #getForObject(String, Class, Object[])}, {@link @@ -360,6 +361,29 @@ public class RestTemplate extends HttpAccessor implements RestOperations { return headers.getAllow(); } + // exchange + + public HttpEntity exchange(String url, HttpMethod method, + HttpEntity requestEntity, Class responseType, Object... uriVariables) throws RestClientException { + HttpEntityRequestCallback requestCallback = new HttpEntityRequestCallback(requestEntity, responseType); + HttpEntityResponseExtractor responseExtractor = new HttpEntityResponseExtractor(responseType); + return execute(url, method, requestCallback, responseExtractor, uriVariables); + } + + public HttpEntity exchange(String url, HttpMethod method, + HttpEntity requestEntity, Class responseType, Map uriVariables) throws RestClientException { + HttpEntityRequestCallback requestCallback = new HttpEntityRequestCallback(requestEntity, responseType); + HttpEntityResponseExtractor responseExtractor = new HttpEntityResponseExtractor(responseType); + return execute(url, method, requestCallback, responseExtractor, uriVariables); + } + + public HttpEntity exchange(URI url, HttpMethod method, HttpEntity requestEntity, + Class responseType) throws RestClientException { + HttpEntityRequestCallback requestCallback = new HttpEntityRequestCallback(requestEntity, responseType); + HttpEntityResponseExtractor responseExtractor = new HttpEntityResponseExtractor(responseType); + return execute(url, method, requestCallback, responseExtractor); + } + // general execution public T execute(String url, HttpMethod method, RequestCallback requestCallback, diff --git a/org.springframework.web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java b/org.springframework.web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java index 09054f1169..ea5732bfa5 100644 --- a/org.springframework.web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java +++ b/org.springframework.web/src/test/java/org/springframework/web/client/RestTemplateIntegrationTests.java @@ -52,6 +52,7 @@ import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.http.HttpHeaders; import org.springframework.http.client.CommonsClientHttpRequestFactory; import org.springframework.util.FileCopyUtils; import org.springframework.util.LinkedMultiValueMap; @@ -177,6 +178,16 @@ public class RestTemplateIntegrationTests { template.postForLocation(URI + "/multipart", parts); } + @Test + public void exchange() throws Exception { + HttpHeaders requestHeaders = new HttpHeaders(); + requestHeaders.set("MyHeader", "MyValue"); + HttpEntity requestEntity = new HttpEntity(requestHeaders); + HttpEntity response = + template.exchange(URI + "/{method}", HttpMethod.GET, requestEntity, String.class, "get"); + assertEquals("Invalid content", helloWorld, response.getBody()); + } + /** Servlet that returns and error message for a given status code. */ private static class ErrorServlet extends GenericServlet { diff --git a/org.springframework.web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/org.springframework.web/src/test/java/org/springframework/web/client/RestTemplateTests.java index 87bf2bab90..aa3217777c 100644 --- a/org.springframework.web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/org.springframework.web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -538,6 +538,40 @@ public class RestTemplateTests { verifyMocks(); } + @Test + public void exchange() throws Exception { + MediaType textPlain = new MediaType("text", "plain"); + expect(converter.canRead(Integer.class, null)).andReturn(true); + expect(converter.getSupportedMediaTypes()).andReturn(Collections.singletonList(textPlain)); + expect(requestFactory.createRequest(new URI("http://example.com"), HttpMethod.POST)).andReturn(this.request); + HttpHeaders requestHeaders = new HttpHeaders(); + expect(this.request.getHeaders()).andReturn(requestHeaders).times(2); + expect(converter.canWrite(String.class, null)).andReturn(true); + String body = "Hello World"; + converter.write(body, null, this.request); + expect(this.request.execute()).andReturn(response); + expect(errorHandler.hasError(response)).andReturn(false); + HttpHeaders responseHeaders = new HttpHeaders(); + responseHeaders.setContentType(textPlain); + expect(response.getHeaders()).andReturn(responseHeaders).times(2); + Integer expected = 42; + expect(converter.canRead(Integer.class, textPlain)).andReturn(true); + expect(converter.read(Integer.class, response)).andReturn(expected); + response.close(); + + replayMocks(); + + HttpEntity requestEntity = new HttpEntity(body, Collections.singletonMap("MyHeader", "MyValue")); + HttpEntity result = template.exchange("http://example.com", HttpMethod.POST, requestEntity, Integer.class); + assertEquals("Invalid POST result", expected, result.getBody()); + assertEquals("Invalid Content-Type", textPlain, result.getHeaders().getContentType()); + assertEquals("Invalid Accept header", textPlain.toString(), requestHeaders.getFirst("Accept")); + assertEquals("Invalid custom header", "MyValue", requestHeaders.getFirst("MyHeader")); + + verifyMocks(); + } + + private void replayMocks() { replay(requestFactory, request, response, errorHandler, converter); }