diff --git a/CHANGELOG.md b/CHANGELOG.md index f263a79cf..1b3e3f336 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ## [Unreleased] ### Added +- Inherit AwsSigV4 in .child ([#725](https://github.com/opensearch-project/opensearch-js/pull/725)) ### Dependencies - Bumps `prettier` from 3.1.1 to 3.2.5 - Bumps `@aws-sdk/types` from 3.485.0 to 3.523.0 diff --git a/index.d.ts b/index.d.ts index 2316ff203..9a12e8c16 100644 --- a/index.d.ts +++ b/index.d.ts @@ -54,6 +54,7 @@ import { CloudConnectionPool, ResurrectEvent, BasicAuth, + AwsSigv4Auth, } from './lib/pool'; import Serializer from './lib/Serializer'; import Helpers from './lib/Helpers'; @@ -126,7 +127,7 @@ interface ClientOptions { opaqueIdPrefix?: string; generateRequestId?: generateRequestIdFn; name?: string | symbol; - auth?: BasicAuth; + auth?: BasicAuth | AwsSigv4Auth; context?: Context; proxy?: string | URL; enableMetaHeader?: boolean; diff --git a/index.js b/index.js index 5e5154398..2987a72a4 100644 --- a/index.js +++ b/index.js @@ -192,6 +192,7 @@ class Client extends OpenSearchAPI { opaqueIdPrefix: options.opaqueIdPrefix, context: options.context, memoryCircuitBreaker: options.memoryCircuitBreaker, + auth: options.auth, }); this.helpers = new Helpers({ diff --git a/lib/Connection.d.ts b/lib/Connection.d.ts index 0dc40d3a1..6d5093e15 100644 --- a/lib/Connection.d.ts +++ b/lib/Connection.d.ts @@ -32,7 +32,7 @@ import { URL } from 'url'; import { inspect, InspectOptions } from 'util'; import { Readable as ReadableStream } from 'stream'; -import { BasicAuth } from './pool'; +import { BasicAuth, AwsSigv4Auth } from './pool'; import * as http from 'http'; import * as https from 'https'; import * as hpagent from 'hpagent'; @@ -48,7 +48,7 @@ export interface ConnectionOptions { agent?: AgentOptions | agentFn; status?: string; roles?: ConnectionRoles; - auth?: BasicAuth; + auth?: BasicAuth | AwsSigv4Auth; proxy?: string | URL; } diff --git a/lib/Transport.d.ts b/lib/Transport.d.ts index d79eabec7..ab46c81fa 100644 --- a/lib/Transport.d.ts +++ b/lib/Transport.d.ts @@ -30,7 +30,7 @@ import { Readable as ReadableStream } from 'stream'; import Connection from './Connection'; import * as errors from './errors'; -import { CloudConnectionPool, ConnectionPool } from './pool'; +import { CloudConnectionPool, ConnectionPool, BasicAuth, AwsSigv4Auth } from './pool'; import Serializer from './Serializer'; export type ApiError = @@ -82,6 +82,7 @@ interface TransportOptions { name?: string; opaqueIdPrefix?: string; memoryCircuitBreaker?: MemoryCircuitBreakerOptions; + auth?: BasicAuth | AwsSigv4Auth; } export interface RequestEvent, TContext = Context> { diff --git a/lib/Transport.js b/lib/Transport.js index f86f73c8f..b13ff9f2b 100644 --- a/lib/Transport.js +++ b/lib/Transport.js @@ -101,6 +101,7 @@ class Transport { this._sniffEnabled = typeof this.sniffInterval === 'number'; this._nextSniff = this._sniffEnabled ? Date.now() + this.sniffInterval : 0; this._isSniffing = false; + this._auth = opts.auth; if (opts.sniffOnStart === true) { // timer needed otherwise it will clash @@ -495,6 +496,9 @@ class Transport { Object.assign({}, params.querystring, options.querystring) ); } + if (this._auth !== null && typeof this._auth === 'object' && 'credentials' in this._auth) { + params.auth = this._auth; + } // handles request timeout params.timeout = toMs(options.requestTimeout || this.requestTimeout); diff --git a/lib/aws/AwsSigv4Signer.js b/lib/aws/AwsSigv4Signer.js index 6b8eb83f0..8e6ae8e33 100644 --- a/lib/aws/AwsSigv4Signer.js +++ b/lib/aws/AwsSigv4Signer.js @@ -79,6 +79,18 @@ function AwsSigv4Signer(opts = {}) { request.region = opts.region; request.headers = request.headers || {}; request.headers['host'] = request.hostname; + + if (request['auth']) { + const awssigv4Cred = request['auth']; + credentialsState.credentials = { + accessKeyId: awssigv4Cred.credentials.accessKeyId, + secretAccessKey: awssigv4Cred.credentials.secretAccessKey, + sessionToken: awssigv4Cred.credentials.sessionToken, + }; + request.region = awssigv4Cred.region; + request.service = awssigv4Cred.service; + delete request['auth']; + } const signed = aws4.sign(request, credentialsState.credentials); signed.headers['x-amz-content-sha256'] = crypto .createHash('sha256') diff --git a/lib/pool/index.d.ts b/lib/pool/index.d.ts index 9a1296673..4bdc6a888 100644 --- a/lib/pool/index.d.ts +++ b/lib/pool/index.d.ts @@ -38,7 +38,7 @@ interface BaseConnectionPoolOptions { ssl?: SecureContextOptions; agent?: AgentOptions; proxy?: string | URL; - auth?: BasicAuth; + auth?: BasicAuth | AwsSigv4Auth; emit: (event: string | symbol, ...args: any[]) => boolean; Connection: typeof Connection; } @@ -62,6 +62,16 @@ interface BasicAuth { password: string; } +interface AwsSigv4Auth { + credentials : { + accessKeyId: string; + secretAccessKey: string; + sessionToken: string; + } + region: string; + service: string; +} + interface resurrectOptions { now?: number; requestId: string; @@ -85,7 +95,7 @@ declare class BaseConnectionPool { _ssl: SecureContextOptions | null; _agent: AgentOptions | null; _proxy: string | URL; - auth: BasicAuth; + auth: BasicAuth | AwsSigv4Auth; Connection: typeof Connection; constructor(opts?: BaseConnectionPoolOptions); /** @@ -235,6 +245,7 @@ export { ConnectionPoolOptions, getConnectionOptions, BasicAuth, + AwsSigv4Auth, internals, resurrectOptions, ResurrectEvent, diff --git a/test/unit/lib/aws/awssigv4signer.test.js b/test/unit/lib/aws/awssigv4signer.test.js index 5c6d3b699..30fa33181 100644 --- a/test/unit/lib/aws/awssigv4signer.test.js +++ b/test/unit/lib/aws/awssigv4signer.test.js @@ -14,6 +14,7 @@ const AwsSigv4Signer = require('../../../../lib/aws/AwsSigv4Signer'); const AwsSigv4SignerError = require('../../../../lib/aws/errors'); const { Connection } = require('../../../../index'); const { Client, buildServer } = require('../../../utils'); +const { debug } = require('console'); test('Sign with SigV4', (t) => { t.plan(4); @@ -594,3 +595,108 @@ test('Basic aws sdk v3 when token expires later than `requestTimeout` ms in the .catch(t.fail); }); }); + +test('Should create child client', (t) => { + t.plan(8); + const childClientCred = { + auth: { + credentials: { + accessKeyId: 'foo', + secretAccessKey: 'bar', + sessionToken: 'foobar', + }, + region: 'eu-west-1', + service: 'es', + }, + }; + const childClientCred2 = { + auth: { + credentials: { + accessKeyId: 'foo2', + secretAccessKey: 'bar2', + sessionToken: 'foobar2', + }, + region: 'eu-west-2', + service: 'es-2', + }, + }; + let count = 0; + function handler(req, res) { + res.setHeader('Content-Type', 'application/json;utf=8'); + res.end(JSON.stringify({ hello: 'world' })); + } + + buildServer(handler, ({ port }, server) => { + const mockRegion = 'us-east-1'; + + let getCredentialsCalled = 0; + const getCredentials = () => + new Promise((resolve) => { + setTimeout(() => { + getCredentialsCalled++; + resolve({ + accessKeyId: uuidv4(), + secretAccessKey: uuidv4(), + sessionToken: uuidv4(), + }); + }, 100); + }); + + const AwsSigv4SignerOptions = { + getCredentials: getCredentials, + region: mockRegion, + }; + + const auth = AwsSigv4Signer(AwsSigv4SignerOptions); + + const client = new Client({ + ...auth, + node: `http://localhost:${port}`, + }); + const child = client.child(childClientCred); + const child2 = client.child(childClientCred2); + + client + .search({ + index: 'test', + q: 'foo:bar', + }) + .then(({ body }) => { + t.same(body, { hello: 'world' }); + t.same(getCredentialsCalled, 1); + child + .search({ + index: 'test', + q: 'foo:bar', + }) + .then(({ body }) => { + t.same(body, { hello: 'world' }); + t.same(getCredentialsCalled, 1); + child2 + .search({ + index: 'test', + q: 'foo:bar', + }) + .then(() => { + server.stop(); + }) + .catch(t.fail); + }) + .catch(t.fail); + }) + .catch(t.fail); + + child.on('request', (err, { meta }) => { + debug('Count', count); + if (count === 0) { + t.equal(JSON.stringify(meta.request.params.auth), undefined); + } else if (count === 1) { + t.equal(JSON.stringify(meta.request.params.auth), JSON.stringify(childClientCred.auth)); + } else if (count === 2) { + t.equal(JSON.stringify(meta.request.params.auth), JSON.stringify(childClientCred2.auth)); + } + count++; + }); + t.not_same(child.transport._auth, child2.transport._auth); + }); +});