Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating an Upstash Redis Checkpoint Saver #435

Open
krusadellc opened this issue Aug 31, 2024 · 8 comments
Open

Creating an Upstash Redis Checkpoint Saver #435

krusadellc opened this issue Aug 31, 2024 · 8 comments

Comments

@krusadellc
Copy link

We are using langchainjs / langgraphjs in a NextJS web app for building an AI Agent.
We want to store the state / history, but since the app runs on Vercel, we cannot use Memory Saver.

So we are trying to implement a Redis Checkpoint saver using Upstash Redis.

We have 2 questions -

  1. Should we be using checkpoint ts or checkpoint id as the row key?
const key = `checkpoint:${thread_id}:${checkpoint_ts_epoch}`;

or

const key = `checkpoint:${thread_id}:${checkpoint_id}`;

Some of the samples suggest using checkpoint ts where as others use checkpoint_id.
We tried logging the checkpoint_id but it looked like a UUID which wouldn't allow sorting to retrieve the latest checkpoint from a list. Yes, we could override the way checkpoint_id is generated, but we would like to minimize the customization.

  1. What's the suggested approach for serializing / deserializing the data?

We used the default serde, which may be using JsonPlusSerializer but the data dumped to redis is binary encoded.
This makes it difficult to understand and debug what's going on. Is there a way to store plain-text (JSON or Stringify-ed JSON) as a value in Redis and retrieve it back?

@Masstronaut
Copy link
Contributor

Should we be using checkpoint ts or checkpoint id as the row key

you should use the checkpoint id. They're implemented using UUID v6, which is highly collision resistant and maintains chronological sort compatibility. It should meet your needs out of the box.

What's the suggested approach for serializing / deserializing the data

The binary data you're seeing is the raw bytes of a stringified JSON representation of the data. Here's an example from a postgres checkpointer I wrote where the stringified JSON is put into a node buffer.

To deserialize, you'll want to use load function from @langchain/core, as I did here. Note that in this example I'm converting the raw postgres bytea data type (represented in JS as a node buffer) to a string using node's Buffer.toString().

In your checkpointer implementation you can skip the conversion to byte arrays / buffers, instead reading and writing a raw string in redis.

Still make sure you use the load function from @langchain/core correctly when deserializing though - some of the values being serialized are classes that need their prototype to be set up, unlike plain JS objects and TS types.

@brandon-pereira
Copy link

@Masstronaut your implementation was super helpful! It looks like a newer version of LangGraph added support for putWrites which also added a new database table.. If anyone else finds this, this is how I got this working with Sequelize:

import type { RunnableConfig } from "@langchain/core/runnables";
import { load } from "@langchain/core/load";
import {
  BaseCheckpointSaver,
  type Checkpoint,
  type CheckpointListOptions,
  type CheckpointTuple,
  type PendingWrite,
  type CheckpointMetadata,
} from "@langchain/langgraph-checkpoint";
import {
  ChatBotCheckpointModel,
  ChatBotCheckpointWritesModel,
} from "../../../../models";

/**
 * Custom LangGraph Checkpoint Saver adapted for Postgres/Sequelize
 *
 * Inspired By:
 * https://langchain-ai.github.io/langgraphjs/how-tos/persistence/
 * https://github.com/langchain-ai/langgraphjs/blob/main/libs/checkpoint-sqlite/src/index.ts
 * https://github.com/Masstronaut/langchain-endpoint/blob/6f8b22cb75bb405ede0e72bb5f308cd410411273/src/agent/postgrescheckpointer.ts
 * https://github.com/langchain-ai/langgraphjs/issues/435
 */
export class SequelizeSaver extends BaseCheckpointSaver {
  async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
    const {
      thread_id,
      checkpoint_ns = "",
      checkpoint_id,
    } = config.configurable ?? {};
    let row: ChatBotCheckpointModel | null;

    if (checkpoint_id) {
      row = await ChatBotCheckpointModel.findOne({
        where: { thread_id, checkpoint_ns, checkpoint_id },
      });
    } else {
      row = await ChatBotCheckpointModel.findOne({
        where: { thread_id, checkpoint_ns },
        order: [["checkpoint_id", "DESC"]],
      });
    }

    if (!row) {
      return undefined;
    }

    let finalConfig = config;
    if (!checkpoint_id) {
      finalConfig = {
        configurable: {
          thread_id: row.thread_id,
          checkpoint_ns,
          checkpoint_id: row.checkpoint_id,
        },
      };
    }

    if (
      !finalConfig.configurable?.thread_id ||
      !finalConfig.configurable?.checkpoint_id
    ) {
      throw new Error("Missing thread_id or checkpoint_id");
    }

    const pendingWritesRows = await ChatBotCheckpointWritesModel.findAll({
      where: {
        thread_id: finalConfig.configurable.thread_id.toString(),
        checkpoint_ns,
        checkpoint_id: finalConfig.configurable.checkpoint_id.toString(),
      },
    });

    const pendingWrites = await Promise.all(
      pendingWritesRows.map(async (row) => {
        const record = row.toJSON();
        return [record.task_id, record.channel, record.value] as [
          string,
          string,
          unknown,
        ];
      })
    );

    return {
      config: finalConfig,
      checkpoint: await load(row.checkpoint.toString()),
      metadata: await load(row.metadata.toString()),
      parentConfig: row.parent_checkpoint_id
        ? {
            configurable: {
              thread_id: row.thread_id,
              checkpoint_ns,
              checkpoint_id: row.parent_checkpoint_id,
            },
          }
        : undefined,
      pendingWrites,
    };
  }

  async *list(
    config: RunnableConfig,
    options?: CheckpointListOptions
  ): AsyncGenerator<CheckpointTuple> {
    const { limit, before } = options ?? {};

    const thread_id = config.configurable?.thread_id;
    const whereClause: WhereOptions<ChatBotCheckpointModel> = { thread_id };

    if (before) {
      whereClause.checkpoint_id = {
        [Op.lt]: before.configurable?.checkpoint_id,
      };
    }

    const rows = await ChatBotCheckpointModel.findAll({
      where: whereClause,
      order: [["checkpoint_id", "DESC"]],
      limit,
    });

    for (const row of rows) {
      yield {
        config: {
          configurable: {
            thread_id: row.thread_id,
            checkpoint_ns: row.checkpoint_ns,
            checkpoint_id: row.checkpoint_id,
          },
        },
        checkpoint: await load(row.checkpoint.toString()),
        metadata: await load(row.metadata.toString()),
        parentConfig: row.parent_checkpoint_id
          ? {
              configurable: {
                thread_id: row.thread_id,
                checkpoint_ns: row.checkpoint_ns,
                checkpoint_id: row.parent_checkpoint_id,
              },
            }
          : undefined,
      };
    }
  }

  async put(
    config: RunnableConfig,
    checkpoint: Checkpoint,
    metadata: CheckpointMetadata
  ): Promise<RunnableConfig> {
    const checkpointColumn = Buffer.from(JSON.stringify(checkpoint));
    const metadataColumn = Buffer.from(JSON.stringify(metadata));

    await ChatBotCheckpointModel.upsert(
      {
        thread_id: config.configurable?.thread_id?.toString(),
        checkpoint_ns: config.configurable?.checkpoint_ns,
        checkpoint_id: checkpoint.id,
        parent_checkpoint_id: config.configurable?.checkpoint_id,
        checkpoint: checkpointColumn,
        metadata: metadataColumn,
      },
      {
        conflictFields: ["thread_id", "checkpoint_ns", "checkpoint_id"],
      }
    );

    return {
      configurable: {
        thread_id: config.configurable?.thread_id,
        checkpoint_ns: config.configurable?.checkpoint_ns,
        checkpoint_id: checkpoint.id,
      },
    };
  }

  async putWrites(
    config: RunnableConfig,
    writes: PendingWrite[],
    taskId: string
  ): Promise<void> {
    const rows = writes.map((write, idx) => {
      const serializedWrite = Buffer.from(JSON.stringify(write[1]));
      return {
        thread_id: config.configurable?.thread_id,
        checkpoint_ns: config.configurable?.checkpoint_ns,
        checkpoint_id: config.configurable?.checkpoint_id,
        task_id: taskId,
        idx,
        channel: write[0],
        value: serializedWrite,
      };
    });

    await ChatBotCheckpointWritesModel.bulkCreate(rows, {
      updateOnDuplicate: ["value"],
    });
  }
}

@hgoona
Copy link

hgoona commented Sep 26, 2024

Hi @brandon-pereira @Masstronaut! Do any of you have a way of explaining the Checkpointers to me? I've been pulling my hair out looking through the docs and (JS) code for the MemorySaver, the MongoDB and Postgres version, and I need to set one up for SurrealDB instead.

I'm very confused as to what is being stored where.

Is there a simple breakdown of what the data Type(s) are that is stored on the DB?
Before Checkpointers existed, I was previously storing each message in a separate collection, and a thread which contains list of messages which are linked (via relations in SurrealDB) which kept the duplication of data super slim, and I was culling all the crap that comes with LangChain types out of it.

With Checkpointers, I need to retain all that stuff, but I'm lost understanding where data transformations are occurring, what's being duplicated, where etc. , what Data Type needs to be stored, what Data Type is retrieved (or "assembled" upon retrieval).

Anyone able to help with an overview of this? I'm currently trying to mirror whats happening in the MemorySaver version (with SERDE gutted from it) but my data looks very different.
I'd love to know/see an example of: whats going IN. whats coming OUT (and how/if it needs transforming).
Even if one of you might have a starting point suggestion, of which method to fix first, that'd be greatly appreciated. I'm currently chasing my tail it seems with this...

*I hope this is not too much of a noob question.

@hgoona
Copy link

hgoona commented Sep 26, 2024

EDIT: I think I have found my saving grace - the Sqlite Saver.. I'm just following through the code and swapping parts out. I think it might be at least be good starting point 🤷🏾‍♂️??

Any gotcha's I should be weary of? (hoping to setup the Checkpointer for Production use eventually)

@jacoblee93
Copy link
Collaborator

Hey yes the SQLite one is a good place to start - definitely need better docs around the mechanics.

There's still a few improvements around structure from the Python library to port over but what's there should work just fine.

@hgoona
Copy link

hgoona commented Sep 26, 2024

Hey yes the SQLite one is a good place to start - definitely need better docs around the mechanics.

There's still a few improvements around structure from the Python library to port over but what's there should work just fine.

🤞🏾I'm still making my way through it. I noticed the code on Github contains all the Types for reference so this has helped! (When clicking through the Class that is inside of the node module the Types are not available << not sure if this is a normal in the code of a package)

@hgoona
Copy link

hgoona commented Sep 27, 2024

@jacoblee93 I've refactored my code down but the putWrites is still confusing...

Looking through the Sqlite Memory Saver, what are the Types for rows in this part:

const transaction = this.db.transaction((rows) => { 

?

Sqlite putWrites code:

async putWrites(
		config: RunnableConfig,
		writes: PendingWrite[],
		taskId: string
	): Promise<void> {
		debugLog(debug_checkpointer, `${serviceName} putWrites▶`)

		const stmt = this.db.prepare(`
      INSERT OR REPLACE INTO writes 
      (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) 
      VALUES (?, ?, ?, ?, ?, ?, ?, ?)
    `);

		const transaction = this.db.transaction((rows) => {
			for (const row of rows) {
				stmt.run(...row);
			}
		});

		const rows = writes.map((write, idx) => {
			const [type, serializedWrite] = this.serde.dumpsTyped(write[1]);
			return [
				config.configurable?.thread_id,
				config.configurable?.checkpoint_ns,
				config.configurable?.checkpoint_id,
				taskId,
				idx,
				write[0],
				type,
				serializedWrite,
			];
		});

		transaction(rows);
	}
}

I'm trying to translate this code to and SurrealDB query

@hgoona
Copy link

hgoona commented Sep 27, 2024

UPDATE 1: I've got something working right now with SurrealDB. However, am I correct in that with a simple Chatbot setup (Input >> Agent Node >> End), if I send 1 msg, and get 1 replay back (where 0 tools are invoked) my database stores:
3 checkpoints records +
image
4 writes records
image

A. Have I done something wrong here?

B. Why on earth is this LangChain API so damn verbose?! Does it need to be?

C. Is it normal?
In my previous setup I would have stored 3 records:
User msg,
AI msg - (tool calls + answer merged into this),
Thread - (which links to these 2 messages - no storage duplication), however I didn't have checkpointing/replay/branching / don't know if this is possible?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants