Skip to content

Commit

Permalink
fix: include groups in tokens
Browse files Browse the repository at this point in the history
Fixes #176
  • Loading branch information
jagregory committed May 30, 2022
1 parent 4863d54 commit 996dcde
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/__tests__/mockUserPoolService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export const newMockUserPoolService = (
getUserByRefreshToken: jest.fn(),
getUserByUsername: jest.fn(),
listGroups: jest.fn(),
listUserGroupMembership: jest.fn(),
listUsers: jest.fn(),
options: config,
removeUserFromGroup: jest.fn(),
Expand Down
2 changes: 1 addition & 1 deletion src/bin/start.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const logger = Pino(
singleLine: true,
messageFormat: (log, messageKey) =>
`${log["reqId"] ?? "NONE"} ${log["target"] ?? "NONE"} ${log[messageKey]}`,
}) as any
}) as any // eslint-disable-line @typescript-eslint/no-explicit-any
);

createDefaultServer(logger)
Expand Down
74 changes: 74 additions & 0 deletions src/services/tokenGenerator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
TDB.appClient(),
{ client: "metadata" },
"RefreshTokens"
Expand Down Expand Up @@ -80,6 +81,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
TDB.appClient(),
{ client: "metadata" },
"RefreshTokens"
Expand Down Expand Up @@ -112,6 +114,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
TDB.appClient(),
{ client: "metadata" },
"RefreshTokens"
Expand Down Expand Up @@ -162,6 +165,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
TDB.appClient(),
{ client: "metadata" },
"RefreshTokens"
Expand All @@ -183,6 +187,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
userPoolClient,
{ client: "metadata" },
"RefreshTokens"
Expand Down Expand Up @@ -243,6 +248,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
userPoolClient,
{ client: "metadata" },
"RefreshTokens"
Expand Down Expand Up @@ -278,6 +284,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
userPoolClient,
{ client: "metadata" },
"RefreshTokens"
Expand Down Expand Up @@ -309,6 +316,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
userPoolClient,
{ client: "metadata" },
"RefreshTokens"
Expand Down Expand Up @@ -344,6 +352,7 @@ describe("JwtTokenGenerator", () => {
const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
userPoolClient,
{ client: "metadata" },
"RefreshTokens"
Expand All @@ -361,4 +370,69 @@ describe("JwtTokenGenerator", () => {
});
});
});

describe("groups", () => {
it("does not include a cognito:groups claim if the user has no groups", async () => {
mockTriggers.enabled.mockReturnValue(false);

const userPoolClient = TDB.appClient({
AccessTokenValidity: 10,
IdTokenValidity: 20,
RefreshTokenValidity: 30,
TokenValidityUnits: {
AccessToken: "seconds",
IdToken: "minutes",
RefreshToken: "hours",
},
});

const tokens = await tokenGenerator.generate(
TestContext,
user,
[],
userPoolClient,
{ client: "metadata" },
"RefreshTokens"
);

expect(
(jwt.decode(tokens.AccessToken) as any)["cognito:groups"]
).toBeUndefined();
expect(
(jwt.decode(tokens.IdToken) as any)["cognito:groups"]
).toBeUndefined();
});

it("includes a cognito:groups claim with the user's groups", async () => {
mockTriggers.enabled.mockReturnValue(false);

const userPoolClient = TDB.appClient({
AccessTokenValidity: 10,
IdTokenValidity: 20,
RefreshTokenValidity: 30,
TokenValidityUnits: {
AccessToken: "seconds",
IdToken: "minutes",
RefreshToken: "hours",
},
});

const tokens = await tokenGenerator.generate(
TestContext,
user,
["group1", "group2"],
userPoolClient,
{ client: "metadata" },
"RefreshTokens"
);

expect((jwt.decode(tokens.AccessToken) as any)["cognito:groups"]).toEqual(
["group1", "group2"]
);
expect((jwt.decode(tokens.IdToken) as any)["cognito:groups"]).toEqual([
"group1",
"group2",
]);
});
});
});
63 changes: 36 additions & 27 deletions src/services/tokenGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,15 @@ const RESERVED_CLAIMS = [
"token_use",
];

type RawToken = Record<
string,
string | number | boolean | undefined | readonly string[]
>;

const applyTokenOverrides = (
token: Record<string, string | number | boolean | undefined>,
token: RawToken,
overrides: TokenOverrides
): Record<string, string | number | boolean | undefined> => {
): RawToken => {
// TODO: support group overrides

const claimsToSuppress = (overrides?.claimsToSuppress ?? []).filter(
Expand Down Expand Up @@ -89,6 +94,7 @@ export interface TokenGenerator {
generate(
ctx: Context,
user: User,
userGroups: readonly string[],
userPoolClient: AppClient,
clientMetadata: Record<string, string> | undefined,
source:
Expand Down Expand Up @@ -124,6 +130,7 @@ export class JwtTokenGenerator implements TokenGenerator {
public async generate(
ctx: Context,
user: User,
userGroups: readonly string[],
userPoolClient: AppClient,
clientMetadata: Record<string, string> | undefined,
source:
Expand All @@ -137,7 +144,18 @@ export class JwtTokenGenerator implements TokenGenerator {
const authTime = Math.floor(this.clock.get().getTime() / 1000);
const sub = attributeValue("sub", user.Attributes);

let idToken: Record<string, string | number | boolean | undefined> = {
const accessToken: RawToken = {
auth_time: authTime,
client_id: userPoolClient.ClientId,
event_id: eventId,
iat: authTime,
jti: uuid.v4(),
scope: "aws.cognito.signin.user.admin", // TODO: scopes
sub,
token_use: "access",
username: user.Username,
};
let idToken: RawToken = {
"cognito:username": user.Username,
auth_time: authTime,
email: attributeValue("email", user.Attributes),
Expand All @@ -152,6 +170,11 @@ export class JwtTokenGenerator implements TokenGenerator {
...attributesToRecord(customAttributes(user.Attributes)),
};

if (userGroups.length) {
accessToken["cognito:groups"] = userGroups;
idToken["cognito:groups"] = userGroups;
}

if (this.triggers.enabled("PreTokenGeneration")) {
const result = await this.triggers.preTokenGeneration(ctx, {
clientId: userPoolClient.ClientId,
Expand All @@ -174,30 +197,16 @@ export class JwtTokenGenerator implements TokenGenerator {
const issuer = `${this.tokenConfig.IssuerDomain}/${userPoolClient.UserPoolId}`;

return {
AccessToken: jwt.sign(
{
auth_time: authTime,
client_id: userPoolClient.ClientId,
event_id: eventId,
iat: authTime,
jti: uuid.v4(),
scope: "aws.cognito.signin.user.admin", // TODO: scopes
sub,
token_use: "access",
username: user.Username,
},
PrivateKey.pem,
{
algorithm: "RS256",
issuer,
expiresIn: formatExpiration(
userPoolClient.AccessTokenValidity,
userPoolClient.TokenValidityUnits?.AccessToken ?? "hours",
"24h"
),
keyid: "CognitoLocal",
}
),
AccessToken: jwt.sign(accessToken, PrivateKey.pem, {
algorithm: "RS256",
issuer,
expiresIn: formatExpiration(
userPoolClient.AccessTokenValidity,
userPoolClient.TokenValidityUnits?.AccessToken ?? "hours",
"24h"
),
keyid: "CognitoLocal",
}),
IdToken: jwt.sign(idToken, PrivateKey.pem, {
algorithm: "RS256",
issuer,
Expand Down
50 changes: 49 additions & 1 deletion src/services/userPoolService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ describe("User Pool Service", () => {

ds.get.mockImplementation((ctx, key) => {
if (key === "Groups") {
return Promise.resolve([]);
return Promise.resolve({});
}

return Promise.resolve(null);
Expand Down Expand Up @@ -774,4 +774,52 @@ describe("User Pool Service", () => {
);
});
});

describe("listUserGroupMembership", () => {
it("returns all the groups that the user is a member", async () => {
const ds = newMockDataStore();
const userPool = new UserPoolServiceImpl(
mockClientsDataStore,
clock,
ds,
{
Id: "test",
}
);

const user = TDB.user();
const group1 = TDB.group({
GroupName: "Group1",
members: [user.Username],
});
const group2 = TDB.group({
GroupName: "Group2",
members: [user.Username],
});
const group3 = TDB.group({
GroupName: "Group3",
members: [],
});
const groups = {
[group1.GroupName]: group1,
[group2.GroupName]: group2,
[group3.GroupName]: group3,
};

ds.get.mockImplementation((ctx, key) => {
if (key === "Groups") {
return Promise.resolve(groups);
}

return Promise.resolve(null);
});

const groupMembership = await userPool.listUserGroupMembership(
TestContext,
user
);

expect(groupMembership).toEqual([group1.GroupName, group2.GroupName]);
});
});
});
21 changes: 21 additions & 0 deletions src/services/userPoolService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ export interface UserPoolService {
): Promise<User | null>;
listGroups(ctx: Context): Promise<readonly Group[]>;
listUsers(ctx: Context): Promise<readonly User[]>;
listUserGroupMembership(ctx: Context, user: User): Promise<readonly string[]>;
updateOptions(ctx: Context, userPool: UserPool): Promise<void>;
removeUserFromGroup(ctx: Context, group: Group, user: User): Promise<void>;
saveGroup(ctx: Context, group: Group): Promise<void>;
Expand Down Expand Up @@ -410,6 +411,26 @@ export class UserPoolServiceImpl implements UserPoolService {
await this.dataStore.set<Group>(ctx, ["Groups", group.GroupName], group);
}

async listUserGroupMembership(
ctx: Context,
user: User
): Promise<readonly string[]> {
ctx.logger.debug(
{ username: user.Username },
"UserPoolServiceImpl.listUserGroupMembership"
);

// could optimise this by dual-writing group membership to both the group and
// the user records, but for an initial version this is probably fine unless
// you have a lot of groups
const groups = await this.listGroups(ctx);

return groups
.filter((x) => x.members?.includes(user.Username))
.map((x) => x.GroupName)
.sort((a, b) => a.localeCompare(b));
}

async storeRefreshToken(
ctx: Context,
refreshToken: string,
Expand Down
4 changes: 4 additions & 0 deletions src/targets/adminInitiateAuth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ describe("AdminInitiateAuth target", () => {
const existingUser = TDB.user();

mockUserPoolService.getUserByUsername.mockResolvedValue(existingUser);
mockUserPoolService.listUserGroupMembership.mockResolvedValue([]);

const response = await adminInitiateAuth(TestContext, {
AuthFlow: "ADMIN_USER_PASSWORD_AUTH",
Expand Down Expand Up @@ -72,6 +73,7 @@ describe("AdminInitiateAuth target", () => {
expect(mockTokenGenerator.generate).toHaveBeenCalledWith(
TestContext,
existingUser,
[],
userPoolClient,
{
client: "metadata",
Expand All @@ -92,6 +94,7 @@ describe("AdminInitiateAuth target", () => {
});

mockUserPoolService.getUserByRefreshToken.mockResolvedValue(existingUser);
mockUserPoolService.listUserGroupMembership.mockResolvedValue([]);

const response = await adminInitiateAuth(TestContext, {
AuthFlow: "REFRESH_TOKEN_AUTH",
Expand Down Expand Up @@ -120,6 +123,7 @@ describe("AdminInitiateAuth target", () => {
expect(mockTokenGenerator.generate).toHaveBeenCalledWith(
TestContext,
existingUser,
[],
userPoolClient,
{
client: "metadata",
Expand Down
Loading

0 comments on commit 996dcde

Please sign in to comment.