-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add method to get and cache dataplane data
- Loading branch information
1 parent
56ff775
commit ffaab5d
Showing
3 changed files
with
176 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
package httpclient | ||
|
||
import ( | ||
"fmt" | ||
"time" | ||
|
||
"github.com/databricks/databricks-sdk-go/credentials" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
type DataPlaneInfoKey struct { | ||
ServiceName string | ||
Path string | ||
} | ||
|
||
// TODO: Replace by the DataPlaneInfo generated by SDK Generator in the package oauth2. | ||
// This will be generated once the first service uses the annotation for the DataPlaneInfo. | ||
type DataPlaneInfo struct { | ||
EndpointUrl string | ||
AuthorizationDetails string | ||
} | ||
|
||
type DataPlaneData struct { | ||
DataPlaneInfo *DataPlaneInfo | ||
Token *credentials.OAuthToken | ||
ExpirationTime *time.Time | ||
InfoRefresher func() (*DataPlaneInfo, error) | ||
} | ||
|
||
func (d *DataPlaneData) RefreshInfo() error { | ||
newInfo, err := d.InfoRefresher() | ||
if err != nil { | ||
return err | ||
} | ||
d.DataPlaneInfo = newInfo | ||
return nil | ||
} | ||
|
||
func (d *DataPlaneData) Expired() bool { | ||
return d.Token == nil || d.ExpirationTime == nil || time.Now().After(*d.ExpirationTime) | ||
} | ||
|
||
func (d *DataPlaneData) ExpiresIn(duration time.Duration) bool { | ||
return d.Token == nil || d.ExpirationTime == nil || time.Now().Add(duration).After(*d.ExpirationTime) | ||
} | ||
|
||
// TODO: Maybe add locks around the map? | ||
type dataPlaneCache map[DataPlaneInfoKey]*DataPlaneData | ||
|
||
// Refreshes the DataPlaneInfo | ||
// This is only required if the backend updates the required permissions | ||
// to call an endpoint, or the endpoint_url changes (which should be very rare). | ||
// Calling this method also invalidates the token. | ||
func (c *ApiClient) RefreshInfo(serviceName string, | ||
path string) error { | ||
key := DataPlaneInfoKey{ | ||
ServiceName: serviceName, | ||
Path: path, | ||
} | ||
data, ok := c.dataPlaneCache[key] | ||
if !ok { | ||
return fmt.Errorf("data not found for service %s and endpoint %s", serviceName, path) | ||
} | ||
err := data.RefreshInfo() | ||
if err != nil { | ||
return err | ||
} | ||
data.Token = nil | ||
data.ExpirationTime = nil | ||
return nil | ||
} | ||
|
||
func (c *ApiClient) GetOAuthTokenForDataPlane( | ||
serviceName string, | ||
path string, | ||
dataPlaneInfoGetter func() (*DataPlaneInfo, error), | ||
controlPlaneTokenProvider func() (*oauth2.Token, error)) (*credentials.OAuthToken, error) { | ||
key := DataPlaneInfoKey{ | ||
ServiceName: serviceName, | ||
Path: path, | ||
} | ||
data, ok := c.dataPlaneCache[key] | ||
if !ok { | ||
data = &DataPlaneData{ | ||
InfoRefresher: dataPlaneInfoGetter, | ||
} | ||
data.RefreshInfo() | ||
c.dataPlaneCache[key] = data | ||
} | ||
if data.ExpiresIn(2 * time.Minute) { | ||
err := c.refreshToken(data, controlPlaneTokenProvider) | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
return data.Token, nil | ||
} | ||
|
||
func (c *ApiClient) refreshToken(data *DataPlaneData, controlPlaneTokenProvider func() (*oauth2.Token, error)) error { | ||
controlPlaneToken, err := controlPlaneTokenProvider() | ||
if err != nil { | ||
return err | ||
} | ||
newToken, err := c.GetOAuthToken(data.DataPlaneInfo.AuthorizationDetails, controlPlaneToken) | ||
if err != nil { | ||
return err | ||
} | ||
data.Token = newToken | ||
expirationTime := time.Now().Add(time.Second * time.Duration(newToken.ExpiresIn)) | ||
data.ExpirationTime = &expirationTime | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package httpclient | ||
|
||
import ( | ||
"encoding/json" | ||
"io" | ||
"net/http" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/databricks/databricks-sdk-go/credentials" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"golang.org/x/oauth2" | ||
) | ||
|
||
var mockTokenProvider = func() (*oauth2.Token, error) { | ||
return &oauth2.Token{ | ||
AccessToken: "dummy", | ||
}, nil | ||
} | ||
|
||
var mockInfoGetter = func() (*DataPlaneInfo, error) { | ||
return &DataPlaneInfo{ | ||
EndpointUrl: "endpoint_url", | ||
AuthorizationDetails: "details", | ||
}, nil | ||
} | ||
|
||
func TestGetOAuthTokenForDataPlane(t *testing.T) { | ||
tokenResponse := &credentials.OAuthToken{ | ||
AccessToken: "dummyDataPlane", | ||
ExpiresIn: 3600, | ||
Scope: "scope", | ||
TokenType: "Bearer", | ||
} | ||
marshalled, err := json.Marshal(tokenResponse) | ||
require.NoError(t, err) | ||
c := NewApiClient(ClientConfig{ | ||
Transport: hc(func(r *http.Request) (*http.Response, error) { | ||
return &http.Response{ | ||
StatusCode: 200, | ||
Body: io.NopCloser(strings.NewReader(string(marshalled))), | ||
Request: r, | ||
}, nil | ||
}), | ||
}) | ||
token, err := c.GetOAuthTokenForDataPlane("service", "path", mockInfoGetter, mockTokenProvider) | ||
assert.NoError(t, err) | ||
assert.Equal(t, "dummyDataPlane", token.AccessToken) | ||
cachedData, ok := c.dataPlaneCache[DataPlaneInfoKey{ | ||
ServiceName: "service", | ||
Path: "path", | ||
}] | ||
assert.True(t, ok) | ||
assert.Equal(t, "details", cachedData.DataPlaneInfo.AuthorizationDetails) | ||
assert.Equal(t, "endpoint_url", cachedData.DataPlaneInfo.EndpointUrl) | ||
assert.Equal(t, "dummyDataPlane", cachedData.Token.AccessToken) | ||
assert.False(t, cachedData.Expired()) | ||
} |