Skip to content

Commit

Permalink
Implement "Run From Here" and fix some graph processor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
abrenneke committed Sep 4, 2024
1 parent 3f8d710 commit b25cb35
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 21 deletions.
6 changes: 5 additions & 1 deletion packages/app-executor/bin/executor.mts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ const rivetDebugger = startDebuggerServer({
port,
allowGraphUpload: true,
datasetProvider,
dynamicGraphRun: async ({ graphId, inputs, runToNodeIds, contextValues }) => {
dynamicGraphRun: async ({ graphId, inputs, runToNodeIds, contextValues, runFromNodeId }) => {
console.log(`Running graph ${graphId} with inputs:`, inputs);

const project = currentDebuggerState.uploadedProject;
Expand Down Expand Up @@ -146,6 +146,10 @@ const rivetDebugger = startDebuggerServer({
processor.processor.runToNodeIds = runToNodeIds;
}

if (runFromNodeId) {
processor.processor.runFromNodeId = runFromNodeId;
}

await processor.run();
} catch (err) {
console.error(err);
Expand Down
6 changes: 4 additions & 2 deletions packages/app/src/components/NodeCanvas.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ export const NodeCanvas: FC<NodeCanvasProps> = ({
handleContextMenu(e);
});

const lastRunPerNode = useRecoilValue(lastRunDataByNodeState);

const hydratedContextMenuData = useMemo((): ContextMenuContext | null => {
if (contextMenuData.data?.type.startsWith('node-')) {
const nodeType = contextMenuData.data.type.replace('node-', '');
Expand All @@ -574,6 +576,7 @@ export const NodeCanvas: FC<NodeCanvasProps> = ({
data: {
nodeType,
nodeId,
canRunFromHere: lastRunPerNode[nodeId] != null,
},
};
}
Expand All @@ -582,7 +585,7 @@ export const NodeCanvas: FC<NodeCanvasProps> = ({
type: 'blankArea',
data: {},
};
}, [contextMenuData]);
}, [contextMenuData, lastRunPerNode]);

// Idk, before we were able to unmount the context menu, but safari be weird,
// so we move it off screen instead
Expand All @@ -596,7 +599,6 @@ export const NodeCanvas: FC<NodeCanvasProps> = ({
const pinnedNodes = useRecoilValue(pinnedNodesState);

const nodeTypes = useNodeTypes();
const lastRunPerNode = useRecoilValue(lastRunDataByNodeState);
const selectedProcessPagePerNode = useRecoilValue(selectedProcessPageNodesState);

const isZoomedOut = canvasPosition.zoom < 0.4;
Expand Down
10 changes: 10 additions & 0 deletions packages/app/src/hooks/useContextMenuConfiguration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export function useContextMenuConfiguration() {
contextType: type<{
nodeType: string;
nodeId: NodeId;
canRunFromHere: boolean;
}>(),
items: [
{
Expand Down Expand Up @@ -98,6 +99,15 @@ export function useContextMenuConfiguration() {
label: 'Run to Here',
icon: PlayIcon,
},
{
id: 'node-run-from-here',
label: 'Run from Here',
icon: PlayIcon,
conditional: (context) => {
const { canRunFromHere } = context as { canRunFromHere: boolean };
return canRunFromHere;
},
},
{
id: 'node-delete',
label: 'Delete',
Expand Down
5 changes: 5 additions & 0 deletions packages/app/src/hooks/useGraphBuilderContextMenuHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ export function useGraphBuilderContextMenuHandler() {

tryRunGraph({ to: [nodeId] });
})
.with('node-run-from-here', () => {
const { nodeId } = context.data as { nodeId: NodeId };

tryRunGraph({ from: nodeId });
})
.with('node-copy', () => {
const { nodeId } = context.data as { nodeId: NodeId };
copyNodes(nodeId);
Expand Down
48 changes: 48 additions & 0 deletions packages/app/src/hooks/useLocalExecutor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
type GraphOutputs,
globalRivetNodeRegistry,
type GraphId,
type Outputs,
} from '@ironclad/rivet-core';
import { produce } from 'immer';
import { useRef } from 'react';
Expand All @@ -27,6 +28,7 @@ import { trivetState } from '../state/trivet';
import { runTrivet } from '@ironclad/trivet';
import { audioProvider, datasetProvider } from '../utils/globals';
import { entries } from '../../../core/src/utils/typeSafety';
import { type RunDataByNodeId, lastRunData, lastRunDataByNodeState } from '../state/dataFlow';

export function useLocalExecutor() {
const project = useRecoilValue(projectState);
Expand All @@ -43,6 +45,7 @@ export function useLocalExecutor() {
const recordExecutions = useRecoilValue(recordExecutionsState);
const projectData = useRecoilValue(projectDataState);
const projectContext = useRecoilValue(projectContextState(project.metadata.id));
const lastRunData = useRecoilValue(lastRunDataByNodeState);

function attachGraphEvents(processor: GraphProcessor) {
processor.on('nodeStart', currentExecution.onNodeStart);
Expand Down Expand Up @@ -90,6 +93,7 @@ export function useLocalExecutor() {
options: {
graphId?: GraphId;
to?: NodeId[];
from?: NodeId;
} = {},
) => {
try {
Expand Down Expand Up @@ -120,6 +124,11 @@ export function useLocalExecutor() {
processor.runToNodeIds = options.to;
}

if (options.from) {
preloadDependentDataForNode(processor, options.from, lastRunData);
processor.runFromNodeId = options.from;
}

if (recordExecutions) {
recorder.record(processor);
}
Expand Down Expand Up @@ -257,3 +266,42 @@ export function useLocalExecutor() {
tryRunTests,
};
}

function preloadDependentDataForNode(processor: GraphProcessor, nodeId: NodeId, previousRunData: RunDataByNodeId) {
const dependencyNodes = processor.getDependencyNodesDeep(nodeId);

for (const dependencyNode of dependencyNodes) {
const dependencyNodeData = previousRunData[dependencyNode];

if (!dependencyNodeData) {
throw new Error(`Node ${dependencyNode} was not found in the previous run data, cannot continue preloading data`);
}

const firstExecution = dependencyNodeData[0];

if (!firstExecution?.data.outputData) {
throw new Error(
`Node ${dependencyNode} has no output data in the previous run data, cannot continue preloading data`,
);
}

const { outputData } = firstExecution.data;

// Convert back to DataValue from DataValueWithRefs
const outputDataWithoutRefs = Object.fromEntries(
Object.entries(outputData).map(([portId, dataValueWithRefs]) => {
if (dataValueWithRefs.type === 'image') {
throw new Error('Not implemented yed');
} else if (dataValueWithRefs.type === 'binary') {
throw new Error('Not implemented yed');
} else if (dataValueWithRefs.type === 'audio') {
throw new Error('Not implemented yed');
} else {
return [portId, dataValueWithRefs];
}
}),
) as Outputs;

processor.preloadNodeData(dependencyNode, outputDataWithoutRefs);
}
}
1 change: 0 additions & 1 deletion packages/app/src/hooks/useRemoteDebugger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ export function useRemoteDebugger(options: { onConnect?: () => void; onDisconnec

const connectRef = useRef<((url: string) => void) | undefined>();
const reconnectingTimeout = useRef<ReturnType<typeof setTimeout> | undefined>();
const selectedExecutor = useRecoilValue(selectedExecutorState);

connectRef.current = (url: string) => {
if (!url) {
Expand Down
63 changes: 61 additions & 2 deletions packages/app/src/hooks/useRemoteExecutor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import {
serializeDatasets,
type GraphId,
type DataValue,
GraphProcessor,
type Outputs,
} from '@ironclad/rivet-core';
import { useCurrentExecution } from './useCurrentExecution';
import { useRecoilState, useRecoilValue, useSetRecoilState } from 'recoil';
Expand All @@ -25,6 +27,7 @@ import { pluginsState } from '../state/plugins';
import { entries } from '../../../core/src/utils/typeSafety';
import { selectedExecutorState } from '../state/execution';
import { datasetProvider } from '../utils/globals';
import { type RunDataByNodeId, lastRunDataByNodeState } from '../state/dataFlow';

// TODO: This allows us to retrieve the GraphOutputs from the remote debugger.
// If the remote debugger events had a unique ID for each run, this would feel a lot less hacky.
Expand All @@ -46,6 +49,7 @@ export function useRemoteExecutor() {
const setUserInputQuestions = useSetRecoilState(userInputModalQuestionsState);
const selectedExecutor = useRecoilValue(selectedExecutorState);
const projectContext = useRecoilValue(projectContextState(project.metadata.id));
const lastRunData = useRecoilValue(lastRunDataByNodeState);

const remoteDebugger = useRemoteDebugger({
onDisconnect: () => {
Expand Down Expand Up @@ -119,7 +123,7 @@ export function useRemoteExecutor() {
}
});

const tryRunGraph = async (options: { to?: NodeId[]; graphId?: GraphId } = {}) => {
const tryRunGraph = async (options: { to?: NodeId[]; from?: NodeId; graphId?: GraphId } = {}) => {
if (
!remoteDebugger.remoteDebuggerState.started ||
remoteDebugger.remoteDebuggerState.socket?.readyState !== WebSocket.OPEN
Expand Down Expand Up @@ -171,7 +175,21 @@ export function useRemoteExecutor() {
{} as Record<string, DataValue>,
);

remoteDebugger.send('run', { graphId: graphToRun, runToNodeIds: options.to, contextValues });
if (options.from) {
// Use a local graph processor to get dependency nodes instead of asking the remote debugger
const processor = new GraphProcessor(project, graph.metadata!.id!);
const dependencyNodes = processor.getDependencyNodesDeep(options.from);
const preloadData = getDependentDataForNodeForPreload(dependencyNodes, lastRunData);

remoteDebugger.send('preload', { nodeData: preloadData });
}

remoteDebugger.send('run', {
graphId: graphToRun,
runToNodeIds: options.to,
contextValues,
runFromNodeIds: options.from,
});
} catch (e) {
console.error(e);
}
Expand Down Expand Up @@ -304,3 +322,44 @@ export function useRemoteExecutor() {
tryRunTests,
};
}

function getDependentDataForNodeForPreload(dependencyNodes: NodeId[], previousRunData: RunDataByNodeId) {
const preloadData: Record<NodeId, Outputs> = {};

for (const dependencyNode of dependencyNodes) {
const dependencyNodeData = previousRunData[dependencyNode];

if (!dependencyNodeData) {
throw new Error(`Node ${dependencyNode} was not found in the previous run data, cannot continue preloading data`);
}

const firstExecution = dependencyNodeData[0];

if (!firstExecution?.data.outputData) {
throw new Error(
`Node ${dependencyNode} has no output data in the previous run data, cannot continue preloading data`,
);
}

const { outputData } = firstExecution.data;

// Convert back to DataValue from DataValueWithRefs
const outputDataWithoutRefs = Object.fromEntries(
Object.entries(outputData).map(([portId, dataValueWithRefs]) => {
if (dataValueWithRefs.type === 'image') {
throw new Error('Not implemented yed');
} else if (dataValueWithRefs.type === 'binary') {
throw new Error('Not implemented yed');
} else if (dataValueWithRefs.type === 'audio') {
throw new Error('Not implemented yed');
} else {
return [portId, dataValueWithRefs];
}
}),
) as Outputs;

preloadData[dependencyNode] = outputDataWithoutRefs;
}

return preloadData;
}
Loading

0 comments on commit b25cb35

Please sign in to comment.