Skip to content

Commit

Permalink
added custom error handler support for credentials provider so middle…
Browse files Browse the repository at this point in the history
…auth can catch terms of service errors and prompt the user to agree to the terms of service
  • Loading branch information
chrisj committed Jan 12, 2024
1 parent 45f6578 commit 736a926
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 78 deletions.
11 changes: 8 additions & 3 deletions src/credentials_provider/http_request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ export async function fetchWithCredentials<Credentials, T>(
credentials: Credentials,
requestInit: RequestInit,
) => RequestInit,
errorHandler: (httpError: HttpError, credentials: Credentials) => "refresh",
errorHandler: (
httpError: HttpError,
credentials: Credentials,
) => "refresh" | Promise<"refresh">,
cancellationToken: CancellationToken = uncancelableToken,
): Promise<T> {
let credentials: CredentialsWithGeneration<Credentials> | undefined;
for (let credentialsAttempt = 0; ; ) {
credentialsLoop: for (let credentialsAttempt = 0; ; ) {
throwIfCanceled(cancellationToken);
if (credentialsAttempt > 1) {
// Don't delay on the first attempt, and also don't delay on the second attempt, since if the
Expand All @@ -65,7 +68,9 @@ export async function fetchWithCredentials<Credentials, T>(
);
} catch (error) {
if (error instanceof HttpError) {
if (errorHandler(error, credentials.credentials) === "refresh") {
if (
(await errorHandler(error, credentials.credentials)) === "refresh"
) {
if (++credentialsAttempt === maxCredentialsAttempts) throw error;
continue;
}
Expand Down
22 changes: 22 additions & 0 deletions src/credentials_provider/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import {
} from "#/util/cancellation";
import { Owned, RefCounted } from "#/util/disposable";
import { StringMemoize } from "#/util/memoize";
import { HttpError } from "#/util/http_request";
import { OAuth2Credentials } from "#/credentials_provider/oauth2";

/**
* Wraps an arbitrary JSON credentials object with a generation number.
Expand All @@ -46,6 +48,26 @@ export abstract class CredentialsProvider<Credentials> extends RefCounted {
invalidCredentials?: CredentialsWithGeneration<Credentials>,
cancellationToken?: CancellationToken,
) => Promise<CredentialsWithGeneration<Credentials>>;

errorHandler? = async (
error: HttpError,
credentials: OAuth2Credentials,
): Promise<"refresh"> => {
const { status } = error;
if (status === 401) {
// 401: Authorization needed. OAuth2 token may have expired.
return "refresh";
} else if (status === 403 && !credentials.accessToken) {
// Anonymous access denied. Request credentials.
return "refresh";
}
if (error instanceof Error && credentials.email !== undefined) {
error.message += ` (Using credentials for ${JSON.stringify(
credentials.email,
)})`;
}
throw error;
};
}

export function makeCachedCredentialsGetter<Credentials>(
Expand Down
18 changes: 1 addition & 17 deletions src/credentials_provider/oauth2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,7 @@ export function fetchWithOAuth2Credentials<T>(
);
return { ...init, headers };
},
(error, credentials) => {
const { status } = error;
if (status === 401) {
// 401: Authorization needed. OAuth2 token may have expired.
return "refresh";
}
if (status === 403 && !credentials.accessToken) {
// Anonymous access denied. Request credentials.
return "refresh";
}
if (error instanceof Error && credentials.email !== undefined) {
error.message += ` (Using credentials for ${JSON.stringify(
credentials.email,
)})`;
}
throw error;
},
credentialsProvider.errorHandler!,
cancellationToken,
);
}
158 changes: 100 additions & 58 deletions src/datasource/middleauth/credentials_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import {
verifyString,
verifyStringArray,
} from "#/util/json";
import { HttpError } from "#/util/http_request";
import { OAuth2Credentials } from "#/credentials_provider/oauth2";

export type MiddleAuthToken = {
tokenType: string;
Expand All @@ -49,86 +51,88 @@ function openPopupCenter(url: string, width: number, height: number) {
);
}

async function waitForLogin(serverUrl: string): Promise<MiddleAuthToken> {
async function waitForRemoteFlow(
url: string,
startMessage: string,
startAction: string,
retryMessage: string,
closedMessage: string,
): Promise<any> {
const status = new StatusMessage(/*delay=*/ false);

const res: Promise<MiddleAuthToken> = new Promise((f) => {
function writeLoginStatus(message: string, buttonMessage: string) {
function writeStatus(message: string, buttonMessage: string) {
status.element.textContent = message + " ";
const button = document.createElement("button");
button.textContent = buttonMessage;
status.element.appendChild(button);

button.addEventListener("click", () => {
writeLoginStatus(
`Waiting for login to middleauth server ${serverUrl}...`,
"Retry",
);

const auth_popup = openPopupCenter(
`${serverUrl}/api/v1/authorize`,
400,
650,
);

const closeAuthPopup = () => {
auth_popup?.close();
writeStatus(retryMessage, "Retry");
const popup = openPopupCenter(url, 400, 650);
const closePopup = () => {
popup?.close();
};

window.addEventListener("beforeunload", closeAuthPopup);
window.addEventListener("beforeunload", closePopup);
const checkClosed = setInterval(() => {
if (auth_popup?.closed) {
if (popup?.closed) {
clearInterval(checkClosed);
writeLoginStatus(
`Login window closed for middleauth server ${serverUrl}.`,
"Retry",
);
writeStatus(closedMessage, "Retry");
}
}, 1000);

const tokenListener = async (ev: MessageEvent) => {
if (ev.source === auth_popup) {
const messageListener = async (ev: MessageEvent) => {
if (ev.source === popup) {
clearInterval(checkClosed);
window.removeEventListener("message", tokenListener);
window.removeEventListener("beforeunload", closeAuthPopup);
closeAuthPopup();

verifyObject(ev.data);
const accessToken = verifyObjectProperty(
ev.data,
"token",
verifyString,
);
const appUrls = verifyObjectProperty(
ev.data,
"app_urls",
verifyStringArray,
);

const token: MiddleAuthToken = {
tokenType: "Bearer",
accessToken,
url: serverUrl,
appUrls,
};
f(token);
window.removeEventListener("message", messageListener);
window.removeEventListener("beforeunload", closePopup);
closePopup();
f(ev.data);
}
};

window.addEventListener("message", tokenListener);
window.addEventListener("message", messageListener);
});
}

writeLoginStatus(`middleauth server ${serverUrl} login required.`, "Login");
writeStatus(startMessage, startAction);
});

try {
return await res;
} finally {
status.dispose();
}
}

async function waitForLogin(serverUrl: string): Promise<MiddleAuthToken> {
console.log("wait for login");
const data = await waitForRemoteFlow(
`${serverUrl}/api/v1/authorize`,
`middleauth server ${serverUrl} login required.`,
"Login",
`Waiting for login to middleauth server ${serverUrl}...`,
`Login window closed for middleauth server ${serverUrl}.`,
);
verifyObject(data);
const accessToken = verifyObjectProperty(data, "token", verifyString);
const appUrls = verifyObjectProperty(data, "app_urls", verifyStringArray);
const token: MiddleAuthToken = {
tokenType: "Bearer",
accessToken,
url: serverUrl,
appUrls,
};
return token;
}

async function showTosForm(url: string, tosName: string) {
const data = await waitForRemoteFlow(
url,
`Before you can access ${tosName}, you need to accept its Terms of Service.`,
"Open",
`Waiting for Terms of Service agreement...`,
`Terms of Service closed for ${tosName}.`,
);
return data === "success";
}

const LOCAL_STORAGE_AUTH_KEY = "auth_token_v2";

function getAuthTokenFromLocalStorage(authURL: string) {
Expand All @@ -154,17 +158,14 @@ export class MiddleAuthCredentialsProvider extends CredentialsProvider<MiddleAut
}
get = makeCredentialsGetter(async () => {
let token = undefined;

if (!this.alreadyTriedLocalStorage) {
this.alreadyTriedLocalStorage = true;
token = getAuthTokenFromLocalStorage(this.serverUrl);
}

if (!token) {
token = await waitForLogin(this.serverUrl);
saveAuthTokenToLocalStorage(this.serverUrl, token);
}

return token;
});
}
Expand All @@ -181,6 +182,7 @@ export class UnverifiedApp extends Error {
export class MiddleAuthAppCredentialsProvider extends CredentialsProvider<MiddleAuthToken> {
private credentials: CredentialsWithGeneration<MiddleAuthToken> | undefined =
undefined;
agreedToTos = false;

constructor(
private serverUrl: string,
Expand All @@ -190,21 +192,61 @@ export class MiddleAuthAppCredentialsProvider extends CredentialsProvider<Middle
}

get = makeCredentialsGetter(async () => {
if (this.credentials && this.agreedToTos) {
return this.credentials.credentials;
}
this.agreedToTos = false;
const authInfo = await fetch(`${this.serverUrl}/auth_info`).then((res) =>
res.json(),
);
const provider = this.credentialsManager.getCredentialsProvider(
"middleauth",
authInfo.login_url,
) as MiddleAuthCredentialsProvider;

this.credentials = await provider.get(this.credentials);

if (this.credentials.credentials.appUrls.includes(this.serverUrl)) {
return this.credentials.credentials;
}
const status = new StatusMessage(/*delay=*/ false);
status.setText(`middleauth: unverified app ${this.serverUrl}`);
throw new UnverifiedApp(this.serverUrl);
});

errorHandler = async (
error: HttpError,
credentials: OAuth2Credentials,
): Promise<"refresh"> => {
console.log("ma handle error", error);
const { status } = error;
if (status === 401) {
// 401: Authorization needed. OAuth2 token may have expired.
return "refresh";
} else if (status === 403) {
// Anonymous access denied. Request credentials.
const { response } = error;
if (response) {
const { headers } = response;
const contentType = headers.get("content-type");
if (contentType === "application/json") {
const json = await response.json();
if (json.error && json.error === "missing_tos") {
const url = new URL(json.data.tos_form_url);
url.searchParams.set("client", "ng");
const success = await showTosForm(
url.toString(),
json.data.tos_name,
);
if (success) {
this.agreedToTos = true;
return "refresh";
}
}
}
}
if (!credentials.accessToken) {
return "refresh";
}
}
throw error;
};
}

0 comments on commit 736a926

Please sign in to comment.