Skip to content

Latest commit

 

History

History
136 lines (108 loc) · 5.13 KB

dataset_creation.md

File metadata and controls

136 lines (108 loc) · 5.13 KB

Creating a CSV Reading Dataset

Introduction

In this document, you learn how to create a dataset class that can prepare data from a .csv file to be consumed during batch training. To learn more about training, refer to the training document.

This example uses a .csv file to store the data. The .csv is available under this third-party repository.

The CSV file has the following format.

url isMalicious
sample.url.good.com 0
sample.url.bad.com 1

DJL dataset blueprints

DJL by default provides abstract classes for different styles of Datasets.

  1. RandomAccessDataset - A dataset that supports random access of data it using indices.
  2. AbstractImageFolder - A dataset to load images from folder structures.
  3. ArrayDataSet - An array based extension of RandomAccessDataset.

To read the CSV file, implement a CSVDataset class that extends RandomAccessDataset. The Dataset APIs follow the builder pattern.

The CSVDataset

The CSVDataset definition looks like the following.

public class CSVDataset extends RandomAccessDataset {
    private static final int FEATURE_LENGTH = 1014;
    private static final String ALL_CHARS = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+ =<>()[]{}";
    private List<Character> alphabets;
    private Map<Character, Integer> alphabetsIndex;
    private List<CSVRecord> dataset;

    private Shape initializeShape;

The CSVDataset class defines the parameters needed to process the input CSV entry into an encoded NDArray.

Every RandomAccessDataSet extension needs to implement a per-index getter method. The getter method returns a Record object that consists of an encoded input and label.

   @Override
   public Record get(NDManager manager, long index) {
       NDList datum = new NDList();
       NDList label = new NDList();
       CSVRecord record = dataset.get(Math.toIntExact(index));
       // Get a single data, label pair, encode them using helpers
       datum.add(encodeData(manager, record.get("url")));
       label.add(encodeLabel(manager, record.get("isMalicious")));
       return new Record(datum, label);
   }

The encodeData() method encodes the input text into NDArrays. The following example implements a one-hot encoding based on the work described in Character-level Convolutional Networks for Text Classification.

private NDArray encodeData(NDManager manager, String url) {
    NDArray encoded = manager.zeros(new Shape(alphabets.size(), FEATURE_LENGTH));
    char[] arrayText = url.toCharArray();
    for (int i = 0; i < url.length(); i++) {
        if (i > FEATURE_LENGTH) {
            break;
        }
        if (alphabetsIndex.containsKey(arrayText[i])) {
            encoded.set(new NDIndex(alphabetsIndex.get(arrayText[i]), i), 1);
        }
    }
    return encoded;
}

Define a builder class, which initializes the dataset object for TRAIN or TEST subsets.

public static final class Builder extends BaseBuilder<Builder> {
        private Usage usage;
        List<CSVRecord> dataset;

        Builder() {
            this.usage = Usage.TRAIN;
        }

        protected Builder self() {
            return this;
        }

        Builder optUsage(Usage usage) {
            this.usage = usage;
            return this;
        }

        CSVDataset build() throws IOException {
            String csvFileLocation = "src/main/resources/malicious_url_data.csv";
            try (Reader reader = Files.newBufferedReader(Paths.get(csvFileLocation));
                    CSVParser csvParser =
                            new CSVParser(
                                    reader,
                                    CSVFormat.DEFAULT
                                            .withHeader("url", "isMalicious")
                                            .withFirstRecordAsHeader()
                                            .withIgnoreHeaderCase()
                                            .withTrim())) {
                List<CSVRecord> csvRecords = csvParser.getRecords();
                int index = (int)(csvRecords.size() * 0.8);
                // split the dataset into training and testing
                switch (usage) {
                    case TRAIN: {
                        dataset = csvRecords.subList(0, index);
                        break;
                    }
                    case TEST: {
                        dataset = csvRecords.subList(index, csvRecords.size());
                        break;
                    }
                }
                return new CSVDataset(this);
            }
        }
    }

The following code illustrates a typical call flow to declare a dataset object based on CSVDataset.

// For train subset
int batchSize = 128;
CSVDataset csvDataset =
       new CSVDataset.Builder().optUsage(Usage.TRAIN).setSampling(batchSize, true).build();

After this, pass the dataset object to the trainer object. For more information, see the training documentation.