Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paging support for azure wrapper #575

Merged
merged 15 commits into from
Sep 4, 2023
96 changes: 75 additions & 21 deletions internal/wrappers/azure-http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"net/http"
"reflect"
"time"

b64 "encoding/base64"
Expand All @@ -27,10 +28,12 @@ const (
azureBaseReposURL = "%s%s/%s/_apis/git/repositories"
azureBaseProjectsURL = "%s%s/_apis/projects"
azureTop = "$top"
azureTopValue = "1000000"
azurePage = "$skip"
azureLayoutTime = "2006-01-02"
basicFormat = "Basic %s"
failedAuth = "failed Azure Authentication"
unauthorized = "unauthorized: verify if the organization you provided is correct"
azurePageLenValue = 100
)

func NewAzureWrapper() AzureWrapper {
Expand All @@ -44,56 +47,82 @@ func (g *AzureHTTPWrapper) GetCommits(url, organizationName, projectName, reposi
error,
) {
var err error
var repository AzureRootCommit
var rootCommit AzureRootCommit
var pages []AzureRootCommit
var queryParams = make(map[string]string)

commitsURL := fmt.Sprintf(azureBaseCommitURL, url, organizationName, projectName, repositoryName)
queryParams[azureSearchDate] = getThreeMonthsTime()
queryParams[azureAPIVersion] = azureAPIVersionValue
queryParams[azureTop] = fmt.Sprintf("%d", azurePageLenValue)

err = g.get(commitsURL, encodeToken(token), &repository, queryParams, basicFormat)
err = g.paginateGetter(commitsURL, encodeToken(token), &AzureRootCommit{}, &pages, queryParams, basicFormat)
if err != nil {
return rootCommit, err
}

for _, commitPage := range pages {
rootCommit.Commits = append(rootCommit.Commits, commitPage.Commits...)
}

return repository, err
return rootCommit, err
}

func (g *AzureHTTPWrapper) GetRepositories(url, organizationName, projectName, token string) (AzureRootRepo, error) {
var err error
var repository AzureRootRepo
var rootRepo AzureRootRepo
var pages []AzureRootRepo
var queryParams = make(map[string]string)

reposURL := fmt.Sprintf(azureBaseReposURL, url, organizationName, projectName)
queryParams[azureTop] = azureTopValue
queryParams[azureAPIVersion] = azureAPIVersionValue

err = g.get(reposURL, encodeToken(token), &repository, queryParams, basicFormat)
// unfortunately, Azure DevOps does not support pagination for repositories so we have to fetch all the repos
err = g.paginateGetter(reposURL, encodeToken(token), &AzureRootRepo{}, &pages, queryParams, basicFormat)
if err != nil {
return rootRepo, err
}

for _, repositoryPage := range pages {
rootRepo.Repos = append(rootRepo.Repos, repositoryPage.Repos...)
}

return repository, err
return rootRepo, err
}

func (g *AzureHTTPWrapper) GetProjects(url, organizationName, token string) (AzureRootProject, error) {
var err error
var project AzureRootProject
var rootProject AzureRootProject
var pages []AzureRootProject
var queryParams = make(map[string]string)

reposURL := fmt.Sprintf(azureBaseProjectsURL, url, organizationName)
queryParams[azureAPIVersion] = azureAPIVersionValue
queryParams[azureTop] = fmt.Sprintf("%d", azurePageLenValue)

err = g.paginateGetter(reposURL, encodeToken(token), &AzureRootProject{}, &pages, queryParams, basicFormat)
if err != nil {
return rootProject, err
}

err = g.get(reposURL, encodeToken(token), &project, queryParams, basicFormat)
for _, projectPage := range pages {
rootProject.Projects = append(rootProject.Projects, projectPage.Projects...)
}

return project, err
return rootProject, err
}

func (g *AzureHTTPWrapper) get(
url, token string,
target interface{},
queryParams map[string]string,
authFormat string,
) error {
) (bool, error) {
var err error

req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
return false, err
}

if len(token) > 0 {
Expand All @@ -108,7 +137,7 @@ func (g *AzureHTTPWrapper) get(
resp, err := g.client.Do(req)

if err != nil {
return err
return false, err
}

logger.PrintRequest(req)
Expand All @@ -123,25 +152,50 @@ func (g *AzureHTTPWrapper) get(
case http.StatusOK:
err = json.NewDecoder(resp.Body).Decode(target)
if err != nil {
return err
return false, err
}
// State sent when expired token
case http.StatusNonAuthoritativeInfo:
err = errors.New(failedAuth)
return err
return false, err
// State sent when no token is provided
case http.StatusForbidden:
err = errors.New(failedAuth)
return err
case http.StatusNotFound:
// Case the commit/project does not exist in the organization
return nil
return false, err
case http.StatusUnauthorized:
return false, errors.New(unauthorized)
default:
body, err := io.ReadAll(resp.Body)
if err != nil {
return false, err
}
return false, errors.Errorf("%s - %s", string(body), resp.Status)
}
headerLink := resp.Header.Get("Link")
continuationToken := resp.Header.Get("X-Ms-Continuationtoken")
return headerLink != "" || continuationToken != "", nil
}

func (g *AzureHTTPWrapper) paginateGetter(url, token string, target, slice interface{}, queryParams map[string]string, format string) error {
var currentPage = 0
for {
queryParams[azurePage] = fmt.Sprintf("%d", currentPage)
hasNextPage, err := g.get(url, token, target, queryParams, format)
if err != nil {
return err
}
return errors.New(string(body))

slicePtr := reflect.ValueOf(slice)
sliceValue := slicePtr.Elem()
sliceValue.Set(reflect.Append(sliceValue, reflect.ValueOf(target).Elem()))

target = reflect.New(reflect.TypeOf(target).Elem()).Interface()

if !hasNextPage {
break
}

currentPage += azurePageLenValue
}
return nil
}
Expand Down
61 changes: 46 additions & 15 deletions test/integration/user-count-azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ import (
)

const (
envOrg = "AZURE_ORG"
envToken = "AZURE_TOKEN"
envProject = "AZURE_PROJECT"
envRepos = "AZURE_REPOS"
projectFlag = "projects"
envOrg = "AZURE_ORG"
envToken = "AZURE_TOKEN"
envProject = "AZURE_PROJECT"
envRepos = "AZURE_REPOS"
projectFlag = "projects"
)

func TestAzureUserCountOrgs(t *testing.T) {
_ = viper.BindEnv(pat)
buffer := executeCmdNilAssertion(
t,
"Counting contributors from checkmarxdev should pass",
"utils",
utilsCommand,
usercount.UcCommand,
usercount.AzureCommand,
flag(usercount.OrgsFlag),
Expand Down Expand Up @@ -56,7 +56,7 @@ func TestAzureUserCountProjects(t *testing.T) {
buffer := executeCmdNilAssertion(
t,
"Counting contributors from checkmarxdev should pass",
"utils",
utilsCommand,
usercount.UcCommand,
usercount.AzureCommand,
flag(usercount.OrgsFlag),
Expand Down Expand Up @@ -86,7 +86,7 @@ func TestAzureUserCountRepos(t *testing.T) {
buffer := executeCmdNilAssertion(
t,
"Counting contributors from checkmarxdev should pass",
"utils",
utilsCommand,
usercount.UcCommand,
usercount.AzureCommand,
flag(usercount.OrgsFlag),
Expand Down Expand Up @@ -117,23 +117,22 @@ func TestAzureUserCountOrgsFailed(t *testing.T) {
_ = viper.BindEnv(pat)
err, _ := executeCommand(
t,
"utils",
utilsCommand,
usercount.UcCommand,
usercount.AzureCommand,
flag(params.SCMTokenFlag),
os.Getenv(envToken),
flag(params.FormatFlag),
printer.FormatJSON,
)

assertError(t, err, "Provide at least one organization")
}

func TestAzureUserCountReposFailed(t *testing.T) {
_ = viper.BindEnv(pat)
err, _ := executeCommand(
t,
"utils",
utilsCommand,
usercount.UcCommand,
usercount.AzureCommand,
flag(usercount.OrgsFlag),
Expand All @@ -145,15 +144,14 @@ func TestAzureUserCountReposFailed(t *testing.T) {
flag(params.FormatFlag),
printer.FormatJSON,
)

assertError(t, err, "Provide at least one project")
}

func TestAzureCountMultipleWorkspaceFailed(t *testing.T) {
_ = viper.BindEnv(pat)
err, _ := executeCommand(
t,
"utils",
utilsCommand,
usercount.UcCommand,
usercount.AzureCommand,
flag(usercount.OrgsFlag),
Expand All @@ -167,6 +165,39 @@ func TestAzureCountMultipleWorkspaceFailed(t *testing.T) {
flag(params.FormatFlag),
printer.FormatJSON,
)

assertError(t, err, "You must provide a single org for repo counting")
}
}

func TestAzureUserCountWrongToken(t *testing.T) {
_ = viper.BindEnv(pat)
err, _ := executeCommand(
t,
utilsCommand,
usercount.UcCommand,
usercount.AzureCommand,
flag(usercount.OrgsFlag),
os.Getenv(envOrg),
flag(params.SCMTokenFlag),
"wrong",
flag(params.FormatFlag),
printer.FormatJSON,
)
assertError(t, err, "failed Azure Authentication")
}

func TestAzureUserCountWrongOrg(t *testing.T) {
_ = viper.BindEnv(pat)
err, _ := executeCommand(
t,
utilsCommand,
usercount.UcCommand,
usercount.AzureCommand,
flag(usercount.OrgsFlag),
"wrong",
flag(params.SCMTokenFlag),
os.Getenv(envToken),
flag(params.FormatFlag),
printer.FormatJSON,
)
assert.ErrorContains(t, err, "unauthorized")
}
Loading