diff --git a/runner-core/src/main/java/org/apache/apisix/plugin/runner/handler/RpcCallHandler.java b/runner-core/src/main/java/org/apache/apisix/plugin/runner/handler/RpcCallHandler.java index f054ccfe..4876356e 100644 --- a/runner-core/src/main/java/org/apache/apisix/plugin/runner/handler/RpcCallHandler.java +++ b/runner-core/src/main/java/org/apache/apisix/plugin/runner/handler/RpcCallHandler.java @@ -17,39 +17,38 @@ package org.apache.apisix.plugin.runner.handler; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.Map; -import java.util.Objects; -import java.util.Queue; -import java.util.Set; - import com.google.common.cache.Cache; import io.github.api7.A6.Err.Code; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import lombok.RequiredArgsConstructor; import org.apache.apisix.plugin.runner.A6Conf; -import org.apache.apisix.plugin.runner.A6ErrRequest; -import org.apache.apisix.plugin.runner.A6ErrResponse; import org.apache.apisix.plugin.runner.A6Request; -import org.apache.apisix.plugin.runner.ExtraInfoRequest; -import org.apache.apisix.plugin.runner.ExtraInfoResponse; import org.apache.apisix.plugin.runner.HttpRequest; -import org.apache.apisix.plugin.runner.HttpResponse; import org.apache.apisix.plugin.runner.PostRequest; +import org.apache.apisix.plugin.runner.HttpResponse; import org.apache.apisix.plugin.runner.PostResponse; +import org.apache.apisix.plugin.runner.A6ErrRequest; +import org.apache.apisix.plugin.runner.A6ErrResponse; +import org.apache.apisix.plugin.runner.ExtraInfoResponse; +import org.apache.apisix.plugin.runner.ExtraInfoRequest; +import org.apache.apisix.plugin.runner.constants.Constants; +import org.apache.apisix.plugin.runner.filter.PluginFilter; +import org.apache.apisix.plugin.runner.filter.PluginFilterChain; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.util.CollectionUtils; -import lombok.RequiredArgsConstructor; -import org.apache.apisix.plugin.runner.constants.Constants; -import org.apache.apisix.plugin.runner.filter.PluginFilter; -import org.apache.apisix.plugin.runner.filter.PluginFilterChain; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.HashSet; +import java.util.Collection; +import java.util.Objects; @RequiredArgsConstructor public class RpcCallHandler extends SimpleChannelInboundHandler { @@ -168,7 +167,7 @@ private void handleHttpRespCall(ChannelHandlerContext ctx, PostRequest request) // save HttpCallRequest postReq = request; - postResp = new PostResponse(postReq.getRequestId()); + postResp = new PostResponse(postReq.getRequestId(), request.getUpstreamHeaders()); confToken = postReq.getConfToken(); A6Conf conf = cache.getIfPresent(confToken); diff --git a/runner-core/src/test/java/org/apache/apisix/plugin/runner/handler/PostFilterTest.java b/runner-core/src/test/java/org/apache/apisix/plugin/runner/handler/PostFilterTest.java index 98b7e9e0..c2912bce 100644 --- a/runner-core/src/test/java/org/apache/apisix/plugin/runner/handler/PostFilterTest.java +++ b/runner-core/src/test/java/org/apache/apisix/plugin/runner/handler/PostFilterTest.java @@ -82,7 +82,7 @@ public void postFilter(PostRequest request, PostResponse response, PluginFilterC System.out.println("do post filter: UpStreamFilter, order: " + chain.getIndex()); System.out.println("do post filter: UpStreamFilter, conf: " + request.getConfig(this)); System.out.println("do post filter: UpStreamFilter, upstreamStatusCode: " + request.getUpstreamStatusCode()); - for (Map.Entry header : request.getUpstreamHeaders().entrySet()) { + for (Map.Entry> header : request.getUpstreamHeaders().entrySet()) { System.out.println("do post filter: UpStreamFilter, upstreamHeader key: " + header.getKey()); System.out.println("do post filter: UpStreamFilter, upstreamHeader value: " + header.getValue()); } @@ -150,6 +150,6 @@ void doPostFilter() { Assertions.assertTrue(bytes.toString().contains("do post filter: UpStreamFilter, conf: {\"conf_key1\":\"conf_value1\",\"conf_key2\":2}")); Assertions.assertTrue(bytes.toString().contains("do post filter: UpStreamFilter, upstreamStatusCode: 418")); Assertions.assertTrue(bytes.toString().contains("do post filter: UpStreamFilter, upstreamHeader key: headerKey")); - Assertions.assertTrue(bytes.toString().contains("do post filter: UpStreamFilter, upstreamHeader value: headerValue")); + Assertions.assertTrue(bytes.toString().contains("do post filter: UpStreamFilter, upstreamHeader value: [headerValue]")); } } diff --git a/runner-plugin-sdk/src/main/java/org/apache/apisix/plugin/runner/PostRequest.java b/runner-plugin-sdk/src/main/java/org/apache/apisix/plugin/runner/PostRequest.java index 88315f59..6a10fe5d 100644 --- a/runner-plugin-sdk/src/main/java/org/apache/apisix/plugin/runner/PostRequest.java +++ b/runner-plugin-sdk/src/main/java/org/apache/apisix/plugin/runner/PostRequest.java @@ -25,8 +25,10 @@ import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.ArrayList; public class PostRequest implements A6Request { private final Req req; @@ -35,7 +37,7 @@ public class PostRequest implements A6Request { private Map config; - private Map headers; + private Map> headers; private Integer status; @@ -76,12 +78,13 @@ public String getConfig(PluginFilter filter) { return config.getOrDefault(filter.name(), null); } - public Map getUpstreamHeaders() { + public Map> getUpstreamHeaders() { if (Objects.isNull(headers)) { headers = new HashMap<>(); for (int i = 0; i < req.headersLength(); i++) { TextEntry header = req.headers(i); - headers.put(header.name(), header.value()); + headers.putIfAbsent(header.name(), new ArrayList<>()); + headers.get(header.name()).add(header.value()); } } return headers; diff --git a/runner-plugin-sdk/src/main/java/org/apache/apisix/plugin/runner/PostResponse.java b/runner-plugin-sdk/src/main/java/org/apache/apisix/plugin/runner/PostResponse.java index 127de5a0..f4736f1c 100644 --- a/runner-plugin-sdk/src/main/java/org/apache/apisix/plugin/runner/PostResponse.java +++ b/runner-plugin-sdk/src/main/java/org/apache/apisix/plugin/runner/PostResponse.java @@ -29,8 +29,10 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.ArrayList; public class PostResponse implements A6Response { @@ -42,12 +44,13 @@ public class PostResponse implements A6Response { private Integer statusCode; - private Map headers; + private Map> headers; private Charset charset; - public PostResponse(long requestId) { + public PostResponse(long requestId, Map> headers) { this.requestId = requestId; + this.headers = headers != null ? new HashMap<>(headers) : new HashMap<>(); this.charset = StandardCharsets.UTF_8; } @@ -63,16 +66,27 @@ public ByteBuffer encode() { int headerIndex = -1; if (!CollectionUtils.isEmpty(headers)) { - int[] headerTexts = new int[headers.size()]; + int hsize = 0; + for (String hkey: headers.keySet()) { + List headerValues = headers.get(hkey); + hsize += CollectionUtils.isEmpty(headerValues) ? 0 : headerValues.size(); + } + + int[] headerTexts = new int[hsize]; int i = -1; - for (Map.Entry header : headers.entrySet()) { + for (Map.Entry> header : headers.entrySet()) { int key = builder.createString(header.getKey()); - int value = 0; - if (!Objects.isNull(header.getValue())) { - value = builder.createString(header.getValue()); + List headerValues = header.getValue(); + if (!CollectionUtils.isEmpty(headerValues)) { + for (String hv: headerValues) { + int value = 0; + if (!Objects.isNull(hv)) { + value = builder.createString(hv); + } + int text = TextEntry.createTextEntry(builder, key, value); + headerTexts[++i] = text; + } } - int text = TextEntry.createTextEntry(builder, key, value); - headerTexts[++i] = text; } headerIndex = Resp.createHeadersVector(builder, headerTexts); } @@ -116,7 +130,27 @@ public void setHeader(String headerKey, String headerValue) { if (Objects.isNull(headers)) { headers = new HashMap<>(); } - headers.put(headerKey, headerValue); + + headers.put(headerKey, new ArrayList<>()); + headers.get(headerKey).add(headerValue); + } + + private void addHeader(String headerKey, String headerValue) { + if (headerKey == null) { + logger.warn("headerKey is null, ignore it"); + return; + } + + if (Objects.isNull(headers)) { + headers = new HashMap<>(); + } + + headers.putIfAbsent(headerKey, new ArrayList<>()); + headers.get(headerKey).add(headerValue); + } + + private Map> headers() { + return headers; } public void setBody(String body) { diff --git a/runner-plugin-sdk/src/test/java/org/apache/apisix/plugin/runner/PostResponseTest.java b/runner-plugin-sdk/src/test/java/org/apache/apisix/plugin/runner/PostResponseTest.java index 1c8a4de8..2b2ba854 100644 --- a/runner-plugin-sdk/src/test/java/org/apache/apisix/plugin/runner/PostResponseTest.java +++ b/runner-plugin-sdk/src/test/java/org/apache/apisix/plugin/runner/PostResponseTest.java @@ -23,9 +23,11 @@ import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.ArrayList; +import java.util.Map; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -37,8 +39,9 @@ void testEncodeWithSetCharset() { long requestId = 1L; String body = "dummy body"; Charset charset = StandardCharsets.UTF_16; + Map> headers = new HashMap<>(); - PostResponse postResponse = new PostResponse(requestId); + PostResponse postResponse = new PostResponse(requestId, headers); postResponse.setBody(body); postResponse.setCharset(charset); @@ -53,8 +56,9 @@ void testEncodeWithoutSetCharset() { long requestId = 1L; String body = "dummy body"; Charset charset = StandardCharsets.UTF_8; + Map> headers = new HashMap<>(); - PostResponse postResponse = new PostResponse(requestId); + PostResponse postResponse = new PostResponse(requestId, headers); postResponse.setBody(body); ByteBuffer encoded = postResponse.encode(); diff --git a/sample/src/main/java/org/apache/apisix/plugin/runner/filter/ResponseFilter.java b/sample/src/main/java/org/apache/apisix/plugin/runner/filter/ResponseFilter.java index 7a45aa32..e45fa4d9 100644 --- a/sample/src/main/java/org/apache/apisix/plugin/runner/filter/ResponseFilter.java +++ b/sample/src/main/java/org/apache/apisix/plugin/runner/filter/ResponseFilter.java @@ -21,8 +21,10 @@ import org.apache.apisix.plugin.runner.PostRequest; import org.apache.apisix.plugin.runner.PostResponse; import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; import java.util.HashMap; +import java.util.List; import java.util.Map; @Component @@ -39,8 +41,8 @@ public void postFilter(PostRequest request, PostResponse response, PluginFilterC Map conf = new HashMap<>(); conf = gson.fromJson(configStr, conf.getClass()); - Map headers = request.getUpstreamHeaders(); - String contentType = headers.get("Content-Type"); + Map> headers = request.getUpstreamHeaders(); + String contentType = CollectionUtils.isEmpty(headers.get("Content-Type")) ? null : headers.get("Content-Type").get(0); Integer upstreamStatusCode = request.getUpstreamStatusCode(); response.setStatusCode(Double.valueOf(conf.get("response_code").toString()).intValue());