Skip to content

Commit

Permalink
fix various issues
Browse files Browse the repository at this point in the history
  • Loading branch information
au-re committed Nov 30, 2024
1 parent 8a0b342 commit db2c795
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 40 deletions.
11 changes: 0 additions & 11 deletions docs/typings.d.ts

This file was deleted.

15 changes: 8 additions & 7 deletions src/guest.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { extractMethods, isWorker } from "./helpers";
import { extractMethods, getEventData, isWorker } from "./helpers";
import { registerLocalMethods, registerRemoteMethods } from "./rpc";
import { actions, EventHandlers, events, IConnection, ISchema } from "./types";

Expand All @@ -8,16 +8,17 @@ function connect(schema: ISchema = {}, eventHandlers?: EventHandlers): Promise<I

// on handshake response
async function handleHandshakeResponse(event: any) {
if (event.data.action !== actions.HANDSHAKE_REPLY) return;
const eventData = getEventData(event);
if (eventData?.action !== actions.HANDSHAKE_REPLY) return;

// register local methods
const unregisterLocal = registerLocalMethods(schema, localMethods, event.data.connectionID);
const unregisterLocal = registerLocalMethods(schema, localMethods, eventData.connectionID);

// register remote methods
const { remote, unregisterRemote } = registerRemoteMethods(
event.data.schema,
event.data.methods,
event.data.connectionID,
eventData.schema,
eventData.methods,
eventData.connectionID,
event
);

Expand All @@ -26,7 +27,7 @@ function connect(schema: ISchema = {}, eventHandlers?: EventHandlers): Promise<I
// send a HANDSHAKE REPLY to the host
const payload = {
action: actions.HANDSHAKE_REPLY,
connectionID: event.data.connectionID,
connectionID: eventData.connectionID,
};

if (isWorker()) self.postMessage(payload);
Expand Down
56 changes: 54 additions & 2 deletions src/helpers.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
export const CONNECTION_TIMEOUT = 1000;

/**
* check if run in a webworker
*
Expand Down Expand Up @@ -109,3 +107,57 @@ export function generateId(length: number = 10): string {
}
return result;
}

export interface NodeWorker {
on(event: string, handler: any): void;
off(event: string, handler: any): void;
postMessage(message: any): void;
terminate(): void;
}

// Type that captures common properties between Web Workers and Node Workers
export type WorkerLike = Worker | NodeWorker;

let NodeWorkerClass: any = null;

if (isNodeEnv()) {
try {
const workerThreads = require('worker_threads');
NodeWorkerClass = workerThreads.Worker;
} catch {
}
}

export function isNodeWorker(target: any): target is NodeWorker {
return NodeWorkerClass !== null && target instanceof NodeWorkerClass;
}

export function isWorkerLike(target: any): target is WorkerLike {
return isNodeWorker(target) || target instanceof Worker;
}

export function addEventListener(target: Window | WorkerLike | HTMLIFrameElement, event: string, handler: any) {
if (isNodeWorker(target)) {
target.on(event, handler);
} else if ('addEventListener' in target) {
target.addEventListener(event, handler);
}
}

export function removeEventListener(target: Window | WorkerLike | HTMLIFrameElement, event: string, handler: any) {
if (isNodeWorker(target)) {
target.off(event, handler);
} else if ('removeEventListener' in target) {
target.removeEventListener(event, handler);
}
}

/**
* Normalize message event data across Web and Node.js environments
* In web, data is in event.data
* In Node.js, the event itself contains the data
*/
export function getEventData(event: any): any {
return event.data || event;
}

58 changes: 38 additions & 20 deletions src/host.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
import { extractMethods, generateId, getOriginFromURL, isNodeEnv } from "./helpers";
import { extractMethods, generateId, getOriginFromURL, isNodeEnv, addEventListener, removeEventListener, isNodeWorker, NodeWorker, getEventData } from "./helpers";
import { registerLocalMethods, registerRemoteMethods } from "./rpc";
import { actions, events, IConnection, IConnections, ISchema } from "./types";

const connections: IConnections = {};

function isValidTarget(iframe: HTMLIFrameElement, event: any) {
const childURL = iframe.getAttribute("src");
const childOrigin = getOriginFromURL(childURL);
const hasProperOrigin = event.origin === childOrigin;
const hasProperSource = event.source === iframe.contentWindow;

return hasProperOrigin && hasProperSource;
function isValidTarget(guest: HTMLIFrameElement | Worker | NodeWorker, event: any) {
// If it's a worker, we don't need to validate origin
if (isNodeWorker(guest) || (typeof Worker !== 'undefined' && guest instanceof Worker)) {
return true;
}

// For iframes, check origin and source
const iframe = guest as HTMLIFrameElement;
try {
const childURL = iframe.src;
const childOrigin = getOriginFromURL(childURL);
const hasProperOrigin = event.origin === childOrigin;
const hasProperSource = event.source === iframe.contentWindow;

return (hasProperOrigin && hasProperSource) || !childURL;
} catch (e) {
console.warn('Error checking iframe target:', e);
return false;
}
}

/**
Expand All @@ -21,19 +33,22 @@ function isValidTarget(iframe: HTMLIFrameElement, event: any) {
* @param schema
* @returns Promise
*/
function connect(guest: HTMLIFrameElement | Worker, schema: ISchema = {}): Promise<IConnection> {
function connect(guest: HTMLIFrameElement | Worker | NodeWorker, schema: ISchema = {}): Promise<IConnection> {
if (!guest) throw new Error("a target is required");

const guestIsWorker = (guest as Worker).onerror !== undefined && (guest as Worker).onmessage !== undefined;
const guestIsWorker = isNodeWorker(guest) || ((guest as Worker).onerror !== undefined && (guest as Worker).onmessage !== undefined);
const listeners = guestIsWorker || isNodeEnv() ? guest : window;

return new Promise((resolve) => {
const connectionID = generateId();

// on handshake request
function handleHandshake(event: any) {
if (!guestIsWorker && !isNodeEnv() && !isValidTarget(guest as HTMLIFrameElement, event)) return;
if (event.data.action !== actions.HANDSHAKE_REQUEST) return;

if (!guestIsWorker && !isNodeEnv() && !isValidTarget(guest, event)) return;

const eventData = getEventData(event);
if (eventData?.action !== actions.HANDSHAKE_REQUEST) return;

// register local methods
const localMethods = extractMethods(schema);
Expand All @@ -46,8 +61,8 @@ function connect(guest: HTMLIFrameElement | Worker, schema: ISchema = {}): Promi

// register remote methods
const { remote, unregisterRemote } = registerRemoteMethods(
event.data.schema,
event.data.methods,
eventData.schema,
eventData.methods,
connectionID,
event,
guestIsWorker || isNodeEnv() ? (guest as Worker) : undefined
Expand All @@ -66,26 +81,29 @@ function connect(guest: HTMLIFrameElement | Worker, schema: ISchema = {}): Promi

// close the connection and all listeners when called
const close = () => {
listeners.removeEventListener(events.MESSAGE, handleHandshake);
removeEventListener(listeners, events.MESSAGE, handleHandshake);
unregisterRemote();
unregisterLocal();
if (guestIsWorker) (guest as Worker).terminate();
if (guestIsWorker) {
(guest as Worker).terminate();
}
};

const connection: IConnection = { remote, close };
connections[connectionID] = connection;
}

// subscribe to HANDSHAKE MESSAGES
listeners.addEventListener(events.MESSAGE, handleHandshake);
addEventListener(listeners, events.MESSAGE, handleHandshake);

// on handshake reply
function handleHandshakeReply(event: any) {
if (event.data.action !== actions.HANDSHAKE_REPLY) return;
return resolve(connections[event.data.connectionID]);
const eventData = getEventData(event);
if (eventData?.action !== actions.HANDSHAKE_REPLY) return;
return resolve(connections[eventData.connectionID]);
}

listeners.addEventListener(events.MESSAGE, handleHandshakeReply);
addEventListener(listeners, events.MESSAGE, handleHandshakeReply);
});
}

Expand Down

0 comments on commit db2c795

Please sign in to comment.