diff --git a/server/deviceflowhandlers.go b/server/deviceflowhandlers.go index d7bb11c59c..4a8b382d30 100644 --- a/server/deviceflowhandlers.go +++ b/server/deviceflowhandlers.go @@ -36,6 +36,10 @@ func (s *Server) getDeviceVerificationURI() string { func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: + // Grab the parameter(s) from the query. + // If "user_code" is set, pre-populate the user code text field. + // If "invalid" is set, set the invalidAttempt boolean, which will display a message to the user that they + // attempted to redeem an invalid or expired user code. userCode := r.URL.Query().Get("user_code") invalidAttempt, err := strconv.ParseBool(r.URL.Query().Get("invalid")) if err != nil { diff --git a/server/deviceflowhandlers_test.go b/server/deviceflowhandlers_test.go index f82b15db8f..5ab3ddb635 100644 --- a/server/deviceflowhandlers_test.go +++ b/server/deviceflowhandlers_test.go @@ -31,7 +31,7 @@ func TestDeviceVerificationURI(t *testing.T) { u, err := url.Parse(s.issuerURL.String()) if err != nil { - t.Errorf("Could not parse issuer URL %v", err) + t.Fatalf("Could not parse issuer URL %v", err) } u.Path = path.Join(u.Path, "/device/auth/verify_code") @@ -49,16 +49,25 @@ func TestHandleDeviceCode(t *testing.T) { tests := []struct { testName string clientID string + requestType string scopes []string expectedResponseCode int expectedServerResponse string }{ { - testName: "New Valid Code", + testName: "New Code", clientID: "test", + requestType: "POST", scopes: []string{"openid", "profile", "email"}, expectedResponseCode: http.StatusOK, }, + { + testName: "Invalid request Type (GET)", + clientID: "test", + requestType: "GET", + scopes: []string{"openid", "profile", "email"}, + expectedResponseCode: http.StatusBadRequest, + }, } for _, tc := range tests { t.Run(tc.testName, func(t *testing.T) { @@ -74,7 +83,7 @@ func TestHandleDeviceCode(t *testing.T) { u, err := url.Parse(s.issuerURL.String()) if err != nil { - t.Errorf("Could not parse issuer URL %v", err) + t.Fatalf("Could not parse issuer URL %v", err) } u.Path = path.Join(u.Path, "device/code") @@ -83,7 +92,7 @@ func TestHandleDeviceCode(t *testing.T) { for _, scope := range tc.scopes { data.Add("scope", scope) } - req, _ := http.NewRequest("POST", u.String(), bytes.NewBufferString(data.Encode())) + req, _ := http.NewRequest(tc.requestType, u.String(), bytes.NewBufferString(data.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") rr := httptest.NewRecorder() @@ -102,9 +111,6 @@ func TestHandleDeviceCode(t *testing.T) { t.Errorf("Unexpected Device Code Response Format %v", string(body)) } } - if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized { - expectErrorResponse(tc.testName, body, tc.expectedServerResponse, t) - } }) } } @@ -322,15 +328,15 @@ func TestDeviceCallback(t *testing.T) { defer httpServer.Close() if err := s.storage.CreateAuthCode(tc.testAuthCode); err != nil { - t.Errorf("failed to create auth code: %v", err) + t.Fatalf("failed to create auth code: %v", err) } if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { - t.Errorf("failed to create device request: %v", err) + t.Fatalf("failed to create device request: %v", err) } if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil { - t.Errorf("failed to create device token: %v", err) + t.Fatalf("failed to create device token: %v", err) } client := storage.Client{ @@ -344,7 +350,7 @@ func TestDeviceCallback(t *testing.T) { u, err := url.Parse(s.issuerURL.String()) if err != nil { - t.Errorf("Could not parse issuer URL %v", err) + t.Fatalf("Could not parse issuer URL %v", err) } u.Path = path.Join(u.Path, "device/callback") q := u.Query() @@ -506,16 +512,16 @@ func TestDeviceTokenResponse(t *testing.T) { defer httpServer.Close() if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { - t.Errorf("Failed to store device token %v", err) + t.Fatalf("Failed to store device token %v", err) } if err := s.storage.CreateDeviceToken(tc.testDeviceToken); err != nil { - t.Errorf("Failed to store device token %v", err) + t.Fatalf("Failed to store device token %v", err) } u, err := url.Parse(s.issuerURL.String()) if err != nil { - t.Errorf("Could not parse issuer URL %v", err) + t.Fatalf("Could not parse issuer URL %v", err) } u.Path = path.Join(u.Path, "device/token") @@ -540,7 +546,7 @@ func TestDeviceTokenResponse(t *testing.T) { t.Errorf("Could read token response %v", err) } if tc.expectedResponseCode == http.StatusBadRequest || tc.expectedResponseCode == http.StatusUnauthorized { - expectErrorResponse(tc.testName, body, tc.expectedServerResponse, t) + expectJsonErrorResponse(tc.testName, body, tc.expectedServerResponse, t) } else if string(body) != tc.expectedServerResponse { t.Errorf("Unexpected Server Response. Expected %v got %v", tc.expectedServerResponse, string(body)) } @@ -548,7 +554,7 @@ func TestDeviceTokenResponse(t *testing.T) { } } -func expectErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) { +func expectJsonErrorResponse(testCase string, body []byte, expectedError string, t *testing.T) { jsonMap := make(map[string]interface{}) err := json.Unmarshal(body, &jsonMap) if err != nil { @@ -637,12 +643,12 @@ func TestVerifyCodeResponse(t *testing.T) { defer httpServer.Close() if err := s.storage.CreateDeviceRequest(tc.testDeviceRequest); err != nil { - t.Errorf("Failed to store device token %v", err) + t.Fatalf("Failed to store device token %v", err) } u, err := url.Parse(s.issuerURL.String()) if err != nil { - t.Errorf("Could not parse issuer URL %v", err) + t.Fatalf("Could not parse issuer URL %v", err) } u.Path = path.Join(u.Path, "device/auth/verify_code") diff --git a/storage/conformance/conformance.go b/storage/conformance/conformance.go index 0b73ce1556..a550a530f2 100644 --- a/storage/conformance/conformance.go +++ b/storage/conformance/conformance.go @@ -1030,6 +1030,6 @@ func testDeviceTokenCRUD(t *testing.T, s storage.Storage) { t.Fatalf("update failed, wanted token status=%v got %v", "complete", got.Status) } if got.Token != "token data" { - t.Fatalf("update failed, wanted token =%v got %v", "token data", got.Token) + t.Fatalf("update failed, wanted token %v got %v", "token data", got.Token) } }