From 17737cf63590f9fe1a0d7c4c01f1dfe45872bd8d Mon Sep 17 00:00:00 2001
From: Prudhvi Godithi <pgodithi@amazon.com>
Date: Sat, 26 Oct 2024 11:53:11 -0700
Subject: [PATCH] Add custom synonym_analyzer

Signed-off-by: Prudhvi Godithi <pgodithi@amazon.com>
---
 .../gradle/testclusters/OpenSearchNode.java   | 12 ++--
 .../common/CommonAnalysisModulePlugin.java    | 33 ++++++++-
 .../SynonymGraphTokenFilterFactory.java       | 11 ++-
 .../common/SynonymTokenFilterFactory.java     | 28 +++++++-
 .../indices/analysis/AnalysisModule.java      |  7 +-
 .../opensearch/plugins/AnalysisPlugin.java    |  9 +++
 .../indices/analysis/AnalysisModuleTests.java | 69 +++++++++++++++++++
 7 files changed, 157 insertions(+), 12 deletions(-)

diff --git a/buildSrc/src/main/java/org/opensearch/gradle/testclusters/OpenSearchNode.java b/buildSrc/src/main/java/org/opensearch/gradle/testclusters/OpenSearchNode.java
index cd22560af9a96..bb409c2afd871 100644
--- a/buildSrc/src/main/java/org/opensearch/gradle/testclusters/OpenSearchNode.java
+++ b/buildSrc/src/main/java/org/opensearch/gradle/testclusters/OpenSearchNode.java
@@ -1216,14 +1216,18 @@ private void createConfiguration() {
             );
 
             final List<Path> configFiles;
-            try (Stream<Path> stream = Files.list(getDistroDir().resolve("config"))) {
+            try (Stream<Path> stream = Files.walk(getDistroDir().resolve("config"))) {
                 configFiles = stream.collect(Collectors.toList());
             }
             logToProcessStdout("Copying additional config files from distro " + configFiles);
             for (Path file : configFiles) {
-                Path dest = configFile.getParent().resolve(file.getFileName());
-                if (Files.exists(dest) == false) {
-                    Files.copy(file, dest);
+                Path relativePath = getDistroDir().resolve("config").relativize(file);
+                Path dest = configFile.getParent().resolve(relativePath);
+                if (Files.isDirectory(file)) {
+                    Files.createDirectories(dest);
+                } else {
+                    Files.createDirectories(dest.getParent());
+                    Files.copy(file, dest, StandardCopyOption.REPLACE_EXISTING);
                 }
             }
         } catch (IOException e) {
diff --git a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/CommonAnalysisModulePlugin.java b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/CommonAnalysisModulePlugin.java
index f14e499081ce9..763c1783c2d28 100644
--- a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/CommonAnalysisModulePlugin.java
+++ b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/CommonAnalysisModulePlugin.java
@@ -146,6 +146,7 @@
 import org.opensearch.index.analysis.PreConfiguredTokenizer;
 import org.opensearch.index.analysis.TokenFilterFactory;
 import org.opensearch.index.analysis.TokenizerFactory;
+import org.opensearch.indices.analysis.AnalysisModule;
 import org.opensearch.indices.analysis.AnalysisModule.AnalysisProvider;
 import org.opensearch.indices.analysis.PreBuiltCacheFactory.CachingStrategy;
 import org.opensearch.plugins.AnalysisPlugin;
@@ -332,8 +333,6 @@ public Map<String, AnalysisProvider<TokenFilterFactory>> getTokenFilters() {
         filters.put("sorani_normalization", SoraniNormalizationFilterFactory::new);
         filters.put("stemmer_override", requiresAnalysisSettings(StemmerOverrideTokenFilterFactory::new));
         filters.put("stemmer", StemmerTokenFilterFactory::new);
-        filters.put("synonym", requiresAnalysisSettings(SynonymTokenFilterFactory::new));
-        filters.put("synonym_graph", requiresAnalysisSettings(SynonymGraphTokenFilterFactory::new));
         filters.put("trim", TrimTokenFilterFactory::new);
         filters.put("truncate", requiresAnalysisSettings(TruncateTokenFilterFactory::new));
         filters.put("unique", UniqueTokenFilterFactory::new);
@@ -343,6 +342,36 @@ public Map<String, AnalysisProvider<TokenFilterFactory>> getTokenFilters() {
         return filters;
     }
 
+    @Override
+    public Map<String, AnalysisProvider<TokenFilterFactory>> getTokenFilters(AnalysisModule analysisModule) {
+        Map<String, AnalysisProvider<TokenFilterFactory>> filters = getTokenFilters();
+        filters.put(
+            "synonym",
+            requiresAnalysisSettings(
+                (indexSettings, environment, name, settings) -> new SynonymTokenFilterFactory(
+                    indexSettings,
+                    environment,
+                    name,
+                    settings,
+                    analysisModule.getAnalysisRegistry()
+                )
+            )
+        );
+        filters.put(
+            "synonym_graph",
+            requiresAnalysisSettings(
+                (indexSettings, environment, name, settings) -> new SynonymGraphTokenFilterFactory(
+                    indexSettings,
+                    environment,
+                    name,
+                    settings,
+                    analysisModule.getAnalysisRegistry()
+                )
+            )
+        );
+        return filters;
+    }
+
     @Override
     public Map<String, AnalysisProvider<CharFilterFactory>> getCharFilters() {
         Map<String, AnalysisProvider<CharFilterFactory>> filters = new TreeMap<>();
diff --git a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/SynonymGraphTokenFilterFactory.java b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/SynonymGraphTokenFilterFactory.java
index fed959108c411..c2e20e99473de 100644
--- a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/SynonymGraphTokenFilterFactory.java
+++ b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/SynonymGraphTokenFilterFactory.java
@@ -40,6 +40,7 @@
 import org.opensearch.env.Environment;
 import org.opensearch.index.IndexSettings;
 import org.opensearch.index.analysis.AnalysisMode;
+import org.opensearch.index.analysis.AnalysisRegistry;
 import org.opensearch.index.analysis.CharFilterFactory;
 import org.opensearch.index.analysis.TokenFilterFactory;
 import org.opensearch.index.analysis.TokenizerFactory;
@@ -49,8 +50,14 @@
 
 public class SynonymGraphTokenFilterFactory extends SynonymTokenFilterFactory {
 
-    SynonymGraphTokenFilterFactory(IndexSettings indexSettings, Environment env, String name, Settings settings) {
-        super(indexSettings, env, name, settings);
+    SynonymGraphTokenFilterFactory(
+        IndexSettings indexSettings,
+        Environment env,
+        String name,
+        Settings settings,
+        AnalysisRegistry analysisRegistry
+    ) {
+        super(indexSettings, env, name, settings, analysisRegistry);
     }
 
     @Override
diff --git a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/SynonymTokenFilterFactory.java b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/SynonymTokenFilterFactory.java
index 01a65e87d7466..86417a579628c 100644
--- a/modules/analysis-common/src/main/java/org/opensearch/analysis/common/SynonymTokenFilterFactory.java
+++ b/modules/analysis-common/src/main/java/org/opensearch/analysis/common/SynonymTokenFilterFactory.java
@@ -44,11 +44,13 @@
 import org.opensearch.index.analysis.AbstractTokenFilterFactory;
 import org.opensearch.index.analysis.Analysis;
 import org.opensearch.index.analysis.AnalysisMode;
+import org.opensearch.index.analysis.AnalysisRegistry;
 import org.opensearch.index.analysis.CharFilterFactory;
 import org.opensearch.index.analysis.CustomAnalyzer;
 import org.opensearch.index.analysis.TokenFilterFactory;
 import org.opensearch.index.analysis.TokenizerFactory;
 
+import java.io.IOException;
 import java.io.Reader;
 import java.io.StringReader;
 import java.util.List;
@@ -64,8 +66,16 @@ public class SynonymTokenFilterFactory extends AbstractTokenFilterFactory {
     protected final Settings settings;
     protected final Environment environment;
     protected final AnalysisMode analysisMode;
-
-    SynonymTokenFilterFactory(IndexSettings indexSettings, Environment env, String name, Settings settings) {
+    private final String synonymAnalyzer;
+    private final AnalysisRegistry analysisRegistry;
+
+    SynonymTokenFilterFactory(
+        IndexSettings indexSettings,
+        Environment env,
+        String name,
+        Settings settings,
+        AnalysisRegistry analysisRegistry
+    ) {
         super(indexSettings, name, settings);
         this.settings = settings;
 
@@ -83,6 +93,8 @@ public class SynonymTokenFilterFactory extends AbstractTokenFilterFactory {
         boolean updateable = settings.getAsBoolean("updateable", false);
         this.analysisMode = updateable ? AnalysisMode.SEARCH_TIME : AnalysisMode.ALL;
         this.environment = env;
+        this.synonymAnalyzer = settings.get("synonym_analyzer", null);
+        this.analysisRegistry = analysisRegistry;
     }
 
     @Override
@@ -137,6 +149,17 @@ Analyzer buildSynonymAnalyzer(
         List<TokenFilterFactory> tokenFilters,
         Function<String, TokenFilterFactory> allFilters
     ) {
+        if (synonymAnalyzer != null) {
+            Analyzer customSynonymAnalyzer;
+            try {
+                customSynonymAnalyzer = analysisRegistry.getAnalyzer(synonymAnalyzer);
+            } catch (IOException e) {
+                throw new RuntimeException(e);
+            }
+            if (customSynonymAnalyzer != null) {
+                return customSynonymAnalyzer;
+            }
+        }
         return new CustomAnalyzer(
             tokenizer,
             charFilters.toArray(new CharFilterFactory[0]),
@@ -177,5 +200,4 @@ Reader getRulesFromSettings(Environment env) {
         }
         return rulesReader;
     }
-
 }
diff --git a/server/src/main/java/org/opensearch/indices/analysis/AnalysisModule.java b/server/src/main/java/org/opensearch/indices/analysis/AnalysisModule.java
index 0926d497087d1..dbb3035a18f74 100644
--- a/server/src/main/java/org/opensearch/indices/analysis/AnalysisModule.java
+++ b/server/src/main/java/org/opensearch/indices/analysis/AnalysisModule.java
@@ -165,7 +165,12 @@ public boolean requiresAnalysisSettings() {
             )
         );
 
-        tokenFilters.extractAndRegister(plugins, AnalysisPlugin::getTokenFilters);
+        for (AnalysisPlugin plugin : plugins) {
+            Map<String, AnalysisProvider<TokenFilterFactory>> filters = plugin.getTokenFilters(this);
+            for (Map.Entry<String, AnalysisProvider<TokenFilterFactory>> entry : filters.entrySet()) {
+                tokenFilters.register(entry.getKey(), entry.getValue());
+            }
+        }
         return tokenFilters;
     }
 
diff --git a/server/src/main/java/org/opensearch/plugins/AnalysisPlugin.java b/server/src/main/java/org/opensearch/plugins/AnalysisPlugin.java
index 53dcc916b244f..a7c4604a30553 100644
--- a/server/src/main/java/org/opensearch/plugins/AnalysisPlugin.java
+++ b/server/src/main/java/org/opensearch/plugins/AnalysisPlugin.java
@@ -47,6 +47,7 @@
 import org.opensearch.index.analysis.PreConfiguredTokenizer;
 import org.opensearch.index.analysis.TokenFilterFactory;
 import org.opensearch.index.analysis.TokenizerFactory;
+import org.opensearch.indices.analysis.AnalysisModule;
 import org.opensearch.indices.analysis.AnalysisModule.AnalysisProvider;
 
 import java.io.IOException;
@@ -84,6 +85,14 @@ default Map<String, AnalysisProvider<CharFilterFactory>> getCharFilters() {
         return emptyMap();
     }
 
+    /**
+     * Override to add additional {@link TokenFilter}s that need access to the AnalysisModule.
+     * The default implementation calls the existing getTokenFilters() method for backward compatibility.
+     */
+    default Map<String, AnalysisProvider<TokenFilterFactory>> getTokenFilters(AnalysisModule analysisModule) {
+        return getTokenFilters();
+    }
+
     /**
      * Override to add additional {@link TokenFilter}s. See {@link #requiresAnalysisSettings(AnalysisProvider)}
      * how to on get the configuration from the index.
diff --git a/server/src/test/java/org/opensearch/indices/analysis/AnalysisModuleTests.java b/server/src/test/java/org/opensearch/indices/analysis/AnalysisModuleTests.java
index c9e26d6d6159a..1a4b7bf831ff0 100644
--- a/server/src/test/java/org/opensearch/indices/analysis/AnalysisModuleTests.java
+++ b/server/src/test/java/org/opensearch/indices/analysis/AnalysisModuleTests.java
@@ -56,6 +56,8 @@
 import org.opensearch.index.analysis.CustomAnalyzer;
 import org.opensearch.index.analysis.IndexAnalyzers;
 import org.opensearch.index.analysis.MyFilterTokenFilterFactory;
+import org.opensearch.index.analysis.NameOrDefinition;
+import org.opensearch.index.analysis.NamedAnalyzer;
 import org.opensearch.index.analysis.PreConfiguredCharFilter;
 import org.opensearch.index.analysis.PreConfiguredTokenFilter;
 import org.opensearch.index.analysis.PreConfiguredTokenizer;
@@ -80,6 +82,7 @@
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -521,4 +524,70 @@ public boolean incrementToken() throws IOException {
         }
     }
 
+    public void testTokenFilterRegistrationWithModuleReference() throws IOException {
+        // Create a test plugin that adds a custom token filter
+        class TestPlugin implements AnalysisPlugin {
+            @Override
+            public Map<String, AnalysisProvider<TokenFilterFactory>> getTokenFilters(AnalysisModule module) {
+                return Map.of(
+                    "test_filter",
+                    (indexSettings, env, name, settings) -> AppendTokenFilter.factoryForSuffix("_" + module.hashCode())
+                );
+            }
+        }
+
+        // Create environment settings for analysis configuration
+        Settings settings = Settings.builder()
+            .put(Environment.PATH_HOME_SETTING.getKey(), createTempDir().toString())
+            .put(IndexMetadata.SETTING_VERSION_CREATED, Version.CURRENT)
+            .put("index.analysis.analyzer.my_analyzer.tokenizer", "standard")
+            .put("index.analysis.analyzer.my_analyzer.filter", "test_filter")
+            .build();
+
+        // Create analysis module with our test plugin
+        Environment environment = TestEnvironment.newEnvironment(settings);
+        AnalysisModule module = new AnalysisModule(environment, singletonList(new TestPlugin()));
+
+        // Get the registry
+        AnalysisRegistry registry = module.getAnalysisRegistry();
+
+        // Create index settings for testing
+        IndexSettings indexSettings = IndexSettingsModule.newIndexSettings("test", Settings.builder().put(settings).build());
+
+        // Build all token filter factories
+        Map<String, TokenFilterFactory> tokenFilterFactories = registry.buildTokenFilterFactories(indexSettings);
+
+        // Verify our custom filter is registered
+        assertTrue("Token filter 'test_filter' should be registered", tokenFilterFactories.containsKey("test_filter"));
+
+        // Test the analyzer with our custom filter
+        IndexAnalyzers analyzers = registry.build(indexSettings);
+        String testText = "test";
+        TokenStream tokenStream = analyzers.get("my_analyzer").tokenStream("", testText);
+        CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
+
+        tokenStream.reset();
+        assertTrue("Should have found a token", tokenStream.incrementToken());
+        assertEquals("Token should have expected suffix", "test_" + module.hashCode(), charTermAttribute.toString());
+        assertFalse("Should not have additional tokens", tokenStream.incrementToken());
+        tokenStream.close();
+
+        // Verify we can build a custom analyzer using our filter
+        NamedAnalyzer customAnalyzer = registry.buildCustomAnalyzer(
+            indexSettings,
+            false,
+            new NameOrDefinition("standard"),
+            Collections.emptyList(),
+            Collections.singletonList(new NameOrDefinition("test_filter"))
+        );
+
+        tokenStream = customAnalyzer.tokenStream("", testText);
+        charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
+
+        tokenStream.reset();
+        assertTrue("Custom analyzer should produce a token", tokenStream.incrementToken());
+        assertEquals("Custom analyzer token should have expected suffix", "test_" + module.hashCode(), charTermAttribute.toString());
+        assertFalse("Custom analyzer should not produce additional tokens", tokenStream.incrementToken());
+        tokenStream.close();
+    }
 }