Skip to content

Commit

Permalink
Inherit AwsSigV4 in .child (#725)
Browse files Browse the repository at this point in the history
* Adds child client support for AWS SigV4 connections

Signed-off-by: Bandini Bhopi <[email protected]>

* Adds auth in transport

Signed-off-by: Bandini Bhopi <[email protected]>

* Remove unused code

Signed-off-by: Bandini Bhopi <[email protected]>

* Add test to check transport

Signed-off-by: Bandini Bhopi <[email protected]>

* Removed comments

Signed-off-by: Bandini Bhopi <[email protected]>

* Adds changelog and fix lint issue

Signed-off-by: Bandini Bhopi <[email protected]>

* Adds missing type

Signed-off-by: Bandini Bhopi <[email protected]>

* Update changelog and remove console log

Signed-off-by: Bandini Bhopi <[email protected]>

---------

Signed-off-by: Bandini Bhopi <[email protected]>
  • Loading branch information
bandinib-amzn authored Mar 5, 2024
1 parent d960512 commit 8966691
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased]
### Added
- Inherit AwsSigV4 in .child ([#725](https://github.com/opensearch-project/opensearch-js/pull/725))
### Dependencies
- Bumps `prettier` from 3.1.1 to 3.2.5
- Bumps `@aws-sdk/types` from 3.485.0 to 3.523.0
Expand Down
3 changes: 2 additions & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import {
CloudConnectionPool,
ResurrectEvent,
BasicAuth,
AwsSigv4Auth,
} from './lib/pool';
import Serializer from './lib/Serializer';
import Helpers from './lib/Helpers';
Expand Down Expand Up @@ -126,7 +127,7 @@ interface ClientOptions {
opaqueIdPrefix?: string;
generateRequestId?: generateRequestIdFn;
name?: string | symbol;
auth?: BasicAuth;
auth?: BasicAuth | AwsSigv4Auth;
context?: Context;
proxy?: string | URL;
enableMetaHeader?: boolean;
Expand Down
1 change: 1 addition & 0 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class Client extends OpenSearchAPI {
opaqueIdPrefix: options.opaqueIdPrefix,
context: options.context,
memoryCircuitBreaker: options.memoryCircuitBreaker,
auth: options.auth,
});

this.helpers = new Helpers({
Expand Down
4 changes: 2 additions & 2 deletions lib/Connection.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import { URL } from 'url';
import { inspect, InspectOptions } from 'util';
import { Readable as ReadableStream } from 'stream';
import { BasicAuth } from './pool';
import { BasicAuth, AwsSigv4Auth } from './pool';
import * as http from 'http';
import * as https from 'https';
import * as hpagent from 'hpagent';
Expand All @@ -48,7 +48,7 @@ export interface ConnectionOptions {
agent?: AgentOptions | agentFn;
status?: string;
roles?: ConnectionRoles;
auth?: BasicAuth;
auth?: BasicAuth | AwsSigv4Auth;
proxy?: string | URL;
}

Expand Down
3 changes: 2 additions & 1 deletion lib/Transport.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import { Readable as ReadableStream } from 'stream';
import Connection from './Connection';
import * as errors from './errors';
import { CloudConnectionPool, ConnectionPool } from './pool';
import { CloudConnectionPool, ConnectionPool, BasicAuth, AwsSigv4Auth } from './pool';
import Serializer from './Serializer';

export type ApiError =
Expand Down Expand Up @@ -82,6 +82,7 @@ interface TransportOptions {
name?: string;
opaqueIdPrefix?: string;
memoryCircuitBreaker?: MemoryCircuitBreakerOptions;
auth?: BasicAuth | AwsSigv4Auth;
}

export interface RequestEvent<TResponse = Record<string, any>, TContext = Context> {
Expand Down
4 changes: 4 additions & 0 deletions lib/Transport.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class Transport {
this._sniffEnabled = typeof this.sniffInterval === 'number';
this._nextSniff = this._sniffEnabled ? Date.now() + this.sniffInterval : 0;
this._isSniffing = false;
this._auth = opts.auth;

if (opts.sniffOnStart === true) {
// timer needed otherwise it will clash
Expand Down Expand Up @@ -495,6 +496,9 @@ class Transport {
Object.assign({}, params.querystring, options.querystring)
);
}
if (this._auth !== null && typeof this._auth === 'object' && 'credentials' in this._auth) {
params.auth = this._auth;
}

// handles request timeout
params.timeout = toMs(options.requestTimeout || this.requestTimeout);
Expand Down
12 changes: 12 additions & 0 deletions lib/aws/AwsSigv4Signer.js
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ function AwsSigv4Signer(opts = {}) {
request.region = opts.region;
request.headers = request.headers || {};
request.headers['host'] = request.hostname;

if (request['auth']) {
const awssigv4Cred = request['auth'];
credentialsState.credentials = {
accessKeyId: awssigv4Cred.credentials.accessKeyId,
secretAccessKey: awssigv4Cred.credentials.secretAccessKey,
sessionToken: awssigv4Cred.credentials.sessionToken,
};
request.region = awssigv4Cred.region;
request.service = awssigv4Cred.service;
delete request['auth'];
}
const signed = aws4.sign(request, credentialsState.credentials);
signed.headers['x-amz-content-sha256'] = crypto
.createHash('sha256')
Expand Down
15 changes: 13 additions & 2 deletions lib/pool/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ interface BaseConnectionPoolOptions {
ssl?: SecureContextOptions;
agent?: AgentOptions;
proxy?: string | URL;
auth?: BasicAuth;
auth?: BasicAuth | AwsSigv4Auth;
emit: (event: string | symbol, ...args: any[]) => boolean;
Connection: typeof Connection;
}
Expand All @@ -62,6 +62,16 @@ interface BasicAuth {
password: string;
}

interface AwsSigv4Auth {
credentials : {
accessKeyId: string;
secretAccessKey: string;
sessionToken: string;
}
region: string;
service: string;
}

interface resurrectOptions {
now?: number;
requestId: string;
Expand All @@ -85,7 +95,7 @@ declare class BaseConnectionPool {
_ssl: SecureContextOptions | null;
_agent: AgentOptions | null;
_proxy: string | URL;
auth: BasicAuth;
auth: BasicAuth | AwsSigv4Auth;
Connection: typeof Connection;
constructor(opts?: BaseConnectionPoolOptions);
/**
Expand Down Expand Up @@ -235,6 +245,7 @@ export {
ConnectionPoolOptions,
getConnectionOptions,
BasicAuth,
AwsSigv4Auth,
internals,
resurrectOptions,
ResurrectEvent,
Expand Down
106 changes: 106 additions & 0 deletions test/unit/lib/aws/awssigv4signer.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const AwsSigv4Signer = require('../../../../lib/aws/AwsSigv4Signer');
const AwsSigv4SignerError = require('../../../../lib/aws/errors');
const { Connection } = require('../../../../index');
const { Client, buildServer } = require('../../../utils');
const { debug } = require('console');

test('Sign with SigV4', (t) => {
t.plan(4);
Expand Down Expand Up @@ -594,3 +595,108 @@ test('Basic aws sdk v3 when token expires later than `requestTimeout` ms in the
.catch(t.fail);
});
});

test('Should create child client', (t) => {
t.plan(8);
const childClientCred = {
auth: {
credentials: {
accessKeyId: 'foo',
secretAccessKey: 'bar',
sessionToken: 'foobar',
},
region: 'eu-west-1',
service: 'es',
},
};
const childClientCred2 = {
auth: {
credentials: {
accessKeyId: 'foo2',
secretAccessKey: 'bar2',
sessionToken: 'foobar2',
},
region: 'eu-west-2',
service: 'es-2',
},
};
let count = 0;
function handler(req, res) {
res.setHeader('Content-Type', 'application/json;utf=8');
res.end(JSON.stringify({ hello: 'world' }));
}

buildServer(handler, ({ port }, server) => {
const mockRegion = 'us-east-1';

let getCredentialsCalled = 0;
const getCredentials = () =>
new Promise((resolve) => {
setTimeout(() => {
getCredentialsCalled++;
resolve({
accessKeyId: uuidv4(),
secretAccessKey: uuidv4(),
sessionToken: uuidv4(),
});
}, 100);
});

const AwsSigv4SignerOptions = {
getCredentials: getCredentials,
region: mockRegion,
};

const auth = AwsSigv4Signer(AwsSigv4SignerOptions);

const client = new Client({
...auth,
node: `http://localhost:${port}`,
});
const child = client.child(childClientCred);
const child2 = client.child(childClientCred2);

client
.search({
index: 'test',
q: 'foo:bar',
})
.then(({ body }) => {
t.same(body, { hello: 'world' });
t.same(getCredentialsCalled, 1);
child
.search({
index: 'test',
q: 'foo:bar',
})
.then(({ body }) => {
t.same(body, { hello: 'world' });
t.same(getCredentialsCalled, 1);
child2
.search({
index: 'test',
q: 'foo:bar',
})
.then(() => {
server.stop();
})
.catch(t.fail);
})
.catch(t.fail);
})
.catch(t.fail);

child.on('request', (err, { meta }) => {
debug('Count', count);
if (count === 0) {
t.equal(JSON.stringify(meta.request.params.auth), undefined);
} else if (count === 1) {
t.equal(JSON.stringify(meta.request.params.auth), JSON.stringify(childClientCred.auth));
} else if (count === 2) {
t.equal(JSON.stringify(meta.request.params.auth), JSON.stringify(childClientCred2.auth));
}
count++;
});
t.not_same(child.transport._auth, child2.transport._auth);
});
});

0 comments on commit 8966691

Please sign in to comment.