Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduplication of text chunks with frequency count, training and encoding 5x speedup #82

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

Majdoddin
Copy link

@Majdoddin Majdoddin commented Jun 8, 2024

In RegexTokenizer, the training text is initially split into chunks, and further processing is performed on individual chunks. This PR optimizes the process by retaining only unique chunks and their corresponding frequency counts. Practically this cuts the number of chunks to 1/7th, resulting in a training speedup of at least 5x.

Similar optimization for encode_ordinary(), where the tokenization of each string is cached. Also 5x speedup.

@Majdoddin Majdoddin changed the title Deduplication of text chunks with frequency count, 5x training speedup Deduplication of text chunks with frequency count, training and encoding 5x speedup Jun 9, 2024
Copy link

@ae99 ae99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Independently arrived at this same change myself!

Took training of a 8192 size tokenizer on ~3m words from 8 hours, down to 20mins. I'll likely now re-train on ~1b words given this makes training almost entirely independent of dataset size after the regex is complete. Makes this repo production viable!

@@ -41,17 +41,26 @@ def train(self, text, vocab_size, verbose=False):
text_chunks = re.findall(self.compiled_pattern, text)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Faster to do the counting ahead of converting to unicode

Suggested change
text_chunks = re.findall(self.compiled_pattern, text)
text_chunks = re.findall(self.compiled_pattern, text)
chunks_counted = collections.Counter(text_chunks)
text_chunks = [chunk for chunk, count in chunks_counted.items()]
global_counts = [count for chunk, count in chunks_counted.items()]

Then further down we can just go:

            for chunk_ids, global_count in zip(ids, global_counts):
                # passing in stats will update it in place, adding up counts
                get_stats_n(chunk_ids, stats, global_count)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants