From 4525c6a5371dbae3a618b54d2b0393d97a1529b7 Mon Sep 17 00:00:00 2001 From: Arjen Poutsma Date: Thu, 4 May 2017 12:21:48 +0200 Subject: [PATCH] Add support for Flux in BodyExtractors This commit adds a `toParts` method in `BodyExtractors`, returning a BodyExtractor. --- .../web/reactive/function/BodyExtractors.java | 63 +++++++++------- .../function/BodyExtractorsTests.java | 71 +++++++++++++++++++ .../function/MultipartIntegrationTests.java | 40 +++++++++-- 3 files changed, 144 insertions(+), 30 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java index 3e98ba59ec..36999443fa 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/BodyExtractors.java @@ -48,12 +48,14 @@ import org.springframework.util.MultiValueMap; */ public abstract class BodyExtractors { - private static final ResolvableType FORM_TYPE = + private static final ResolvableType FORM_MAP_TYPE = ResolvableType.forClassWithGenerics(MultiValueMap.class, String.class, String.class); - private static final ResolvableType MULTIPART_TYPE = ResolvableType.forClassWithGenerics( + private static final ResolvableType MULTIPART_MAP_TYPE = ResolvableType.forClassWithGenerics( MultiValueMap.class, String.class, Part.class); + private static final ResolvableType PART_TYPE = ResolvableType.forClass(Part.class); + /** * Return a {@code BodyExtractor} that reads into a Reactor {@link Mono}. @@ -133,15 +135,16 @@ public abstract class BodyExtractors { public static BodyExtractor>, ServerHttpRequest> toFormData() { return (serverRequest, context) -> { HttpMessageReader> messageReader = - formMessageReader(context); + messageReader(FORM_MAP_TYPE, MediaType.APPLICATION_FORM_URLENCODED, context); return context.serverResponse() - .map(serverResponse -> messageReader.readMono(FORM_TYPE, FORM_TYPE, serverRequest, serverResponse, context.hints())) - .orElseGet(() -> messageReader.readMono(FORM_TYPE, serverRequest, context.hints())); + .map(serverResponse -> messageReader.readMono(FORM_MAP_TYPE, FORM_MAP_TYPE, serverRequest, serverResponse, context.hints())) + .orElseGet(() -> messageReader.readMono(FORM_MAP_TYPE, serverRequest, context.hints())); }; } /** - * Return a {@code BodyExtractor} that reads form data into a {@link MultiValueMap}. + * Return a {@code BodyExtractor} that reads multipart (i.e. file upload) form data into a + * {@link MultiValueMap}. * @return a {@code BodyExtractor} that reads multipart data */ // Note that the returned BodyExtractor is parameterized to ServerHttpRequest, not @@ -150,10 +153,29 @@ public abstract class BodyExtractors { public static BodyExtractor>, ServerHttpRequest> toMultipartData() { return (serverRequest, context) -> { HttpMessageReader> messageReader = - multipartMessageReader(context); + messageReader(MULTIPART_MAP_TYPE, MediaType.MULTIPART_FORM_DATA, context); return context.serverResponse() - .map(serverResponse -> messageReader.readMono(MULTIPART_TYPE, MULTIPART_TYPE, serverRequest, serverResponse, context.hints())) - .orElseGet(() -> messageReader.readMono(MULTIPART_TYPE, serverRequest, context.hints())); + .map(serverResponse -> messageReader.readMono(MULTIPART_MAP_TYPE, + MULTIPART_MAP_TYPE, serverRequest, serverResponse, context.hints())) + .orElseGet(() -> messageReader.readMono(MULTIPART_MAP_TYPE, serverRequest, context.hints())); + }; + } + + /** + * Return a {@code BodyExtractor} that reads multipart (i.e. file upload) form data into a + * {@link MultiValueMap}. + * @return a {@code BodyExtractor} that reads multipart data + */ + // Note that the returned BodyExtractor is parameterized to ServerHttpRequest, not + // ReactiveHttpInputMessage like other methods, since reading form data only typically happens on + // the server-side + public static BodyExtractor, ServerHttpRequest> toParts() { + return (serverRequest, context) -> { + HttpMessageReader messageReader = + messageReader(PART_TYPE, MediaType.MULTIPART_FORM_DATA, context); + return context.serverResponse() + .map(serverResponse -> messageReader.read(PART_TYPE, PART_TYPE, serverRequest, serverResponse, context.hints())) + .orElseGet(() -> messageReader.read(PART_TYPE, serverRequest, context.hints())); }; } @@ -191,26 +213,15 @@ public abstract class BodyExtractors { }); } - private static HttpMessageReader> formMessageReader(BodyExtractor.Context context) { + private static HttpMessageReader messageReader(ResolvableType elementType, + MediaType mediaType, BodyExtractor.Context context) { return context.messageReaders().get() - .filter(messageReader -> messageReader - .canRead(FORM_TYPE, MediaType.APPLICATION_FORM_URLENCODED)) + .filter(messageReader -> messageReader.canRead(elementType, mediaType)) .findFirst() - .map(BodyExtractors::>cast) - .orElseThrow(() -> new IllegalStateException( - "Could not find HttpMessageReader that supports " + - MediaType.APPLICATION_FORM_URLENCODED_VALUE)); - } - - private static HttpMessageReader> multipartMessageReader(BodyExtractor.Context context) { - return context.messageReaders().get() - .filter(messageReader -> messageReader - .canRead(MULTIPART_TYPE, MediaType.MULTIPART_FORM_DATA)) - .findFirst() - .map(BodyExtractors::>cast) + .map(BodyExtractors::cast) .orElseThrow(() -> new IllegalStateException( - "Could not find HttpMessageReader that supports " + - MediaType.MULTIPART_FORM_DATA)); + "Could not find HttpMessageReader that supports \"" + mediaType + + "\" and \"" + elementType + "\"")); } private static MediaType contentType(HttpMessage message) { diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java index f83350f987..46d2718e45 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/BodyExtractorsTests.java @@ -36,6 +36,8 @@ import reactor.test.StepVerifier; import org.springframework.core.codec.ByteBufferDecoder; import org.springframework.core.codec.StringDecoder; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.Resource; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; @@ -45,10 +47,16 @@ import org.springframework.http.codec.DecoderHttpMessageReader; import org.springframework.http.codec.FormHttpMessageReader; import org.springframework.http.codec.HttpMessageReader; import org.springframework.http.codec.json.Jackson2JsonDecoder; +import org.springframework.http.codec.multipart.FilePart; +import org.springframework.http.codec.multipart.FormFieldPart; +import org.springframework.http.codec.multipart.MultipartHttpMessageReader; +import org.springframework.http.codec.multipart.Part; +import org.springframework.http.codec.multipart.SynchronossPartHttpMessageReader; import org.springframework.http.codec.xml.Jaxb2XmlDecoder; import org.springframework.http.server.reactive.ServerHttpRequest; import org.springframework.http.server.reactive.ServerHttpResponse; import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest; +import org.springframework.util.FileCopyUtils; import org.springframework.util.MultiValueMap; import static org.junit.Assert.*; @@ -72,6 +80,11 @@ public class BodyExtractorsTests { messageReaders.add(new DecoderHttpMessageReader<>(StringDecoder.allMimeTypes(true))); messageReaders.add(new DecoderHttpMessageReader<>(new Jaxb2XmlDecoder())); messageReaders.add(new DecoderHttpMessageReader<>(new Jackson2JsonDecoder())); + messageReaders.add(new FormHttpMessageReader()); + SynchronossPartHttpMessageReader partReader = new SynchronossPartHttpMessageReader(); + messageReaders.add(partReader); + messageReaders.add(new MultipartHttpMessageReader(partReader)); + messageReaders.add(new FormHttpMessageReader()); this.context = new BodyExtractor.Context() { @@ -249,6 +262,64 @@ public class BodyExtractorsTests { .verify(); } + @Test + public void toParts() throws Exception { + BodyExtractor, ServerHttpRequest> extractor = BodyExtractors.toParts(); + + String bodyContents = "-----------------------------9051914041544843365972754266\r\n" + + "Content-Disposition: form-data; name=\"text\"\r\n" + + "\r\n" + + "text default\r\n" + + "-----------------------------9051914041544843365972754266\r\n" + + "Content-Disposition: form-data; name=\"file1\"; filename=\"a.txt\"\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + "Content of a.txt.\r\n" + + "\r\n" + + "-----------------------------9051914041544843365972754266\r\n" + + "Content-Disposition: form-data; name=\"file2\"; filename=\"a.html\"\r\n" + + "Content-Type: text/html\r\n" + + "\r\n" + + "Content of a.html.\r\n" + + "\r\n" + + "-----------------------------9051914041544843365972754266--\r\n"; + + DefaultDataBufferFactory factory = new DefaultDataBufferFactory(); + DefaultDataBuffer dataBuffer = + factory.wrap(ByteBuffer.wrap(bodyContents.getBytes(StandardCharsets.UTF_8))); + Flux body = Flux.just(dataBuffer); + + MockServerHttpRequest request = MockServerHttpRequest.post("/") + .header("Content-Type", "multipart/form-data; boundary=---------------------------9051914041544843365972754266") + .body(body); + + Flux result = extractor.extract(request, this.context); + + StepVerifier.create(result) + .consumeNextWith(part -> { + assertEquals("text", part.getName()); + assertTrue(part instanceof FormFieldPart); + FormFieldPart formFieldPart = (FormFieldPart) part; + assertEquals("text default", formFieldPart.getValue()); + }) + .consumeNextWith(part -> { + assertEquals("file1", part.getName()); + assertTrue(part instanceof FilePart); + FilePart filePart = (FilePart) part; + assertEquals("a.txt", filePart.getFilename()); + assertEquals(MediaType.TEXT_PLAIN, filePart.getHeaders().getContentType()); + }) + .consumeNextWith(part -> { + assertEquals("file2", part.getName()); + assertTrue(part instanceof FilePart); + FilePart filePart = (FilePart) part; + assertEquals("a.html", filePart.getFilename()); + assertEquals(MediaType.TEXT_HTML, filePart.getHeaders().getContentType()); + }) + .expectComplete() + .verify(); + } + @Test public void toDataBuffers() throws Exception { BodyExtractor, ReactiveHttpInputMessage> extractor = BodyExtractors.toDataBuffers(); diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java index 624239f931..a0ede6d495 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/MultipartIntegrationTests.java @@ -16,6 +16,7 @@ package org.springframework.web.reactive.function; +import java.util.List; import java.util.Map; import org.junit.Test; @@ -48,10 +49,25 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration private final WebClient webClient = WebClient.create(); @Test - public void multipart() { + public void multipartData() { Mono result = webClient .post() - .uri("http://localhost:" + this.port + "/") + .uri("http://localhost:" + this.port + "/multipartData") + .contentType(MediaType.MULTIPART_FORM_DATA) + .body(BodyInserters.fromMultipartData(generateBody())) + .exchange(); + + StepVerifier + .create(result) + .consumeNextWith(response -> assertEquals(HttpStatus.OK, response.statusCode())) + .verifyComplete(); + } + + @Test + public void parts() { + Mono result = webClient + .post() + .uri("http://localhost:" + this.port + "/parts") .contentType(MediaType.MULTIPART_FORM_DATA) .body(BodyInserters.fromMultipartData(generateBody())) .exchange(); @@ -77,12 +93,13 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration @Override protected RouterFunction routerFunction() { MultipartHandler multipartHandler = new MultipartHandler(); - return route(POST("/"), multipartHandler::handle); + return route(POST("/multipartData"), multipartHandler::multipartData) + .andRoute(POST("/parts"), multipartHandler::parts); } private static class MultipartHandler { - public Mono handle(ServerRequest request) { + public Mono multipartData(ServerRequest request) { return request .body(BodyExtractors.toMultipartData()) .flatMap(map -> { @@ -98,6 +115,21 @@ public class MultipartIntegrationTests extends AbstractRouterFunctionIntegration return ServerResponse.ok().build(); }); } + + public Mono parts(ServerRequest request) { + return request.body(BodyExtractors.toParts()).collectList() + .flatMap(parts -> { + try { + assertEquals(2, parts.size()); + assertEquals("foo.txt", ((FilePart) parts.get(0)).getFilename()); + assertEquals("bar", ((FormFieldPart) parts.get(1)).getValue()); + } + catch(Exception e) { + return Mono.error(e); + } + return ServerResponse.ok().build(); + }); + } } }