diff --git a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java index c1f600b1c0..6e94ea7840 100644 --- a/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java +++ b/spring-web/src/main/java/org/springframework/http/client/MultipartBodyBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,12 +22,15 @@ import java.util.Map; import java.util.function.Consumer; import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ResolvableType; +import org.springframework.core.io.buffer.DataBuffer; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; +import org.springframework.http.codec.multipart.Part; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -121,15 +124,21 @@ public final class MultipartBodyBuilder { /** * Add an asynchronous part with {@link Publisher}-based content. * @param name the name of the part to add - * @param publisher the part contents + * @param publisher a Publisher of content for the part * @param elementClass the type of elements contained in the publisher * @return builder that allows for further customization of part headers */ + @SuppressWarnings("unchecked") public > PartBuilder asyncPart(String name, P publisher, Class elementClass) { Assert.hasLength(name, "'name' must not be empty"); Assert.notNull(publisher, "'publisher' must not be null"); Assert.notNull(elementClass, "'elementClass' must not be null"); + if (Part.class.isAssignableFrom(elementClass)) { + publisher = (P) Mono.from(publisher).flatMapMany(p -> ((Part) p).content()); + elementClass = (Class) DataBuffer.class; + } + HttpHeaders headers = new HttpHeaders(); PublisherPartBuilder builder = new PublisherPartBuilder<>(headers, publisher, elementClass); this.parts.add(name, builder); diff --git a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java index 7e5aaa4b5d..e39967f476 100644 --- a/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java +++ b/spring-web/src/test/java/org/springframework/http/codec/multipart/MultipartHttpMessageWriterTests.java @@ -17,6 +17,7 @@ package org.springframework.http.codec.multipart; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Collections; import java.util.List; @@ -46,6 +47,7 @@ import org.springframework.mock.http.server.reactive.test.MockServerHttpResponse import org.springframework.util.MultiValueMap; import static org.junit.Assert.*; +import static org.mockito.Mockito.*; /** * @author Sebastien Deleuze @@ -94,7 +96,13 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTestCas } }; - Publisher publisher = Flux.just("foo", "bar", "baz"); + Flux bufferPublisher = Flux.just( + this.bufferFactory.wrap("Aa".getBytes(StandardCharsets.UTF_8)), + this.bufferFactory.wrap("Bb".getBytes(StandardCharsets.UTF_8)), + this.bufferFactory.wrap("Cc".getBytes(StandardCharsets.UTF_8)) + ); + Part mockPart = mock(Part.class); + when(mockPart.content()).thenReturn(bufferPublisher); MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder(); bodyBuilder.part("name 1", "value 1"); @@ -103,14 +111,15 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTestCas bodyBuilder.part("logo", logo); bodyBuilder.part("utf8", utf8); bodyBuilder.part("json", new Foo("bar"), MediaType.APPLICATION_JSON); - bodyBuilder.asyncPart("publisher", publisher, String.class); + bodyBuilder.asyncPart("publisher", Flux.just("foo", "bar", "baz"), String.class); + bodyBuilder.asyncPart("partPublisher", Mono.just(mockPart), Part.class); Mono>> result = Mono.just(bodyBuilder.build()); Map hints = Collections.emptyMap(); this.writer.write(result, null, MediaType.MULTIPART_FORM_DATA, this.response, hints).block(Duration.ofSeconds(5)); MultiValueMap requestParts = parse(hints); - assertEquals(6, requestParts.size()); + assertEquals(7, requestParts.size()); Part part = requestParts.getFirst("name 1"); assertTrue(part instanceof FormFieldPart); @@ -145,21 +154,25 @@ public class MultipartHttpMessageWriterTests extends AbstractLeakCheckingTestCas part = requestParts.getFirst("json"); assertEquals("json", part.name()); assertEquals(MediaType.APPLICATION_JSON, part.headers().getContentType()); - - String value = StringDecoder.textPlainOnly(false).decodeToMono(part.content(), - ResolvableType.forClass(String.class), MediaType.TEXT_PLAIN, - Collections.emptyMap()).block(Duration.ZERO); - + String value = decodeToString(part); assertEquals("{\"bar\":\"bar\"}", value); part = requestParts.getFirst("publisher"); assertEquals("publisher", part.name()); + value = decodeToString(part); + assertEquals("foobarbaz", value); - value = StringDecoder.textPlainOnly(false).decodeToMono(part.content(), - ResolvableType.forClass(String.class), MediaType.TEXT_PLAIN, - Collections.emptyMap()).block(Duration.ZERO); + part = requestParts.getFirst("partPublisher"); + assertEquals("partPublisher", part.name()); + value = decodeToString(part); + assertEquals("AaBbCc", value); + } - assertEquals("foobarbaz", value); + @SuppressWarnings("ConstantConditions") + private String decodeToString(Part part) { + return StringDecoder.textPlainOnly().decodeToMono(part.content(), + ResolvableType.forClass(String.class), MediaType.TEXT_PLAIN, + Collections.emptyMap()).block(Duration.ZERO); } @Test // SPR-16402