Skip to content

Commit

Permalink
Add support for importing vectors (#1208)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexis Rico <[email protected]>
  • Loading branch information
SferaDev authored Oct 11, 2023
1 parent b4396bf commit 9865f12
Show file tree
Hide file tree
Showing 6 changed files with 522 additions and 234 deletions.
5 changes: 5 additions & 0 deletions .changeset/wise-impalas-kneel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@xata.io/importer": patch
---

Add support for importing vectors
12 changes: 4 additions & 8 deletions cli/src/commands/import/csv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ export default class ImportCSV extends BaseCommand<typeof ImportCSV> {
'null-value': nullValues
} = flags;
const header = !noHeader;
let columns = flagsToColumns(flags);
const flagColumns = flagsToColumns(flags);

const csvOptions = {
delimiter,
delimitersToGuess: delimitersToGuess ? splitCommas(delimitersToGuess) : undefined,
header,
nullValues,
columns,
columns: flagColumns,
limit
};

Expand All @@ -107,12 +107,8 @@ export default class ImportCSV extends BaseCommand<typeof ImportCSV> {
if (!parseResults.success) {
throw new Error(`Failed to parse CSV file ${parseResults.errors.join(' ')}`);
}
if (!columns) {
columns = parseResults.columns;
}
if (!columns) {
throw new Error('No columns found');
}

const { columns } = parseResults;
await this.migrateSchema({ table, columns, create });

let importSuccessCount = 0;
Expand Down
32 changes: 29 additions & 3 deletions packages/importer/src/columns.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ const parseMultiple = (value: string): string[] | null => {

const isGuessableMultiple = <T>(value: T): boolean => Array.isArray(value) || tryIsJsonArray(String(value));

const isGuessableVectorColumn = <T>(values: T[]): boolean => {
const checks = values.map((value) => {
const isMultiple = isGuessableMultiple(value);
if (!isMultiple) return null;

const array = parseMultiple(String(value)) ?? [];
if (array.some((item) => !isFloat(item))) return null;

return array.length;
});

return checks.every((length) => length !== null && length > 50);
};

const isMultiple = <T>(value: T): boolean => isGuessableMultiple(value) || tryIsCsvArray(String(value));

const isMaybeMultiple = <T>(value: T): boolean => isMultiple(value) || typeof value === 'string';
Expand Down Expand Up @@ -101,6 +115,9 @@ export const guessColumnTypes = <T>(
if (columnValues.every(isEmail)) {
return 'email';
}
if (isGuessableVectorColumn(columnValues)) {
return 'vector';
}
if (columnValues.some(isGuessableMultiple)) {
return 'multiple';
}
Expand All @@ -121,7 +138,7 @@ export type CoercedValue = {

export const coerceValue = async (
value: unknown,
type: Schemas.Column['type'],
column: Schemas.Column,
options: ColumnOptions = {}
): Promise<CoercedValue> => {
const { isNull = defaultIsNull, toBoolean = defaultToBoolean, proxyFunction } = options;
Expand All @@ -130,7 +147,7 @@ export const coerceValue = async (
return { value: null, isError: false };
}

switch (type) {
switch (column.type) {
case 'string':
case 'text':
case 'link': {
Expand All @@ -153,6 +170,15 @@ export const coerceValue = async (
const date = anyToDate(value);
return date.invalid ? { value: null, isError: true } : { value: date, isError: false };
}
case 'vector': {
if (!isMaybeMultiple(value)) return { value: null, isError: true };
const array = parseMultiple(String(value));

return {
value: array,
isError: array?.some((item) => !isFloat(item)) || array?.length !== column.vector?.dimension
};
}
case 'multiple': {
return isMaybeMultiple(value)
? { value: parseMultiple(String(value)), isError: false }
Expand Down Expand Up @@ -218,7 +244,7 @@ export const coerceRows = async <T extends Record<string, unknown>>(
for (const row of rows) {
const mappedRow: Record<string, CoercedValue> = {};
for (const column of columns) {
mappedRow[column.name] = await coerceValue(row[column.name], column.type, options);
mappedRow[column.name] = await coerceValue(row[column.name], column, options);
}
mapped.push(mappedRow);
}
Expand Down
3 changes: 2 additions & 1 deletion packages/importer/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ export const importColumnTypes = z.enum([
'link',
'multiple',
'file',
'file[]'
'file[]',
'vector'
]);
22 changes: 20 additions & 2 deletions packages/importer/src/parsers/jsonParser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,35 @@ import JSON from 'json5';
import { coerceRows, guessColumns } from '../columns';
import { ParseJsonOptions, ParseResults } from '../types';
import { isDefined, isObject, isXataFile, partition } from '../utils/lang';
import { Schemas } from '@xata.io/client';

const arrayToObject = (array: unknown[]) => {
return Object.fromEntries(array.map((value, index) => [index, value]));
};

// Some columns need to be prepared before coercing the rows
const prepareColumns = (columns: Schemas.Column[], values: { data: Record<string, unknown> }[]): Schemas.Column[] => {
return columns.map((column) => {
switch (column.type) {
case 'vector': {
// We get the dimension from the first row
const dimension = (values[0]?.data?.[column.name] as unknown[])?.length ?? 0;
return { ...column, vector: { dimension } };
}
default:
return column;
}
});
};

export const parseJson = async (options: ParseJsonOptions, startIndex = 0): Promise<ParseResults> => {
const { data: input, columns: externalColumns, limit } = options;

const array = Array.isArray(input) ? input : isObject(input) ? [input] : JSON.parse(input);

const arrayUpToLimit = isDefined(limit) ? array.slice(0, limit) : array;
const columns = externalColumns ?? guessColumns(arrayUpToLimit, options);
const item = await coerceRows(arrayUpToLimit, columns, options);
const columnsGuessed = externalColumns ?? guessColumns(arrayUpToLimit, options);
const item = await coerceRows(arrayUpToLimit, columnsGuessed, options);

const data = item.map((row, index) => {
const original = Array.isArray(arrayUpToLimit[index])
Expand All @@ -39,5 +55,7 @@ export const parseJson = async (options: ParseJsonOptions, startIndex = 0): Prom
};
});

const columns = prepareColumns(columnsGuessed, data);

return { success: true, columns, warnings: [], data };
};
Loading

0 comments on commit 9865f12

Please sign in to comment.