Skip to content

Commit

Permalink
update correlation command
Browse files Browse the repository at this point in the history
Signed-off-by: YANGDB <[email protected]>
  • Loading branch information
YANG-DB committed Oct 8, 2023
1 parent e9d1589 commit f998427
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package org.opensearch.sql.ast.expression;

import org.opensearch.sql.ast.AbstractNodeVisitor;

import java.util.List;

public class FieldsMapping extends UnresolvedExpression {


private final List<UnresolvedExpression> fieldsMappingList;

public <R> FieldsMapping(List<UnresolvedExpression> fieldsMappingList) {
this.fieldsMappingList = fieldsMappingList;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visit(this, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.opensearch.sql.ast.expression;

/** Scope expression node. Params include field expression and the scope value. */
public class Scope extends Span {
public Scope(UnresolvedExpression field, UnresolvedExpression value, SpanUnit unit) {
super(field, value, unit);
}

}
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
package org.opensearch.sql.ast.tree;

import org.opensearch.flint.spark.ppl.OpenSearchPPLParser;
import com.google.common.collect.ImmutableList;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Scope;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedExpression;

import java.util.List;

/** Logical plan node of correlation , the interface for building the searching sources. */

public class Correlation extends UnresolvedPlan {
private final CorrelationType correlationTypeContext;
private final List<OpenSearchPPLParser.FieldExpressionContext> fieldExpression;
private final OpenSearchPPLParser.ScopeClauseContext contextParamContext;
private final OpenSearchPPLParser.MappingListContext mappingListContext;
private UnresolvedPlan child;
public Correlation(OpenSearchPPLParser.CorrelationTypeContext correlationTypeContext, OpenSearchPPLParser.FieldListContext fieldListContext, OpenSearchPPLParser.ScopeClauseContext contextParamContext, OpenSearchPPLParser.MappingListContext mappingListContext) {
this.correlationTypeContext = CorrelationType.valueOf(correlationTypeContext.getText());
this.fieldExpression = fieldListContext.fieldExpression();
this.contextParamContext = contextParamContext;
private final CorrelationType correlationType;
private final List<UnresolvedExpression> fieldsList;
private final Scope scope;
private final FieldsMapping mappingListContext;
private UnresolvedPlan child ;
public Correlation(String correlationType, List<UnresolvedExpression> fieldsList, Scope scope, FieldsMapping mappingListContext) {
this.correlationType = CorrelationType.valueOf(correlationType);
this.fieldsList = fieldsList;
this.scope = scope;
this.mappingListContext = mappingListContext;
}

Expand All @@ -25,15 +30,37 @@ public <T, C> T accept(AbstractNodeVisitor<T, C> nodeVisitor, C context) {
return nodeVisitor.visitCorrelation(this, context);
}

@Override
public List<? extends Node> getChild() {
return ImmutableList.of(child);
}

@Override
public Correlation attach(UnresolvedPlan child) {
this.child = child;
return this;
}

enum CorrelationType {

public CorrelationType getCorrelationType() {
return correlationType;
}

public List<UnresolvedExpression> getFieldsList() {
return fieldsList;
}

public Scope getScope() {
return scope;
}

public FieldsMapping getMappingListContext() {
return mappingListContext;
}

public enum CorrelationType {
exact,
approximate
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.sql.ast.statement.Query;
import org.opensearch.sql.ast.statement.Statement;
import org.opensearch.sql.ast.tree.Aggregation;
import org.opensearch.sql.ast.tree.Correlation;
import org.opensearch.sql.ast.tree.Dedupe;
import org.opensearch.sql.ast.tree.Eval;
import org.opensearch.sql.ast.tree.Filter;
Expand All @@ -63,6 +64,7 @@
import static java.util.List.of;
import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate;
import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join;
import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window;

/**
Expand Down Expand Up @@ -110,6 +112,18 @@ public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) {
return context.plan(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(innerConditionExpression, p));
}

@Override
public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) {
visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context);
Seq<Expression> fields = context.retainAllNamedParseExpressions(e -> e);
expressionAnalyzer.visitSpan(node.getScope(), context);
Expression scope = context.getNamedParseExpressions().pop();
node.getMappingListContext().accept(this, context);
Seq<Expression> mapping = context.retainAllNamedParseExpressions(e -> e);
return context.plan(p -> join(node.getCorrelationType(), fields, scope, mapping, p));
}


@Override
public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
Expand All @@ -130,7 +144,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex
// build the aggregation logical step
return extractedAggregation(context);
}

private static LogicalPlan extractedAggregation(CatalystPlanContext context) {
Seq<Expression> groupingExpression = context.retainAllGroupingNamedParseExpressions(p -> p);
Seq<NamedExpression> aggregateExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p);
Expand Down Expand Up @@ -161,7 +175,7 @@ public LogicalPlan visitProject(Project node, CatalystPlanContext context) {
}
return child;
}

@Override
public LogicalPlan visitSort(Sort node, CatalystPlanContext context) {
node.getChild().get(0).accept(this, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@
import org.opensearch.sql.ast.expression.Alias;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Map;
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Scope;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.tree.Aggregation;
Expand All @@ -42,9 +45,12 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;


/** Class of building the AST. Refines the visit path and build the AST nodes */
public class AstBuilder extends OpenSearchPPLParserBaseVisitor<UnresolvedPlan> {
Expand Down Expand Up @@ -102,7 +108,17 @@ public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext

@Override
public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) {
return new Correlation(ctx.correlationType(),ctx.fieldList(),ctx.scopeClause(),ctx.mappingList());
return new Correlation(ctx.correlationType().getText(),
ctx.fieldList().fieldExpression().stream()
.map(this::internalVisitExpression)
.collect(Collectors.toList()),
new Scope(expressionBuilder.visit(ctx.scopeClause().fieldExpression()),
expressionBuilder.visit(ctx.scopeClause().value),
SpanUnit.of(ctx.scopeClause().unit.getText())),
Objects.isNull(ctx.mappingList()) ? new FieldsMapping(emptyList()) : new FieldsMapping(ctx.mappingList()
.mappingClause().stream()
.map(this::internalVisitExpression)
.collect(Collectors.toList())));
}

/** Fields command. */
Expand Down Expand Up @@ -155,7 +171,7 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext
getTextInQuery(groupCtx),
internalVisitExpression(groupCtx)))
.collect(Collectors.toList()))
.orElse(Collections.emptyList());
.orElse(emptyList());

UnresolvedExpression span =
Optional.ofNullable(ctx.statsByClause())
Expand All @@ -166,7 +182,7 @@ public UnresolvedPlan visitStatsCommand(OpenSearchPPLParser.StatsCommandContext
Aggregation aggregation =
new Aggregation(
aggListBuilder.build(),
Collections.emptyList(),
emptyList(),
groupList,
span,
ArgumentFactory.getArgumentList(ctx));
Expand Down Expand Up @@ -260,7 +276,7 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo
@Override
public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) {
List<UnresolvedExpression> groupList =
ctx.byClause() == null ? Collections.emptyList() : getGroupByList(ctx.byClause());
ctx.byClause() == null ? emptyList() : getGroupByList(ctx.byClause());
return new RareTopN(
RareTopN.CommandType.TOP,
ArgumentFactory.getArgumentList(ctx),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.opensearch.sql.ppl.utils;

import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.opensearch.sql.ast.tree.Correlation;
import scala.collection.Seq;

public interface JoinSpecTransformer {

static LogicalPlan join(Correlation.CorrelationType correlationType, Seq<Expression> fields, Expression valueExpression, Seq<Expression> mapping, LogicalPlan p) {
//create a join statement
return p;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,44 @@ class PPLLogicalPlanCorrelationQueriesTranslatorTestSuite


test("Search multiple tables with correlation - translated into join call with fields") {
val context = new CatalystPlanContext
val query = "source = table1, table2 | correlate exact fields(ip, port) scope(@timestamp, 1d)"
val logPlan = planTrnasformer.visit(plan(pplParser, query, false), context)

val table1 = UnresolvedRelation(Seq("table1"))
val table2 = UnresolvedRelation(Seq("table2"))

val allFields1 = UnresolvedStar(None)
val allFields2 = UnresolvedStar(None)

val projectedTable1 = Project(Seq(allFields1), table1)
val projectedTable2 = Project(Seq(allFields2), table2)

val expectedPlan =
Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(expectedPlan, logPlan)
}
test("Search multiple tables with correlation with filters - translated into join call with fields") {
val context = new CatalystPlanContext
val query = "source = table1, table2 | where @timestamp=`2018-07-02T22:23:00` AND ip=`10.0.0.1` AND cloud.provider=`aws` | correlate exact fields(ip, port) scope(@timestamp, 1d)"
val logPlan = planTrnasformer.visit(plan(pplParser, query, false), context)

val table1 = UnresolvedRelation(Seq("table1"))
val table2 = UnresolvedRelation(Seq("table2"))

val allFields1 = UnresolvedStar(None)
val allFields2 = UnresolvedStar(None)

val projectedTable1 = Project(Seq(allFields1), table1)
val projectedTable2 = Project(Seq(allFields2), table2)

val expectedPlan =
Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(expectedPlan, logPlan)
}
test("Search multiple tables with correlation - translated into join call with different fields mapping ") {
val context = new CatalystPlanContext
val query = "source = table1, table2 | where @timestamp=`2018-07-02T22:23:00` AND ip=`10.0.0.1` AND cloud.provider=`aws` | correlate exact fields(ip, port) scope(@timestamp, 1d)" +
" mapping( alb_logs.ip = traces.source_ip, alb_logs.port = metrics.target_port )"
Expand Down

0 comments on commit f998427

Please sign in to comment.