Skip to content

Commit

Permalink
feat(lib): Allow per-column standardisation (#202)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: standardise does not allow boolean values anymore,
meand and variance are no longer returned
  • Loading branch information
isair authored Mar 8, 2021
1 parent 34a5794 commit 93c9296
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 102 deletions.
20 changes: 5 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,23 @@ Advanced usage:
```js
import loadCsv from 'tensorflow-load-csv';

const {
features,
labels,
testFeatures,
testLabels,
mean, // tensor holding mean of features, ignores testFeatures
variance, // tensor holding variance of features, ignores testFeatures
} = loadCsv('./data.csv', {
const { features, labels, testFeatures, testLabels } = loadCsv('./data.csv', {
featureColumns: ['lat', 'lng', 'height'],
labelColumns: ['temperature'],
mappings: {
height: (ft) => ft * 0.3048, // feet to meters
temperature: (f) => (f < 50 ? [1, 0] : [0, 1]), // cold or hot classification
}, // Map values based on which column they are in before they are loaded into tensors.
flatten: ['temperature'], // Flattens the array result of a mapping so that each member is a new column.
shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use it as a seed for the shuffling.
splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data, or a percentage string (e.g. 10%).
prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for linear regression.
standardise: true, // Calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use as a seed for the shuffling.
splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data, or a percentage string (e.g. '10%').
standardise: ['height'], // Calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for regression problems.
});

features.print();
labels.print();

testFeatures.print();
testLabels.print();

mean.print();
variance.print();
```
8 changes: 4 additions & 4 deletions jest.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ module.exports = {
coveragePathIgnorePatterns: ['/node_modules/', '/tests/'],
coverageThreshold: {
global: {
branches: 90,
functions: 95,
lines: 95,
statements: 95,
branches: 100,
functions: 100,
lines: 100,
statements: 100,
},
},
collectCoverageFrom: ['src/*.{js,ts}'],
Expand Down
4 changes: 2 additions & 2 deletions src/loadCsv.models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ export interface CsvReadOptions {
*/
prependOnes?: boolean;
/**
* If true, calculates mean and variance for each feature column using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
* Calculates mean and variance for given columns using data only in features, then standardises the values in features and testFeatures. Does not touch labels.
*/
standardise?: boolean | string[];
standardise?: string[];
/**
* Useful for classification problems, if you have mapped a column's values to an array using `mappings`, you can choose to flatten it here so that each element becomes a new column.
*
Expand Down
28 changes: 16 additions & 12 deletions src/loadCsv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import filterColumns from './filterColumns';
import splitTestData from './splitTestData';
import applyMappings from './applyMappings';
import shuffle from './shuffle';
import standardise from './standardise';

const defaultShuffleSeed = 'mncv9340ur';

Expand All @@ -31,10 +32,10 @@ const loadCsv = (
featureColumns,
labelColumns,
mappings = {},
shuffle: shouldShuffle = false,
shuffle: shouldShuffleOrSeed = false,
splitTest,
prependOnes = false,
standardise = false,
standardise: columnsToStandardise = [],
flatten = [],
}: CsvReadOptions
) => {
Expand All @@ -54,11 +55,13 @@ const loadCsv = (
};

tables.labels.shift();
tables.features.shift();
const featureColumnNames = tables.features.shift() as string[];

if (shouldShuffle) {
if (shouldShuffleOrSeed) {
const seed =
typeof shouldShuffle === 'string' ? shouldShuffle : defaultShuffleSeed;
typeof shouldShuffleOrSeed === 'string'
? shouldShuffleOrSeed
: defaultShuffleSeed;
tables.features = shuffle(tables.features, seed);
tables.labels = shuffle(tables.labels, seed);
}
Expand All @@ -76,11 +79,14 @@ const loadCsv = (
const labels = tf.tensor(tables.labels);
const testLabels = tf.tensor(tables.testLabels);

const { mean, variance } = tf.moments(features, 0);

if (standardise) {
features = features.sub(mean).div(variance.pow(0.5));
testFeatures = testFeatures.sub(mean).div(variance.pow(0.5));
if (columnsToStandardise.length > 0) {
const result = standardise(
features,
testFeatures,
featureColumnNames.map((c) => columnsToStandardise.includes(c))
);
features = result.features;
testFeatures = result.testFeatures;
}

if (prependOnes) {
Expand All @@ -93,8 +99,6 @@ const loadCsv = (
labels,
testFeatures,
testLabels,
mean,
variance,
};
};

Expand Down
68 changes: 68 additions & 0 deletions src/standardise.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import * as tf from '@tensorflow/tfjs';

const standardise = (
features: tf.Tensor<tf.Rank>,
testFeatures: tf.Tensor<tf.Rank>,
indicesToStandardise: boolean[]
): {
features: tf.Tensor<tf.Rank>;
testFeatures: tf.Tensor<tf.Rank>;
} => {
let newFeatures, newTestFeatures;

if (features.shape.length < 2 || testFeatures.shape.length < 2) {
throw new Error(
'features and testFeatures must have at least two dimensions'
);
}

if (features.shape[1] !== testFeatures.shape[1]) {
throw new Error(
'Length of the second dimension of features and testFeatures must be the same'
);
}

if (features.shape[1] !== indicesToStandardise.length) {
throw new Error(
'Length of indicesToStandardise must match the length of the second dimension of features'
);
}

if (features.shape[1] === 0) {
return { features, testFeatures };
}

for (let i = 0; i < features.shape[1]; i++) {
let featureSlice = features.slice([0, i], [features.shape[0], 1]);
let testFeatureSlice = testFeatures.slice(
[0, i],
[testFeatures.shape[0], 1]
);
if (indicesToStandardise[i]) {
const sliceMoments = tf.moments(featureSlice);
featureSlice = featureSlice
.sub(sliceMoments.mean)
.div(sliceMoments.variance.pow(0.5));
testFeatureSlice = testFeatureSlice
.sub(sliceMoments.mean)
.div(sliceMoments.variance.pow(0.5));
}
if (!newFeatures) {
newFeatures = featureSlice;
} else {
newFeatures = newFeatures.concat(featureSlice, 1);
}
if (!newTestFeatures) {
newTestFeatures = testFeatureSlice;
} else {
newTestFeatures = newTestFeatures.concat(testFeatureSlice, 1);
}
}

return {
features: newFeatures as tf.Tensor<tf.Rank>,
testFeatures: newTestFeatures as tf.Tensor<tf.Rank>,
};
};

export default standardise;
97 changes: 28 additions & 69 deletions tests/loadCsv.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,95 +46,54 @@ test('Loading with only the required options should work', () => {
]);
});

test('Shuffling should work and preserve feature - label pairs', () => {
const { features, labels } = loadCsv(filePath, {
test('Loading with all extra options should work', () => {
const { features, labels, testFeatures, testLabels } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
mappings: {
country: (name) => (name as string).toUpperCase(),
lat: (lat) => ((lat as number) > 0 ? [0, 1] : [1, 0]), // South or North classification
},
flatten: ['lat'],
shuffle: true,
splitTest: true,
prependOnes: true,
standardise: ['lng'],
});
// @ts-ignore
expect(features.arraySync()).toBeDeepCloseTo(
[
[5, 40.34],
[0.234, 1.47],
[-93.2, 103.34],
[102, -164],
[1, 0, 1, 1],
[1, 0, 1, -1],
],
3
);
expect(labels.arraySync()).toMatchObject([
['Landistan'],
['SomeCountria'],
['SomeOtherCountria'],
['Landotzka'],
]);
});

test('Shuffling with a custom seed should work', () => {
const { features, labels } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
shuffle: 'hello-is-it-me-you-are-looking-for',
});
expect(labels.arraySync()).toMatchObject([['LANDISTAN'], ['SOMECOUNTRIA']]);
// @ts-ignore
expect(features.arraySync()).toBeDeepCloseTo(
expect(testFeatures.arraySync()).toBeDeepCloseTo(
[
[-93.2, 103.34],
[102, -164],
[5, 40.34],
[0.234, 1.47],
[1, 1, 0, 4.241],
[1, 0, 1, -9.514],
],
3
);
expect(labels.arraySync()).toMatchObject([
['SomeOtherCountria'],
['Landotzka'],
['Landistan'],
['SomeCountria'],
expect(testLabels.arraySync()).toMatchObject([
['SOMEOTHERCOUNTRIA'],
['LANDOTZKA'],
]);
});

test('Loading with all extra options other than shuffle as true should work', () => {
const {
features,
labels,
testFeatures,
testLabels,
mean,
variance,
} = loadCsv(filePath, {
test('Loading with custom seed should use the custom seed', () => {
const { features } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
mappings: {
country: (name) => (name as string).toUpperCase(),
},
splitTest: true,
prependOnes: true,
standardise: true,
shuffle: true,
});
const { features: featuresCustom } = loadCsv(filePath, {
featureColumns: ['lat', 'lng'],
labelColumns: ['country'],
shuffle: 'sdhjhdf',
});
// @ts-ignore
expect(features.arraySync()).toBeDeepCloseTo(
[
[1, 1, -1],
[1, -1, 1],
],
3
);
expect(labels.arraySync()).toMatchObject([
['SOMECOUNTRIA'],
['SOMEOTHERCOUNTRIA'],
]);
// @ts-ignore
expect(testFeatures.arraySync()).toBeDeepCloseTo(
[
[1, 1.102, -0.236],
[1, 3.178, -4.248],
],
3
);
expect(testLabels.arraySync()).toMatchObject([['LANDISTAN'], ['LANDOTZKA']]);
// @ts-ignore
expect(mean.arraySync()).toBeDeepCloseTo([-46.482, 52.404], 3);
// @ts-ignore
expect(variance.arraySync()).toBeDeepCloseTo([2182.478, 2594.374], 3);
expect(features).not.toBeDeepCloseTo(featuresCustom, 1);
});
Loading

0 comments on commit 93c9296

Please sign in to comment.