Skip to content

Commit

Permalink
chore: Implement DataConnector interface in PostgresClient (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
nadeesha authored Nov 27, 2024
1 parent 80725f9 commit 16313cc
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
25 changes: 13 additions & 12 deletions data-connector/src/postgres.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import { approvalRequest, blob, ContextInput, Inferable } from "inferable";
import pg from "pg";
import { z } from "zod";
import crypto from "crypto";
import type { DataConnector } from "./types";

export class PostgresClient {
export class PostgresClient implements DataConnector {
private client: pg.Client | null = null;
private initialized: Promise<void>;

Expand All @@ -15,7 +16,7 @@ export class PostgresClient {
connectionString: string;
privacyMode: boolean;
paranoidMode: boolean;
}
},
) {
assert(params.schema, "Schema parameter is required");
this.initialized = this.initialize();
Expand All @@ -28,7 +29,7 @@ export class PostgresClient {
console.log(`Initial probe successful: ${res.rows[0].now}`);
if (this.params.privacyMode) {
console.log(
"Privacy mode is enabled, table data will not be sent to the model."
"Privacy mode is enabled, table data will not be sent to the model.",
);
}

Expand Down Expand Up @@ -65,12 +66,12 @@ export class PostgresClient {
const client = await this.getClient();
const res = await client.query(
"SELECT * FROM pg_catalog.pg_tables WHERE schemaname = $1",
[this.params.schema]
[this.params.schema],
);
return res.rows;
};

getDatabaseContext = async () => {
getContext = async () => {
await this.initialized;
const client = await this.getClient();
const tables = await this.getAllTables();
Expand All @@ -79,7 +80,7 @@ export class PostgresClient {

for (const table of tables) {
const sample = await client.query(
`SELECT * FROM ${this.params.schema}.${table.tablename} LIMIT 1`
`SELECT * FROM ${this.params.schema}.${table.tablename} LIMIT 1`,
);

if (sample.rows.length > 0) {
Expand All @@ -91,8 +92,8 @@ export class PostgresClient {
? []
: sample.rows.map((row) =>
Object.values(row).map((value) =>
String(value).substring(0, 50)
)
String(value).substring(0, 50),
),
)[0],
};
context.push(tableContext);
Expand Down Expand Up @@ -152,16 +153,16 @@ export class PostgresClient {
});

service.register({
name: "getDatabaseContext",
func: this.getDatabaseContext,
description: "Gets the context of the database",
name: "getContext",
func: this.getContext,
description: "Gets the schema of the database.",
});

service.register({
name: "executeQuery",
func: this.executeQuery,
description:
"Executes a raw SQL query. If this fails, you need to getDatabaseContext to learn the schema first.",
"Executes a raw SQL query. If this fails, you need to getContext to learn the schema first.",
schema: {
input: z.object({
query: z.string().describe("The query to execute"),
Expand Down
7 changes: 7 additions & 0 deletions data-connector/src/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import type { ContextInput, Inferable } from "inferable";

export interface DataConnector {
getContext(): Promise<any[]>;
executeQuery(input: { query: string }, ctx: ContextInput): Promise<any>;
createService(client: Inferable): any;
}

0 comments on commit 16313cc

Please sign in to comment.