From 523d21c26ed6dc711e9fcf24522315fb3136e7ab Mon Sep 17 00:00:00 2001 From: Nadeesha Cabral Date: Tue, 3 Dec 2024 21:54:53 +1100 Subject: [PATCH] feat: Add graphql to data-connector (#209) * update * update * fix: Add graphql to data-connector * update * update --- data-connector/.env.example | 9 +- data-connector/README.md | 19 +- data-connector/config.json | 8 + data-connector/package-lock.json | 36 +- data-connector/package.json | 8 +- data-connector/src/graphql/graphql.ts | 315 ++++++++++++++++++ data-connector/src/index.ts | 44 ++- data-connector/src/{ => open-api}/open-api.ts | 2 +- data-connector/src/{ => postgres}/postgres.ts | 9 +- 9 files changed, 418 insertions(+), 32 deletions(-) create mode 100644 data-connector/src/graphql/graphql.ts rename data-connector/src/{ => open-api}/open-api.ts (99%) rename data-connector/src/{ => postgres}/postgres.ts (94%) diff --git a/data-connector/.env.example b/data-connector/.env.example index ba8237b0..25f50972 100644 --- a/data-connector/.env.example +++ b/data-connector/.env.example @@ -14,4 +14,11 @@ POSTGRES_SCHEMA=public # # OPENAPI_SPEC_URL=https://api.inferable.ai/public/oas.json # SERVER_URL=https://api.inferable.ai -# SERVER_AUTH_HEADER=Authorization: Bearer your_api_key_here +# SERVER_AUTH_HEADER=Bearer your_api_key_here\ + +# +# GraphQL config example. +# +# GRAPHQL_SCHEMA_URL=https://docs.github.com/public/fpt/schema.docs.graphql +# GRAPHQL_ENDPOINT=https://api.github.com/graphql +# GRAPHQL_AUTH_HEADER="bearer ghp_xxx" diff --git a/data-connector/README.md b/data-connector/README.md index 924c6116..de86594f 100644 --- a/data-connector/README.md +++ b/data-connector/README.md @@ -13,11 +13,11 @@ Inferable Data Connector is a bridge between your data systems and Inferable. Co ## Connectors -- [x] [Postgres](./src/postgres.ts) -- [x] [OpenAPI](./src/open-api.ts) -- [ ] [GraphQL](./src/graphql.ts) -- [ ] [MySQL](./src/mysql.ts) -- [ ] [SQLite](./src/sqlite.ts) +- [x] [Postgres](./src/postgres/postgres.ts) +- [x] [OpenAPI](./src/open-api/open-api.ts) +- [x] [GraphQL](./src/graphql/graphql.ts) +- [ ] [MySQL](./src/mysql/mysql.ts) +- [ ] [SQLite](./src/sqlite/sqlite.ts) ## Quick Start @@ -114,6 +114,15 @@ Each connector is defined in the `config.connectors` array. +
+GraphQL Connector Configuration + +- `config.connectors[].schemaUrl`: The URL to your GraphQL schema. Must be publicly accessible. +- `config.connectors[].endpoint`: The endpoint to use. (e.g. `https://api.inferable.ai`) +- `config.connectors[].defaultHeaders`: The default headers to use. (e.g. `{"Authorization": "Bearer "}`) + +
+ ### config.privacyMode When enabled (`config.privacyMode=1`), raw data is never sent to the model. Instead: diff --git a/data-connector/config.json b/data-connector/config.json index 5720f380..16f158d0 100644 --- a/data-connector/config.json +++ b/data-connector/config.json @@ -15,6 +15,14 @@ "defaultHeaders": { "Authorization": "process.env.SERVER_AUTH_HEADER" } + }, + { + "type": "graphql", + "schemaUrl": "process.env.GRAPHQL_SCHEMA_URL", + "endpoint": "process.env.GRAPHQL_ENDPOINT", + "defaultHeaders": { + "Authorization": "process.env.GRAPHQL_AUTH_HEADER" + } } ] } diff --git a/data-connector/package-lock.json b/data-connector/package-lock.json index 107a4e6e..de4dc235 100644 --- a/data-connector/package-lock.json +++ b/data-connector/package-lock.json @@ -1,15 +1,17 @@ { - "name": "inferable-sql-manager", + "name": "inferable-data-connector", "version": "1.0.0", "lockfileVersion": 3, "requires": true, "packages": { "": { - "name": "inferable-sql-manager", + "name": "inferable-data-connector", "version": "1.0.0", "dependencies": { "dotenv": "^16.4.5", "fastify": "^4.28.1", + "graphql": "^16.9.0", + "graphql-tag": "^2.12.6", "inferable": "^0.30.48", "pg": "^8.13.1", "tsx": "^4.19.2" @@ -818,6 +820,30 @@ "url": "https://github.com/privatenumber/get-tsconfig?sponsor=1" } }, + "node_modules/graphql": { + "version": "16.9.0", + "resolved": "https://registry.npmjs.org/graphql/-/graphql-16.9.0.tgz", + "integrity": "sha512-GGTKBX4SD7Wdb8mqeDLni2oaRGYQWjWHGKPQ24ZMnUtKfcsVoiv4uX8+LJr1K6U5VW2Lu1BwJnj7uiori0YtRw==", + "license": "MIT", + "engines": { + "node": "^12.22.0 || ^14.16.0 || ^16.0.0 || >=17.0.0" + } + }, + "node_modules/graphql-tag": { + "version": "2.12.6", + "resolved": "https://registry.npmjs.org/graphql-tag/-/graphql-tag-2.12.6.tgz", + "integrity": "sha512-FdSNcu2QQcWnM2VNvSCCDCVS5PpPqpzgFT8+GXzqJuoDd0CBncxCY278u4mhRO7tMgo2JjgJA5aZ+nWSQ/Z+xg==", + "license": "MIT", + "dependencies": { + "tslib": "^2.1.0" + }, + "engines": { + "node": ">=10" + }, + "peerDependencies": { + "graphql": "^0.9.0 || ^0.10.0 || ^0.11.0 || ^0.12.0 || ^0.13.0 || ^14.0.0 || ^15.0.0 || ^16.0.0" + } + }, "node_modules/inferable": { "version": "0.30.48", "resolved": "https://registry.npmjs.org/inferable/-/inferable-0.30.48.tgz", @@ -1233,6 +1259,12 @@ "node": ">=12" } }, + "node_modules/tslib": { + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" + }, "node_modules/tsx": { "version": "4.19.2", "resolved": "https://registry.npmjs.org/tsx/-/tsx-4.19.2.tgz", diff --git a/data-connector/package.json b/data-connector/package.json index c8fac5b4..38afdbff 100644 --- a/data-connector/package.json +++ b/data-connector/package.json @@ -12,13 +12,15 @@ "dependencies": { "dotenv": "^16.4.5", "fastify": "^4.28.1", + "graphql": "^16.9.0", + "graphql-tag": "^2.12.6", "inferable": "^0.30.48", "pg": "^8.13.1", "tsx": "^4.19.2" }, "devDependencies": { + "@faker-js/faker": "^9.1.0", "@types/node": "^20.11.19", - "typescript": "^5.3.3", - "@faker-js/faker": "^9.1.0" + "typescript": "^5.3.3" } -} \ No newline at end of file +} diff --git a/data-connector/src/graphql/graphql.ts b/data-connector/src/graphql/graphql.ts new file mode 100644 index 00000000..e9de219f --- /dev/null +++ b/data-connector/src/graphql/graphql.ts @@ -0,0 +1,315 @@ +import crypto from "crypto"; +import { + buildSchema, + GraphQLSchema, + GraphQLType, + isInputObjectType, + isListType, + isNonNullType, + isObjectType, +} from "graphql"; +import { approvalRequest, blob, ContextInput, Inferable } from "inferable"; +import { FunctionRegistrationInput } from "inferable/bin/types"; +import { z } from "zod"; +import type { DataConnector } from "../types"; + +export class GraphQLClient implements DataConnector { + private schema: GraphQLSchema | null = null; + private initialized = false; + private operations: ReturnType< + typeof this.graphqlOperationToInferableFunction + >[] = []; + + constructor( + private params: { + name?: string; + schemaUrl: string; + endpoint: string; + defaultHeaders?: Record; + privacyMode: boolean; + paranoidMode: boolean; + }, + ) {} + + public initialize = async () => { + try { + const response = await fetch(this.params.schemaUrl); + const schemaSDL = await response.text(); + this.schema = buildSchema(schemaSDL); + console.log( + `GraphQL schema loaded successfully from ${this.params.schemaUrl}`, + ); + + // Parse the schema and extract operations + const queryType = this.schema.getQueryType(); + const mutationType = this.schema.getMutationType(); + + if (queryType) { + const fields = queryType.getFields(); + for (const [fieldName, field] of Object.entries(fields)) { + const operation = this.graphqlOperationToInferableFunction( + fieldName, + field, + "query", + ); + this.operations.push(operation); + } + } + + if (mutationType) { + const fields = mutationType.getFields(); + for (const [fieldName, field] of Object.entries(fields)) { + const operation = this.graphqlOperationToInferableFunction( + fieldName, + field, + "mutation", + ); + this.operations.push(operation); + } + } + + console.log( + `Loaded ${this.operations.length} operations from GraphQL schema`, + ); + + if (this.params.privacyMode) { + console.log( + "Privacy mode is enabled, response data will not be sent to the model.", + ); + } + + this.initialized = true; + } catch (error) { + console.error("Failed to initialize GraphQL connection:", error); + throw error; + } + }; + + private graphqlOperationToInferableFunction = ( + fieldName: string, + field: any, + operationType: "query" | "mutation", + ): FunctionRegistrationInput => { + const args = field.args || []; + const hasArguments = args.length > 0; + + const operationDefinition = `${operationType} ${fieldName}${ + hasArguments + ? "($args: " + + field.args.map((arg: any) => `${arg.name}: ${arg.type}`).join(", ") + + ")" + : "" + } { + ${fieldName}${hasArguments ? "(args: $args)" : ""} + }`; + + return { + name: `${operationType}${fieldName}`, + description: `GraphQL ${operationType} operation: ${operationDefinition} +To understand the input and output types for this operation, use the searchGraphQLDefinition function with: +{ + "operation": "${operationType}", + "fieldName": "${fieldName}" +}`, + func: this.executeRequest, + schema: { + input: z.object({ + query: z + .string() + .describe( + `The full graphql ${operationType} to execute: ${operationDefinition} ${fieldName} ${hasArguments ? `($args)` : ``} { ... }`, + ), + variables: hasArguments + ? z + .record(z.any()) + .describe( + `Operation variables. Use searchGraphQLDefinition to see the required types.`, + ) + : z.undefined(), + }), + }, + }; + }; + + executeRequest = async ( + input: { + query: string; + variables?: Record; + }, + ctx: ContextInput, + ) => { + if (this.params.paranoidMode) { + if (!ctx.approved) { + console.log("Request requires approval"); + return approvalRequest(); + } else { + console.log("Request approved"); + } + } + + if (!this.initialized) throw new Error("GraphQL schema not initialized"); + if (!this.schema) throw new Error("GraphQL schema not initialized"); + + // Merge default headers with the Content-Type header + const headers = { + "Content-Type": "application/json", + ...this.params.defaultHeaders, + }; + + const response = await fetch(this.params.endpoint, { + method: "POST", + headers, + body: JSON.stringify(input), + }); + + const data = await response.json(); + + if (this.params.privacyMode) { + return { + message: + "This request was executed in privacy mode. Data was returned to the user directly.", + blob: blob({ + name: "Results", + type: "application/json", + data, + }), + }; + } + + return data; + }; + + private connectionStringHash = () => { + return crypto + .createHash("sha256") + .update(this.params.schemaUrl) + .digest("hex") + .substring(0, 8); + }; + + createService = (client: Inferable) => { + const service = client.service({ + name: this.params.name ?? `graphql${this.connectionStringHash()}`, + }); + + // Register the search function first + service.register(this.searchGraphQLDefinitionToInferableFunction()); + + // Then register all the operations + this.operations.forEach((operation) => { + service.register(operation); + }); + + return service; + }; + + private searchGraphQLDefinitionToInferableFunction = + (): FunctionRegistrationInput => { + return { + name: "searchGraphQLDefinition", + description: + "Search for input and output type definitions of a GraphQL operation", + func: async (input: { + operation: "query" | "mutation"; + fieldName: string; + }) => { + if (!this.schema) throw new Error("GraphQL schema not initialized"); + + const operationType = + input.operation === "query" + ? this.schema.getQueryType() + : this.schema.getMutationType(); + + if (!operationType) { + throw new Error(`No ${input.operation} type found in schema`); + } + + const field = operationType.getFields()[input.fieldName]; + if (!field) { + throw new Error( + `No field ${input.fieldName} found in ${input.operation} type`, + ); + } + + const args = field.args || []; + const argTypes: Record = {}; + + args.forEach((arg) => { + argTypes[arg.name] = { + type: this.getTypeString(arg.type), + definition: this.getTypeDefinition(arg.type), + }; + }); + + const returnType = { + type: this.getTypeString(field.type), + definition: this.getTypeDefinition(field.type), + }; + + return { + operation: input.operation, + fieldName: input.fieldName, + inputTypes: argTypes, + outputType: returnType, + }; + }, + schema: { + input: z.object({ + operation: z.enum(["query", "mutation"]), + fieldName: z.string(), + }), + }, + }; + }; + + private getTypeString = (type: GraphQLType): string => { + if (isNonNullType(type)) { + return `${this.getTypeString(type.ofType)}!`; + } + if (isListType(type)) { + return `[${this.getTypeString(type.ofType)}]`; + } + return type.toString(); + }; + + private getTypeDefinition = (type: GraphQLType): any => { + if (isNonNullType(type)) { + return this.getTypeDefinition(type.ofType); + } + if (isListType(type)) { + return { + type: "list", + ofType: this.getTypeDefinition(type.ofType), + }; + } + if (isObjectType(type) || isInputObjectType(type)) { + const fields = type.getFields(); + const fieldDefs: Record = {}; + + Object.entries(fields).forEach(([fieldName, field]) => { + fieldDefs[fieldName] = { + type: this.getTypeString(field.type), + description: field.description || undefined, + }; + + if ("args" in field && field.args.length > 0) { + fieldDefs[fieldName].args = field.args.map((arg) => ({ + name: arg.name, + type: this.getTypeString(arg.type), + description: arg.description || undefined, + })); + } + }); + + return { + type: "object", + fields: fieldDefs, + }; + } + + return { + type: "scalar", + name: type.toString(), + }; + }; +} diff --git a/data-connector/src/index.ts b/data-connector/src/index.ts index 314a8dd7..2aabddfd 100644 --- a/data-connector/src/index.ts +++ b/data-connector/src/index.ts @@ -1,32 +1,33 @@ import "dotenv/config"; import { Inferable } from "inferable"; -import { PostgresClient } from "./postgres"; +import { PostgresClient } from "./postgres/postgres"; import { RegisteredService } from "inferable/bin/types"; -import { OpenAPIClient } from "./open-api"; +import { OpenAPIClient } from "./open-api/open-api"; +import { GraphQLClient } from "./graphql/graphql"; -const parseConfig = () => { - const config = require("../config.json"); - - config.connectors.forEach((connector: any) => { - for (const [key, value] of Object.entries(connector)) { - if (typeof value === "string" && value.startsWith("process.env.")) { - const actual = process.env[value.replace("process.env.", "")]; - if (!actual) { - throw new Error(`Environment variable ${value} not found`); - } - connector[key] = actual; +const parseConfig = (connector: any) => { + for (const [key, value] of Object.entries(connector)) { + if (typeof value === "object") { + console.log("Parsing object", key); + connector[key] = parseConfig(value); + } else if (typeof value === "string" && value.startsWith("process.env.")) { + const actual = process.env[value.replace("process.env.", "")]; + if (!actual) { + throw new Error(`Environment variable ${value} not found`); } + connector[key] = actual; } - }); + } - return config; + return connector; }; (async function main() { const client = new Inferable(); - const config = parseConfig(); + const config = require("../config.json"); + config.connectors = parseConfig(config.connectors); if (config.connectors.length === 0) { throw new Error("No connectors found in config.json"); @@ -54,6 +55,17 @@ const parseConfig = () => { await openAPIClient.initialize(); const service = openAPIClient.createService(client); services.push(service); + } else if (connector.type === "graphql") { + const graphQLClient = new GraphQLClient({ + ...connector, + paranoidMode: config.paranoidMode === 1, + privacyMode: config.privacyMode === 1, + }); + await graphQLClient.initialize(); + const service = graphQLClient.createService(client); + services.push(service); + } else { + throw new Error(`Unknown connector type: ${connector.type}`); } } diff --git a/data-connector/src/open-api.ts b/data-connector/src/open-api/open-api.ts similarity index 99% rename from data-connector/src/open-api.ts rename to data-connector/src/open-api/open-api.ts index 07d850f2..bda5852b 100644 --- a/data-connector/src/open-api.ts +++ b/data-connector/src/open-api/open-api.ts @@ -1,7 +1,7 @@ import { approvalRequest, blob, ContextInput, Inferable } from "inferable"; import { z } from "zod"; import fetch from "node-fetch"; -import type { DataConnector } from "./types"; +import type { DataConnector } from "../types"; import { OpenAPIV3 } from "openapi-types"; import crypto from "crypto"; import { FunctionRegistrationInput } from "inferable/bin/types"; diff --git a/data-connector/src/postgres.ts b/data-connector/src/postgres/postgres.ts similarity index 94% rename from data-connector/src/postgres.ts rename to data-connector/src/postgres/postgres.ts index 4fdbc44b..b7bc33c2 100644 --- a/data-connector/src/postgres.ts +++ b/data-connector/src/postgres/postgres.ts @@ -3,11 +3,11 @@ import { approvalRequest, blob, ContextInput, Inferable } from "inferable"; import pg from "pg"; import { z } from "zod"; import crypto from "crypto"; -import type { DataConnector } from "./types"; +import type { DataConnector } from "../types"; export class PostgresClient implements DataConnector { private client: pg.Client | null = null; - private initialized: Promise; + private initialized = false; constructor( private params: { @@ -34,6 +34,7 @@ export class PostgresClient implements DataConnector { process.removeListener("SIGTERM", this.handleSigterm); process.on("SIGTERM", this.handleSigterm); + this.initialized = true; } catch (error) { console.error("Failed to initialize database connection:", error); throw error; @@ -71,7 +72,7 @@ export class PostgresClient implements DataConnector { }; getContext = async () => { - await this.initialized; + if (!this.initialized) throw new Error("Database not initialized"); const client = await this.getClient(); const tables = await this.getAllTables(); @@ -118,7 +119,7 @@ export class PostgresClient implements DataConnector { } } - await this.initialized; + if (!this.initialized) throw new Error("Database not initialized"); const client = await this.getClient(); const res = await client.query(input.query);