Skip to content

Commit

Permalink
add test case for load strategy wild card matching
Browse files Browse the repository at this point in the history
  • Loading branch information
cyruszhang committed Dec 11, 2024
1 parent 7c48892 commit 940b44d
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions tests/core/test_dataload_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import unittest
from data_juicer.core.data.load_strategy import (
DataLoadStrategyRegistry, DataLoadStrategy, StrategyKey
)

class MockStrategy(DataLoadStrategy):
def load_data(self):
pass

class TestDataLoadStrategyRegistry(unittest.TestCase):
def setUp(self):
# Clear existing strategies before each test
DataLoadStrategyRegistry._strategies = {}

def test_exact_match(self):
# Register a specific strategy
@DataLoadStrategyRegistry.register('local', 'ondisk', 'json')
class TestStrategy(MockStrategy):
pass

# Test exact match
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json')
self.assertEqual(strategy, TestStrategy)

# Test no match
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv')
self.assertIsNone(strategy)

def test_wildcard_matching(self):
# Register strategies with different wildcard patterns
@DataLoadStrategyRegistry.register('local', 'ondisk', '*')
class AllFilesStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register('local', '*', '*')
class AllLocalStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register('*', '*', '*')
class FallbackStrategy(MockStrategy):
pass

# Test specific matches
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json')
self.assertEqual(strategy, AllFilesStrategy) # Should match most specific wildcard

strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json')
self.assertEqual(strategy, AllLocalStrategy) # Should match second level wildcard

strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json')
self.assertEqual(strategy, FallbackStrategy) # Should match most general wildcard

def test_specificity_priority(self):
@DataLoadStrategyRegistry.register('*', '*', '*')
class GeneralStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register('local', '*', '*')
class LocalStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register('local', 'ondisk', '*')
class LocalOndiskStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register('local', 'ondisk', 'json')
class ExactStrategy(MockStrategy):
pass

# Test matching priority
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json')
self.assertEqual(strategy, ExactStrategy) # Should match exact first

strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv')
self.assertEqual(strategy, LocalOndiskStrategy) # Should match one wildcard

strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json')
self.assertEqual(strategy, LocalStrategy) # Should match two wildcards

strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json')
self.assertEqual(strategy, GeneralStrategy) # Should match general wildcard

def test_pattern_matching(self):
@DataLoadStrategyRegistry.register('local', 'ondisk', '*.json')
class JsonStrategy(MockStrategy):
pass

@DataLoadStrategyRegistry.register('local', 'ondisk', 'data_[0-9]*')
class NumberedDataStrategy(MockStrategy):
pass

# Test pattern matching
strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.json')
self.assertEqual(strategy, JsonStrategy)

strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'data_123')
self.assertEqual(strategy, NumberedDataStrategy)

strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.csv')
self.assertIsNone(strategy)

def test_strategy_key_matches(self):
# Test StrategyKey matching directly
wildcard_key = StrategyKey('*', 'ondisk', '*.json')
specific_key = StrategyKey('local', 'ondisk', 'test.json')

self.assertTrue(wildcard_key.matches(specific_key))
self.assertFalse(specific_key.matches(wildcard_key)) # Exact keys don't match wildcards

# Test pattern matching
pattern_key = StrategyKey('local', '*', 'data_[0-9]*')
match_key = StrategyKey('local', 'ondisk', 'data_123')
no_match_key = StrategyKey('local', 'ondisk', 'data_abc')

self.assertTrue(pattern_key.matches(match_key))
self.assertFalse(pattern_key.matches(no_match_key))

0 comments on commit 940b44d

Please sign in to comment.