diff --git a/src/main/java/synapseawsconsolelogin/Auth.java b/src/main/java/synapseawsconsolelogin/Auth.java index a30f488..afae572 100644 --- a/src/main/java/synapseawsconsolelogin/Auth.java +++ b/src/main/java/synapseawsconsolelogin/Auth.java @@ -79,6 +79,9 @@ public class Auth extends HttpServlet { private static Logger logger = Logger.getLogger("Auth"); + private static final String StrictTransportSecurityHeaderName = "Strict-Transport-Security"; + private static final String StrictTransportSecurityHeaderValue = "max-age=31536000; includeSubDomains"; + private static final String TEAM_CLAIM_NAME = "team"; // templates for constructing the 'claims' part of the OIDC authorization request @@ -368,6 +371,9 @@ public String getAuthorizeUrl(String state) { @Override public void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException { + if (req.isSecure()) { + resp.setHeader(StrictTransportSecurityHeaderName, StrictTransportSecurityHeaderValue); + } resp.setContentType("text/plain"); try (ServletOutputStream os=resp.getOutputStream()) { os.println("Not found."); @@ -405,11 +411,14 @@ static int synapseExceptionStatus(SynapseServerException e) { // and we don't expect others. For the rest we'll just return 500. return 500; } - + @Override public void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException { try { + if (req.isSecure()) { + resp.setHeader(StrictTransportSecurityHeaderName, StrictTransportSecurityHeaderValue); + } doGetIntern(req, resp); } catch (Exception e) { handleException(e, resp); diff --git a/src/test/java/synapseawsconsolelogin/AuthTest.java b/src/test/java/synapseawsconsolelogin/AuthTest.java index ff70ae6..e754406 100644 --- a/src/test/java/synapseawsconsolelogin/AuthTest.java +++ b/src/test/java/synapseawsconsolelogin/AuthTest.java @@ -5,9 +5,10 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; -import static org.mockito.Matchers.eq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -669,4 +670,57 @@ public void testDoGet_PersonalAccessToken() throws Exception { verify(mockHttpResponse).setContentType("application/force-download"); } + private static void hstsIsSet(HttpServletResponse mockHttpResponse) { + verify(mockHttpResponse).setHeader("Strict-Transport-Security", "max-age=31536000; includeSubDomains"); + } + + private static void hstsIsNOTSet(HttpServletResponse mockHttpResponse) { + verify(mockHttpResponse, never()).setHeader(eq("Strict-Transport-Security"), anyString()); + } + + @Test + public void testHSTSSecureGet() throws Exception { + mockIncomingUrl("https://www.foo.com", "/unknown"); + when (mockHttpRequest.isSecure()).thenReturn(true); + + // method under test + auth.doGet(mockHttpRequest, mockHttpResponse); + + hstsIsSet(mockHttpResponse); + } + + @Test + public void testHSTSSecurePost() throws Exception { + mockIncomingUrl("https://www.foo.com", "/unknown"); + when (mockHttpRequest.isSecure()).thenReturn(true); + + // method under test + auth.doPost(mockHttpRequest, mockHttpResponse); + + hstsIsSet(mockHttpResponse); + } + + @Test + public void testHSTSInsecureGet() throws Exception { + mockIncomingUrl("http://www.foo.com", "/unknown"); + when (mockHttpRequest.isSecure()).thenReturn(false); + + // method under test + auth.doGet(mockHttpRequest, mockHttpResponse); + + hstsIsNOTSet(mockHttpResponse); + } + + @Test + public void testHSTSInsecurePost() throws Exception { + mockIncomingUrl("http://www.foo.com", "/unknown"); + when (mockHttpRequest.isSecure()).thenReturn(false); + + // method under test + auth.doPost(mockHttpRequest, mockHttpResponse); + + hstsIsNOTSet(mockHttpResponse); + } + + }