-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add test case for load strategy wild card matching
- Loading branch information
1 parent
7c48892
commit 940b44d
Showing
1 changed file
with
116 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import unittest | ||
from data_juicer.core.data.load_strategy import ( | ||
DataLoadStrategyRegistry, DataLoadStrategy, StrategyKey | ||
) | ||
|
||
class MockStrategy(DataLoadStrategy): | ||
def load_data(self): | ||
pass | ||
|
||
class TestDataLoadStrategyRegistry(unittest.TestCase): | ||
def setUp(self): | ||
# Clear existing strategies before each test | ||
DataLoadStrategyRegistry._strategies = {} | ||
|
||
def test_exact_match(self): | ||
# Register a specific strategy | ||
@DataLoadStrategyRegistry.register('local', 'ondisk', 'json') | ||
class TestStrategy(MockStrategy): | ||
pass | ||
|
||
# Test exact match | ||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') | ||
self.assertEqual(strategy, TestStrategy) | ||
|
||
# Test no match | ||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv') | ||
self.assertIsNone(strategy) | ||
|
||
def test_wildcard_matching(self): | ||
# Register strategies with different wildcard patterns | ||
@DataLoadStrategyRegistry.register('local', 'ondisk', '*') | ||
class AllFilesStrategy(MockStrategy): | ||
pass | ||
|
||
@DataLoadStrategyRegistry.register('local', '*', '*') | ||
class AllLocalStrategy(MockStrategy): | ||
pass | ||
|
||
@DataLoadStrategyRegistry.register('*', '*', '*') | ||
class FallbackStrategy(MockStrategy): | ||
pass | ||
|
||
# Test specific matches | ||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') | ||
self.assertEqual(strategy, AllFilesStrategy) # Should match most specific wildcard | ||
|
||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json') | ||
self.assertEqual(strategy, AllLocalStrategy) # Should match second level wildcard | ||
|
||
strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json') | ||
self.assertEqual(strategy, FallbackStrategy) # Should match most general wildcard | ||
|
||
def test_specificity_priority(self): | ||
@DataLoadStrategyRegistry.register('*', '*', '*') | ||
class GeneralStrategy(MockStrategy): | ||
pass | ||
|
||
@DataLoadStrategyRegistry.register('local', '*', '*') | ||
class LocalStrategy(MockStrategy): | ||
pass | ||
|
||
@DataLoadStrategyRegistry.register('local', 'ondisk', '*') | ||
class LocalOndiskStrategy(MockStrategy): | ||
pass | ||
|
||
@DataLoadStrategyRegistry.register('local', 'ondisk', 'json') | ||
class ExactStrategy(MockStrategy): | ||
pass | ||
|
||
# Test matching priority | ||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') | ||
self.assertEqual(strategy, ExactStrategy) # Should match exact first | ||
|
||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv') | ||
self.assertEqual(strategy, LocalOndiskStrategy) # Should match one wildcard | ||
|
||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json') | ||
self.assertEqual(strategy, LocalStrategy) # Should match two wildcards | ||
|
||
strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json') | ||
self.assertEqual(strategy, GeneralStrategy) # Should match general wildcard | ||
|
||
def test_pattern_matching(self): | ||
@DataLoadStrategyRegistry.register('local', 'ondisk', '*.json') | ||
class JsonStrategy(MockStrategy): | ||
pass | ||
|
||
@DataLoadStrategyRegistry.register('local', 'ondisk', 'data_[0-9]*') | ||
class NumberedDataStrategy(MockStrategy): | ||
pass | ||
|
||
# Test pattern matching | ||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.json') | ||
self.assertEqual(strategy, JsonStrategy) | ||
|
||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'data_123') | ||
self.assertEqual(strategy, NumberedDataStrategy) | ||
|
||
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.csv') | ||
self.assertIsNone(strategy) | ||
|
||
def test_strategy_key_matches(self): | ||
# Test StrategyKey matching directly | ||
wildcard_key = StrategyKey('*', 'ondisk', '*.json') | ||
specific_key = StrategyKey('local', 'ondisk', 'test.json') | ||
|
||
self.assertTrue(wildcard_key.matches(specific_key)) | ||
self.assertFalse(specific_key.matches(wildcard_key)) # Exact keys don't match wildcards | ||
|
||
# Test pattern matching | ||
pattern_key = StrategyKey('local', '*', 'data_[0-9]*') | ||
match_key = StrategyKey('local', 'ondisk', 'data_123') | ||
no_match_key = StrategyKey('local', 'ondisk', 'data_abc') | ||
|
||
self.assertTrue(pattern_key.matches(match_key)) | ||
self.assertFalse(pattern_key.matches(no_match_key)) |