diff --git a/.gitignore b/.gitignore index 3e50bf2..9dc1a1c 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ pnpm-debug.log* dist docs coverage +shrinkwrap.yaml diff --git a/README.md b/README.md index 8501609..e7f04f5 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,10 @@ const { } = loadCsv('./data.csv', { featureColumns: ['lat', 'lng', 'height'], labelColumns: ['temperature'], + mappings: { + height: (ft) => ft * 0.3048, // feet to meters + temperature: (f) => (f - 32) / 1.8, // fahrenheit to celsius + }, // Map values based on which column they are in before they are loaded into tensors. shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use it as a seed for the shuffling. splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data. prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for linear regression. diff --git a/src/applyMappings.ts b/src/applyMappings.ts new file mode 100644 index 0000000..efd250d --- /dev/null +++ b/src/applyMappings.ts @@ -0,0 +1,22 @@ +import { CsvTable, CsvReadOptions } from './loadCsv.models'; + +const applyMappings = ( + table: CsvTable, + mappings: NonNullable +) => { + if (table.length < 2) { + return table; + } + const mappingsByIndex = table[0].map((columnName) => mappings[columnName]); + return table.map((row, index) => + index === 0 + ? row + : row.map((value, columnIndex) => + mappingsByIndex[columnIndex] + ? mappingsByIndex[columnIndex](value) + : value + ) + ); +}; + +export default applyMappings; diff --git a/src/loadCsv.models.ts b/src/loadCsv.models.ts index 0ffb99a..2a22c50 100644 --- a/src/loadCsv.models.ts +++ b/src/loadCsv.models.ts @@ -7,6 +7,13 @@ export interface CsvReadOptions { * Names of the columns to be included in the labels and testLabels tensors. */ labelColumns: string[]; + /** + * Used for transforming values of entire columns. Key is column label, value is transformer function. Each value belonging to + * that column will be put through the transformer function and be overwritten with the return value of it. + */ + mappings?: { + [columnName: string]: (value: string | number) => string | number; + }; /** * If true, shuffles all rows with a fixed seed, meaning that shuffling the same data will always result in the same shuffled data. * diff --git a/src/loadCsv.ts b/src/loadCsv.ts index b7925c7..aaab508 100644 --- a/src/loadCsv.ts +++ b/src/loadCsv.ts @@ -6,6 +6,7 @@ import { shuffle } from 'shuffle-seed'; import { CsvReadOptions, CsvTable } from './loadCsv.models'; import filterColumns from './filterColumns'; import splitTestData from './splitTestData'; +import applyMappings from './applyMappings'; const defaultShuffleSeed = 'mncv9340ur'; @@ -13,6 +14,7 @@ const loadCsv = (filename: string, options: CsvReadOptions) => { const { featureColumns, labelColumns, + mappings = {}, shuffle: shouldShuffle = false, splitTest = false, prependOnes = false, @@ -43,6 +45,10 @@ const loadCsv = (filename: string, options: CsvReadOptions) => { tables.labels.shift(); tables.features.shift(); + for (const key of Object.keys(tables)) { + tables[key] = applyMappings(tables[key], mappings); + } + if (shouldShuffle) { const seed = typeof shouldShuffle === 'string' ? shouldShuffle : defaultShuffleSeed; diff --git a/tests/applyMappings.test.ts b/tests/applyMappings.test.ts new file mode 100644 index 0000000..efa0671 --- /dev/null +++ b/tests/applyMappings.test.ts @@ -0,0 +1,38 @@ +/* eslint-disable @typescript-eslint/ban-ts-comment */ +import { CsvReadOptions } from '../src/loadCsv.models'; +import applyMappings from '../src/applyMappings'; + +const data = [ + ['lat', 'lng', 'height', 'temperature'], + [0.234, 1.47, 849.7, 64.4], + [-293.2, 103.34, 715.2, 73.4], +]; + +const mappings: NonNullable = { + height: (ft) => Number(ft) * 0.3048, // feet to meters + temperature: (f) => (Number(f) - 32) / 1.8, // fahrenheit to celsius +}; + +test('Applying mappings works correctly', () => { + const mappedData = applyMappings(data, mappings); + // @ts-ignore + expect(mappedData).toBeDeepCloseTo( + [ + ['lat', 'lng', 'height', 'temperature'], + [0.234, 1.47, 258.98856, 18], + [-293.2, 103.34, 217.99296, 23], + ], + 3 + ); +}); + +test('Applying mappings does not break with a table with just headers', () => { + const tableOnlyHeaders = [['lat', 'lng', 'height', 'temperature']]; + const mappedData = applyMappings(tableOnlyHeaders, mappings); + expect(mappedData).toMatchObject(tableOnlyHeaders); +}); + +test('Applying mappings does not break with an empty table', () => { + const mappedData = applyMappings([], mappings); + expect(mappedData).toMatchObject([]); +});