From 940b44d07364adc963070abf85a6643e00beffb2 Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 11 Dec 2024 11:10:56 -0800 Subject: [PATCH] add test case for load strategy wild card matching --- tests/core/test_dataload_strategy.py | 116 +++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tests/core/test_dataload_strategy.py diff --git a/tests/core/test_dataload_strategy.py b/tests/core/test_dataload_strategy.py new file mode 100644 index 000000000..b3f489cfb --- /dev/null +++ b/tests/core/test_dataload_strategy.py @@ -0,0 +1,116 @@ +import unittest +from data_juicer.core.data.load_strategy import ( + DataLoadStrategyRegistry, DataLoadStrategy, StrategyKey +) + +class MockStrategy(DataLoadStrategy): + def load_data(self): + pass + +class TestDataLoadStrategyRegistry(unittest.TestCase): + def setUp(self): + # Clear existing strategies before each test + DataLoadStrategyRegistry._strategies = {} + + def test_exact_match(self): + # Register a specific strategy + @DataLoadStrategyRegistry.register('local', 'ondisk', 'json') + class TestStrategy(MockStrategy): + pass + + # Test exact match + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + self.assertEqual(strategy, TestStrategy) + + # Test no match + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv') + self.assertIsNone(strategy) + + def test_wildcard_matching(self): + # Register strategies with different wildcard patterns + @DataLoadStrategyRegistry.register('local', 'ondisk', '*') + class AllFilesStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', '*', '*') + class AllLocalStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('*', '*', '*') + class FallbackStrategy(MockStrategy): + pass + + # Test specific matches + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + self.assertEqual(strategy, AllFilesStrategy) # Should match most specific wildcard + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json') + self.assertEqual(strategy, AllLocalStrategy) # Should match second level wildcard + + strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json') + self.assertEqual(strategy, FallbackStrategy) # Should match most general wildcard + + def test_specificity_priority(self): + @DataLoadStrategyRegistry.register('*', '*', '*') + class GeneralStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', '*', '*') + class LocalStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', 'ondisk', '*') + class LocalOndiskStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', 'ondisk', 'json') + class ExactStrategy(MockStrategy): + pass + + # Test matching priority + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'json') + self.assertEqual(strategy, ExactStrategy) # Should match exact first + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'csv') + self.assertEqual(strategy, LocalOndiskStrategy) # Should match one wildcard + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'remote', 'json') + self.assertEqual(strategy, LocalStrategy) # Should match two wildcards + + strategy = DataLoadStrategyRegistry.get_strategy_class('ray', 'remote', 'json') + self.assertEqual(strategy, GeneralStrategy) # Should match general wildcard + + def test_pattern_matching(self): + @DataLoadStrategyRegistry.register('local', 'ondisk', '*.json') + class JsonStrategy(MockStrategy): + pass + + @DataLoadStrategyRegistry.register('local', 'ondisk', 'data_[0-9]*') + class NumberedDataStrategy(MockStrategy): + pass + + # Test pattern matching + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.json') + self.assertEqual(strategy, JsonStrategy) + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'data_123') + self.assertEqual(strategy, NumberedDataStrategy) + + strategy = DataLoadStrategyRegistry.get_strategy_class('local', 'ondisk', 'test.csv') + self.assertIsNone(strategy) + + def test_strategy_key_matches(self): + # Test StrategyKey matching directly + wildcard_key = StrategyKey('*', 'ondisk', '*.json') + specific_key = StrategyKey('local', 'ondisk', 'test.json') + + self.assertTrue(wildcard_key.matches(specific_key)) + self.assertFalse(specific_key.matches(wildcard_key)) # Exact keys don't match wildcards + + # Test pattern matching + pattern_key = StrategyKey('local', '*', 'data_[0-9]*') + match_key = StrategyKey('local', 'ondisk', 'data_123') + no_match_key = StrategyKey('local', 'ondisk', 'data_abc') + + self.assertTrue(pattern_key.matches(match_key)) + self.assertFalse(pattern_key.matches(no_match_key))