diff --git a/invoker/core/src/main/java/com/google/cloud/functions/invoker/http/TimeoutFilter.java b/invoker/core/src/main/java/com/google/cloud/functions/invoker/http/TimeoutFilter.java new file mode 100644 index 00000000..e0577f9b --- /dev/null +++ b/invoker/core/src/main/java/com/google/cloud/functions/invoker/http/TimeoutFilter.java @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.cloud.functions.invoker.http; + +import java.io.IOException; +import java.util.Timer; +import java.util.TimerTask; +import java.util.logging.Logger; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletResponse; + +public class TimeoutFilter implements Filter { + + private static final Logger logger = Logger.getLogger(TimeoutFilter.class.getName()); + private final int timeoutMs; + + public TimeoutFilter(int timeoutSeconds) { + this.timeoutMs = timeoutSeconds * 1000; // Convert seconds to milliseconds + } + + @Override + public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) + throws IOException, ServletException { + Timer timer = new Timer(true); + TimerTask timeoutTask = + new TimerTask() { + @Override + public void run() { + if (response instanceof HttpServletResponse) { + try { + ((HttpServletResponse) response) + .sendError(HttpServletResponse.SC_REQUEST_TIMEOUT, "Request timed out"); + } catch (IOException e) { + logger.warning("Error while sending HTTP response: " + e.toString()); + } + } else { + try { + response.getWriter().write("Request timed out"); + } catch (IOException e) { + logger.warning("Error while writing response: " + e.toString()); + } + } + } + }; + + timer.schedule(timeoutTask, timeoutMs); + + try { + chain.doFilter(request, response); + timeoutTask.cancel(); + } finally { + timer.purge(); + } + } +} diff --git a/invoker/core/src/main/java/com/google/cloud/functions/invoker/runner/Invoker.java b/invoker/core/src/main/java/com/google/cloud/functions/invoker/runner/Invoker.java index 892d6038..da5e72ec 100644 --- a/invoker/core/src/main/java/com/google/cloud/functions/invoker/runner/Invoker.java +++ b/invoker/core/src/main/java/com/google/cloud/functions/invoker/runner/Invoker.java @@ -25,6 +25,7 @@ import com.google.cloud.functions.invoker.HttpFunctionExecutor; import com.google.cloud.functions.invoker.TypedFunctionExecutor; import com.google.cloud.functions.invoker.gcf.JsonLogHandler; +import com.google.cloud.functions.invoker.http.TimeoutFilter; import java.io.File; import java.io.IOException; import java.io.UncheckedIOException; @@ -38,6 +39,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.EnumSet; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -48,6 +50,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Stream; +import javax.servlet.DispatcherType; import javax.servlet.MultipartConfigElement; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; @@ -59,6 +62,7 @@ import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.handler.HandlerWrapper; +import org.eclipse.jetty.servlet.FilterHolder; import org.eclipse.jetty.servlet.ServletContextHandler; import org.eclipse.jetty.servlet.ServletHolder; import org.eclipse.jetty.util.thread.QueuedThreadPool; @@ -324,6 +328,7 @@ private void startServer(boolean join) throws Exception { ServletHolder servletHolder = new ServletHolder(servlet); servletHolder.getRegistration().setMultipartConfig(new MultipartConfigElement("")); servletContextHandler.addServlet(servletHolder, "/*"); + servletContextHandler = addTimerFilterForRequestTimeout(servletContextHandler); server.start(); logServerInfo(); @@ -393,6 +398,18 @@ private HttpServlet servletForDeducedSignatureType(Class functionClass) { throw new RuntimeException(error); } + private ServletContextHandler addTimerFilterForRequestTimeout( + ServletContextHandler servletContextHandler) { + String timeoutSeconds = System.getenv("CLOUD_RUN_TIMEOUT_SECONDS"); + if (timeoutSeconds == null) { + return servletContextHandler; + } + int seconds = Integer.parseInt(timeoutSeconds); + FilterHolder holder = new FilterHolder(new TimeoutFilter(seconds)); + servletContextHandler.addFilter(holder, "/*", EnumSet.of(DispatcherType.REQUEST)); + return servletContextHandler; + } + static URL[] classpathToUrls(String classpath) { String[] components = classpath.split(File.pathSeparator); List urls = new ArrayList<>(); diff --git a/invoker/core/src/test/java/com/google/cloud/functions/invoker/IntegrationTest.java b/invoker/core/src/test/java/com/google/cloud/functions/invoker/IntegrationTest.java index 3f3de837..82197547 100644 --- a/invoker/core/src/test/java/com/google/cloud/functions/invoker/IntegrationTest.java +++ b/invoker/core/src/test/java/com/google/cloud/functions/invoker/IntegrationTest.java @@ -51,6 +51,7 @@ import java.time.OffsetDateTime; import java.time.ZoneOffset; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -252,6 +253,34 @@ public void helloWorld() throws Exception { ROBOTS_TXT_TEST_CASE)); } + @Test + public void timeoutHttpSuccess() throws Exception { + testFunction( + SignatureType.HTTP, + fullTarget("TimeoutHttp"), + ImmutableList.of(), + ImmutableList.of( + TestCase.builder() + .setExpectedResponseText("finished\n") + .setExpectedResponseText(Optional.empty()) + .build()), + ImmutableMap.of("CLOUD_RUN_TIMEOUT_SECONDS", "3")); + } + + @Test + public void timeoutHttpTimesOut() throws Exception { + testFunction( + SignatureType.HTTP, + fullTarget("TimeoutHttp"), + ImmutableList.of(), + ImmutableList.of( + TestCase.builder() + .setExpectedResponseCode(408) + .setExpectedResponseText(Optional.empty()) + .build()), + ImmutableMap.of("CLOUD_RUN_TIMEOUT_SECONDS", "1")); + } + @Test public void exceptionHttp() throws Exception { String exceptionExpectedOutput = @@ -290,7 +319,8 @@ public void exceptionBackground() throws Exception { .setRequestText(gcfRequestText) .setExpectedResponseCode(500) .setExpectedOutput(exceptionExpectedOutput) - .build())); + .build()), + Collections.emptyMap()); } @Test @@ -400,7 +430,8 @@ public void typedFunction() throws Exception { TestCase.builder() .setRequestText(originalJson) .setExpectedResponseText("{\"fullName\":\"JohnDoe\"}") - .build())); + .build()), + Collections.emptyMap()); } @Test @@ -410,7 +441,8 @@ public void typedVoidFunction() throws Exception { fullTarget("TypedVoid"), ImmutableList.of(), ImmutableList.of( - TestCase.builder().setRequestText("{}").setExpectedResponseCode(204).build())); + TestCase.builder().setRequestText("{}").setExpectedResponseCode(204).build()), + Collections.emptyMap()); } @Test @@ -424,7 +456,8 @@ public void typedCustomFormat() throws Exception { .setRequestText("abc\n123\n$#@\n") .setExpectedResponseText("abc123$#@") .setExpectedResponseCode(200) - .build())); + .build()), + Collections.emptyMap()); } private void backgroundTest(String target) throws Exception { @@ -595,7 +628,8 @@ public void classpathOptionHttp() throws Exception { SignatureType.HTTP, "com.example.functionjar.Foreground", ImmutableList.of("--classpath", functionJarString()), - ImmutableList.of(testCase)); + ImmutableList.of(testCase), + Collections.emptyMap()); } /** Like {@link #classpathOptionHttp} but for background functions. */ @@ -612,7 +646,8 @@ public void classpathOptionBackground() throws Exception { SignatureType.BACKGROUND, "com.example.functionjar.Background", ImmutableList.of("--classpath", functionJarString()), - ImmutableList.of(TestCase.builder().setRequestText(json.toString()).build())); + ImmutableList.of(TestCase.builder().setRequestText(json.toString()).build()), + Collections.emptyMap()); } /** Like {@link #classpathOptionHttp} but for typed functions. */ @@ -629,7 +664,8 @@ public void classpathOptionTyped() throws Exception { TestCase.builder() .setRequestText(originalJson) .setExpectedResponseText("{\"fullName\":\"JohnDoe\"}") - .build())); + .build()), + Collections.emptyMap()); } // In these tests, we test a number of different functions that express the same functionality @@ -643,7 +679,12 @@ private void backgroundTest( for (TestCase testCase : testCases) { File snoopFile = testCase.snoopFile().get(); snoopFile.delete(); - testFunction(signatureType, functionTarget, ImmutableList.of(), ImmutableList.of(testCase)); + testFunction( + signatureType, + functionTarget, + ImmutableList.of(), + ImmutableList.of(testCase), + Collections.emptyMap()); String snooped = new String(Files.readAllBytes(snoopFile.toPath()), StandardCharsets.UTF_8); Gson gson = new Gson(); JsonObject snoopedJson = gson.fromJson(snooped, JsonObject.class); @@ -667,16 +708,18 @@ private void checkSnoopFile(TestCase testCase) throws IOException { } private void testHttpFunction(String target, List testCases) throws Exception { - testFunction(SignatureType.HTTP, target, ImmutableList.of(), testCases); + testFunction(SignatureType.HTTP, target, ImmutableList.of(), testCases, Collections.emptyMap()); } private void testFunction( SignatureType signatureType, String target, ImmutableList extraArgs, - List testCases) + List testCases, + Map environmentVariables) throws Exception { - ServerProcess serverProcess = startServer(signatureType, target, extraArgs); + ServerProcess serverProcess = + startServer(signatureType, target, extraArgs, environmentVariables); try { HttpClient httpClient = new HttpClient(); httpClient.start(); @@ -772,7 +815,10 @@ public void close() { } private ServerProcess startServer( - SignatureType signatureType, String target, ImmutableList extraArgs) + SignatureType signatureType, + String target, + ImmutableList extraArgs, + Map environmentVariables) throws IOException, InterruptedException { File javaHome = new File(System.getProperty("java.home")); assertThat(javaHome.exists()).isTrue(); @@ -798,6 +844,7 @@ private ServerProcess startServer( "FUNCTION_TARGET", target); processBuilder.environment().putAll(environment); + processBuilder.environment().putAll(environmentVariables); Process serverProcess = processBuilder.start(); CountDownLatch ready = new CountDownLatch(1); StringBuilder output = new StringBuilder(); diff --git a/invoker/core/src/test/java/com/google/cloud/functions/invoker/testfunctions/TimeoutHttp.java b/invoker/core/src/test/java/com/google/cloud/functions/invoker/testfunctions/TimeoutHttp.java new file mode 100644 index 00000000..c73e52d2 --- /dev/null +++ b/invoker/core/src/test/java/com/google/cloud/functions/invoker/testfunctions/TimeoutHttp.java @@ -0,0 +1,18 @@ +package com.google.cloud.functions.invoker.testfunctions; + +import com.google.cloud.functions.HttpFunction; +import com.google.cloud.functions.HttpRequest; +import com.google.cloud.functions.HttpResponse; + +public class TimeoutHttp implements HttpFunction { + + @Override + public void service(HttpRequest request, HttpResponse response) throws Exception { + try { + Thread.sleep(2000); + } catch (InterruptedException e) { + response.getWriter().close(); + } + response.getWriter().write("finished\n"); + } +}