diff --git a/euphoria-core/src/test/java/cz/seznam/euphoria/core/client/operator/HintTest.java b/euphoria-core/src/test/java/cz/seznam/euphoria/core/client/operator/HintTest.java index dfb4f7b3..9c3ed8e8 100644 --- a/euphoria-core/src/test/java/cz/seznam/euphoria/core/client/operator/HintTest.java +++ b/euphoria-core/src/test/java/cz/seznam/euphoria/core/client/operator/HintTest.java @@ -1,3 +1,18 @@ +/* + * Copyright 2016-2018 Seznam.cz, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package cz.seznam.euphoria.core.client.operator; import cz.seznam.euphoria.core.client.dataset.Dataset; diff --git a/euphoria-core/src/test/java/cz/seznam/euphoria/core/client/operator/JoinTest.java b/euphoria-core/src/test/java/cz/seznam/euphoria/core/client/operator/JoinTest.java index bcb3c14c..c19880aa 100644 --- a/euphoria-core/src/test/java/cz/seznam/euphoria/core/client/operator/JoinTest.java +++ b/euphoria-core/src/test/java/cz/seznam/euphoria/core/client/operator/JoinTest.java @@ -219,8 +219,8 @@ public void testBuild_Hints() { assertTrue(outputDataset.getProducer().getHints().contains(SizeHint.FITS_IN_MEMORY)); Join join = (Join) flow.operators().stream().filter(op -> op instanceof Join).findFirst().get(); - assertTrue(join.listInputs().stream().anyMatch(input -> ((Dataset)input).getProducer().getHints().contains(new - Util.TestHint()))); + assertTrue(join.listInputs().stream().anyMatch(input -> + ((Dataset) input).getProducer().getHints().contains(new Util.TestHint()))); assertTrue(join.listInputs().stream().anyMatch(input -> ((Dataset) input).getProducer().getHints().contains(new Util.TestHint2()))); assertEquals(2, ((Dataset) join.listInputs().stream().findFirst().get()).getProducer().getHints().size()); diff --git a/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BroadcastHashJoinTranslator.java b/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BroadcastHashJoinTranslator.java index 0017a87c..09541656 100644 --- a/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BroadcastHashJoinTranslator.java +++ b/euphoria-flink/src/main/java/cz/seznam/euphoria/flink/batch/BroadcastHashJoinTranslator.java @@ -22,6 +22,7 @@ import cz.seznam.euphoria.core.client.functional.BinaryFunctor; import cz.seznam.euphoria.core.client.functional.UnaryFunction; import cz.seznam.euphoria.core.client.operator.Join; +import cz.seznam.euphoria.core.client.operator.Operator; import cz.seznam.euphoria.core.client.operator.hint.SizeHint; import cz.seznam.euphoria.core.client.util.Pair; import cz.seznam.euphoria.core.executor.util.MultiValueContext; @@ -41,13 +42,20 @@ public class BroadcastHashJoinTranslator implements BatchOperatorTranslator ((Dataset) input).getHints().contains(SizeHint.FITS_IN_MEMORY)) + .anyMatch(input -> hasSizeHint(((Dataset) input).getProducer())) && (o.getType() == Join.Type.LEFT || o.getType() == Join.Type.RIGHT) && !(o.getWindowing() instanceof MergingWindowing); } + static boolean hasSizeHint(Operator operator) { + return operator != null && + operator.getHints() != null && + operator.getHints().contains(SizeHint.FITS_IN_MEMORY); + } + @Override @SuppressWarnings("unchecked") public DataSet translate(FlinkOperator operator, BatchExecutorContext context) { diff --git a/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.java b/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.java index cc1efc41..79dc4608 100644 --- a/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.java +++ b/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/BroadcastHashJoinTranslator.java @@ -24,6 +24,7 @@ import cz.seznam.euphoria.core.client.dataset.windowing.Windowing; import cz.seznam.euphoria.core.client.functional.UnaryFunction; import cz.seznam.euphoria.core.client.operator.Join; +import cz.seznam.euphoria.core.client.operator.Operator; import cz.seznam.euphoria.core.client.operator.hint.SizeHint; import cz.seznam.euphoria.core.client.util.Either; import cz.seznam.euphoria.core.client.util.Pair; @@ -58,11 +59,17 @@ public class BroadcastHashJoinTranslator implements SparkOperatorTranslator ((Dataset) input).getProducer().getHints().contains(SizeHint.FITS_IN_MEMORY)) + .anyMatch(input -> hasSizeHint(((Dataset) input).getProducer())) && (o.getType() == Join.Type.LEFT || o.getType() == Join.Type.RIGHT) && !(o.getWindowing() instanceof MergingWindowing); } + static boolean hasSizeHint(Operator operator) { + return operator != null && + operator.getHints() != null && + operator.getHints().contains(SizeHint.FITS_IN_MEMORY); + } + @Override @SuppressWarnings("unchecked") public JavaRDD translate(Join operator, SparkExecutorContext context) { @@ -71,7 +78,7 @@ public JavaRDD translate(Join operator, SparkExecutorContext context) { Preconditions.checkArgument( operator.listInputs() .stream() - .anyMatch(input -> ((Dataset) input).getProducer().getHints().contains(SizeHint.FITS_IN_MEMORY)), + .anyMatch(input -> hasSizeHint(((Dataset) input).getProducer())), "Missing broadcastHashJoin hint"); Preconditions.checkArgument( operator.getType() == Join.Type.LEFT || operator.getType() == Join.Type.RIGHT,