Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ljfgem committed Aug 27, 2024
1 parent aafd59d commit 9e68ce2
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
* Class to resolve hive function names in SQL to Function.
*/
public class HiveFunctionResolver {
private static final String CLASS_NAME_PREFIX = "coral_udf_version_(\\d+|x)_(\\d+|x)_(\\d+|x)";
private static final String VERSIONED_UDF_CLASS_NAME_PREFIX = "coral_udf_version_(\\d+|x)_(\\d+|x)_(\\d+|x)";

public final FunctionRegistry registry;
private final ConcurrentHashMap<String, Function> dynamicFunctionRegistry;
Expand Down Expand Up @@ -112,22 +112,23 @@ public SqlOperator resolveBinaryOperator(String name) {
* this attempts to match dali-style function names (DB_TABLE_VERSION_FUNCTION).
* Right now, this method does not validate parameters leaving it to
* the subsequent validator and analyzer phases to validate parameter types.
* @param functionName hive function name
* @param originalViewTextFunctionName original function name in view text to resolve
* @param hiveTable handle to Hive table representing metastore information. This is used for resolving
* Dali function names, which are resolved using table parameters
* @param numOfOperands number of operands this function takes. This is needed to
* create SqlOperandTypeChecker to resolve Dali function dynamically
* @return resolved hive functions
* @throws UnknownSqlFunctionException if the function name can not be resolved.
*/
public Function tryResolve(@Nonnull String functionName, @Nullable Table hiveTable, int numOfOperands) {
checkNotNull(functionName);
Collection<Function> functions = registry.lookup(functionName);
public Function tryResolve(@Nonnull String originalViewTextFunctionName, @Nullable Table hiveTable,
int numOfOperands) {
checkNotNull(originalViewTextFunctionName);
Collection<Function> functions = registry.lookup(originalViewTextFunctionName);
if (functions.isEmpty() && hiveTable != null) {
functions = tryResolveAsDaliFunction(functionName, hiveTable, numOfOperands);
functions = tryResolveAsDaliFunction(originalViewTextFunctionName, hiveTable, numOfOperands);
}
if (functions.isEmpty()) {
throw new UnknownSqlFunctionException(functionName);
throw new UnknownSqlFunctionException(originalViewTextFunctionName);
}
if (functions.size() == 1) {
return functions.iterator().next();
Expand Down Expand Up @@ -161,39 +162,40 @@ public Collection<Function> resolve(String functionName) {
/**
* Tries to resolve function name as Dali function name using the provided Hive table catalog information.
* This uses table parameters 'function' property to resolve the function name to the implementing class.
* @param functionName function name to resolve
* @param originalViewTextFunctionName original function name in view text to resolve
* @param table Hive metastore table handle
* @param numOfOperands number of operands this function takes. This is needed to
* create SqlOperandTypeChecker to resolve Dali function dynamically
* @return list of matching Functions or empty list if the function name is not in the dali function name format
* of `databaseName_tableName_udfName` or `udfName` (without `databaseName_tableName_` prefix)
* @throws UnknownSqlFunctionException if the function name is in Dali function name format but there is no mapping
*/
public Collection<Function> tryResolveAsDaliFunction(String functionName, @Nonnull Table table, int numOfOperands) {
public Collection<Function> tryResolveAsDaliFunction(String originalViewTextFunctionName, @Nonnull Table table,
int numOfOperands) {
Preconditions.checkNotNull(table);
String functionPrefix = String.format("%s_%s_", table.getDbName(), table.getTableName());
if (!functionName.toLowerCase().startsWith(functionPrefix.toLowerCase())) {
// if functionName is not in `databaseName_tableName_udfName` format, we don't require the `databaseName_tableName_` prefix
if (!originalViewTextFunctionName.toLowerCase().startsWith(functionPrefix.toLowerCase())) {
// if originalViewTextFunctionName is not in `databaseName_tableName_udfName` format, we don't require the `databaseName_tableName_` prefix
functionPrefix = "";
}
String funcBaseName = functionName.substring(functionPrefix.length());
String funcBaseName = originalViewTextFunctionName.substring(functionPrefix.length());
HiveTable hiveTable = new HiveTable(table);
Map<String, String> functionParams = hiveTable.getDaliFunctionParams();
String funcClassName = functionParams.get(funcBaseName);
if (funcClassName == null) {
String functionClassName = functionParams.get(funcBaseName);
if (functionClassName == null) {
return ImmutableList.of();
}
// If the UDF class name is versioned, remove the versioning prefix, which allows user to
// register the unversioned UDF once and use different versioning prefix in the view
final Collection<Function> functions = registry.lookup(removeVersioningPrefix(funcClassName));
final Collection<Function> functions = registry.lookup(removeVersioningPrefix(functionClassName));
if (functions.isEmpty()) {
Collection<Function> dynamicResolvedFunctions =
resolveDaliFunctionDynamically(functionName, funcClassName, hiveTable, numOfOperands);
resolveDaliFunctionDynamically(originalViewTextFunctionName, functionClassName, hiveTable, numOfOperands);

if (dynamicResolvedFunctions.isEmpty()) {
// we want to see class name in the exception message for coverage testing
// so throw exception here
throw new UnknownSqlFunctionException(funcClassName);
throw new UnknownSqlFunctionException(functionClassName);
}

return dynamicResolvedFunctions;
Expand All @@ -202,28 +204,28 @@ public Collection<Function> tryResolveAsDaliFunction(String functionName, @Nonnu
return functions.stream()
.map(f -> new Function(f.getFunctionName(),
new VersionedSqlUserDefinedFunction((SqlUserDefinedFunction) f.getSqlOperator(),
hiveTable.getDaliUdfDependencies(), functionName, funcClassName)))
hiveTable.getDaliUdfDependencies(), originalViewTextFunctionName, functionClassName)))
.collect(Collectors.toList());
}

public void addDynamicFunctionToTheRegistry(String funcClassName, Function function) {
if (!dynamicFunctionRegistry.contains(funcClassName)) {
dynamicFunctionRegistry.put(funcClassName, function);
public void addDynamicFunctionToTheRegistry(String functionClassName, Function function) {
if (!dynamicFunctionRegistry.contains(functionClassName)) {
dynamicFunctionRegistry.put(functionClassName, function);
}
}

private @Nonnull Collection<Function> resolveDaliFunctionDynamically(String functionName, String funcClassName,
HiveTable hiveTable, int numOfOperands) {
if (dynamicFunctionRegistry.contains(funcClassName)) {
return ImmutableList.of(dynamicFunctionRegistry.get(functionName));
private @Nonnull Collection<Function> resolveDaliFunctionDynamically(String originalViewTextFunctionName,
String functionClassName, HiveTable hiveTable, int numOfOperands) {
if (dynamicFunctionRegistry.contains(functionClassName)) {
return ImmutableList.of(dynamicFunctionRegistry.get(originalViewTextFunctionName));
}
Function function = new Function(funcClassName,
Function function = new Function(functionClassName,
new VersionedSqlUserDefinedFunction(
new SqlUserDefinedFunction(new SqlIdentifier(funcClassName, ZERO),
new HiveGenericUDFReturnTypeInference(funcClassName, hiveTable.getDaliUdfDependencies()), null,
new SqlUserDefinedFunction(new SqlIdentifier(functionClassName, ZERO),
new HiveGenericUDFReturnTypeInference(functionClassName, hiveTable.getDaliUdfDependencies()), null,
createSqlOperandTypeChecker(numOfOperands), null, null),
hiveTable.getDaliUdfDependencies(), functionName, funcClassName));
dynamicFunctionRegistry.put(funcClassName, function);
hiveTable.getDaliUdfDependencies(), originalViewTextFunctionName, functionClassName));
dynamicFunctionRegistry.put(functionClassName, function);
return ImmutableList.of(function);
}

Expand All @@ -246,14 +248,14 @@ public void addDynamicFunctionToTheRegistry(String funcClassName, Function funct
/**
* Removes the versioning prefix from a given UDF class name if it is present.
* A class name is considered versioned if the prefix before the first dot
* follows {@link HiveFunctionResolver#CLASS_NAME_PREFIX} format
* follows {@link HiveFunctionResolver#VERSIONED_UDF_CLASS_NAME_PREFIX} format
*/
private String removeVersioningPrefix(String className) {
if (className != null && !className.isEmpty()) {
int firstDotIndex = className.indexOf('.');
if (firstDotIndex != -1) {
String prefix = className.substring(0, firstDotIndex);
if (prefix.matches(CLASS_NAME_PREFIX)) {
if (prefix.matches(VERSIONED_UDF_CLASS_NAME_PREFIX)) {
return className.substring(firstDotIndex + 1);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,36 +55,36 @@ public class VersionedSqlUserDefinedFunction extends SqlUserDefinedFunction {

// The view-dependent function name in the format of "dbName_viewName_functionName",
// where functionName is defined in the "functions" property of the view.
private final String viewDependentFunctionName;
private final String originalViewTextFunctionName;

// The UDF class name value defined in the "functions" property of the view.
// i.e. "functions = <viewDependentFunctionName> : <funcClassName>"
private final String funcClassName;
// i.e. "functions = <originalViewTextFunctionName> : <functionClassName>"
private final String functionClassName;

private VersionedSqlUserDefinedFunction(SqlIdentifier opName, SqlReturnTypeInference returnTypeInference,
SqlOperandTypeInference operandTypeInference, SqlOperandTypeChecker operandTypeChecker,
List<RelDataType> paramTypes, Function function, List<String> ivyDependencies, String viewDependentFunctionName,
String funcClassName) {
List<RelDataType> paramTypes, Function function, List<String> ivyDependencies,
String originalViewTextFunctionName, String functionClassName) {
super(opName, returnTypeInference, operandTypeInference, operandTypeChecker, paramTypes, function,
SqlFunctionCategory.USER_DEFINED_FUNCTION);
this.ivyDependencies = ivyDependencies;
this.viewDependentFunctionName = viewDependentFunctionName;
this.funcClassName = funcClassName;
this.originalViewTextFunctionName = originalViewTextFunctionName;
this.functionClassName = functionClassName;
}

public VersionedSqlUserDefinedFunction(SqlUserDefinedFunction sqlUdf, List<String> ivyDependencies,
String viewDependentFunctionName, String funcClassName) {
String originalViewTextFunctionName, String functionClassName) {
this(new SqlIdentifier(ImmutableList.of(sqlUdf.getName()), SqlParserPos.ZERO), sqlUdf.getReturnTypeInference(),
null, sqlUdf.getOperandTypeChecker(), sqlUdf.getParamTypes(), sqlUdf.getFunction(), ivyDependencies,
viewDependentFunctionName, funcClassName);
originalViewTextFunctionName, functionClassName);
}

public List<String> getIvyDependencies() {
return ivyDependencies;
}

public String getViewDependentFunctionName() {
return viewDependentFunctionName;
public String getOriginalViewTextFunctionName() {
return originalViewTextFunctionName;
}

/**
Expand All @@ -106,20 +106,20 @@ public String getShortFunctionName() {
return caseConverter.convert(nameSplit[nameSplit.length - 1]);
}

public String getFuncClassName() {
return funcClassName;
public String getFunctionClassName() {
return functionClassName;
}

// This method is called during SQL validation. The super-class implementation resets the call's sqlOperator to one
// that is looked up from the StaticHiveFunctionRegistry or inferred dynamically if it's a Dali UDF. Since UDFs in the StaticHiveFunctionRegistry are not
// versioned, this method overrides the super-class implementation to properly restore the call's operator as
// a VersionedSqlUserDefinedFunction based on the already existing call's sqlOperator obtained from the
// StaticHiveFunctionRegistry, and hence preserve ivyDependencies and viewDependentFunctionName.
// StaticHiveFunctionRegistry, and hence preserve ivyDependencies and originalViewTextFunctionName.
@Override
public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
RelDataType relDataType = super.deriveType(validator, scope, call);
((SqlBasicCall) call).setOperator(new VersionedSqlUserDefinedFunction((SqlUserDefinedFunction) (call.getOperator()),
ivyDependencies, viewDependentFunctionName, funcClassName));
ivyDependencies, originalViewTextFunctionName, functionClassName));
return relDataType;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 LinkedIn Corporation. All rights reserved.
* Copyright 2019-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -166,7 +166,7 @@ private String getVersionedFunctionName(RexCall rexCall) {
}

final VersionedSqlUserDefinedFunction versionedFunction = (VersionedSqlUserDefinedFunction) rexCall.getOperator();
return String.join("_", PIG_UDF_ALIAS_TEMPLATE, versionedFunction.getViewDependentFunctionName())
return String.join("_", PIG_UDF_ALIAS_TEMPLATE, versionedFunction.getOriginalViewTextFunctionName())
.replace(NOT_ALPHA_NUMERIC_UNDERSCORE_REGEX, "_");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ protected SqlCall transform(SqlCall sqlCall) {
if (UNSUPPORTED_HIVE_UDFS.contains(operatorName)) {
throw new UnsupportedUDFException(operatorName);
}
final String viewDependentFunctionName = operator.getViewDependentFunctionName();
final String originalViewTextFunctionName = operator.getOriginalViewTextFunctionName();
final List<String> dependencies = operator.getIvyDependencies();
List<URI> listOfUris = dependencies.stream().map(URI::create).collect(Collectors.toList());
LOG.info("Function: {} is not a Builtin UDF or Transport UDF. We fall back to its Hive "
+ "function with ivy dependency: {}", operatorName, String.join(",", dependencies));
final SparkUDFInfo sparkUDFInfo = new SparkUDFInfo(operator.getFuncClassName(), viewDependentFunctionName,
final SparkUDFInfo sparkUDFInfo = new SparkUDFInfo(operator.getFunctionClassName(), originalViewTextFunctionName,
listOfUris, SparkUDFInfo.UDFTYPE.HIVE_CUSTOM_UDF);
sparkUDFInfos.add(sparkUDFInfo);
final SqlOperator convertedFunction =
createSqlOperator(viewDependentFunctionName, operator.getReturnTypeInference());
createSqlOperator(originalViewTextFunctionName, operator.getReturnTypeInference());
return convertedFunction.createCall(sqlCall.getParserPosition(), sqlCall.getOperandList());
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -73,13 +73,13 @@ protected boolean condition(SqlCall sqlCall) {
@Override
protected SqlCall transform(SqlCall sqlCall) {
final VersionedSqlUserDefinedFunction operator = (VersionedSqlUserDefinedFunction) sqlCall.getOperator();
final String viewDependentFunctionName = operator.getViewDependentFunctionName();
sparkUDFInfos.add(new SparkUDFInfo(sparkUDFClassName, viewDependentFunctionName,
final String originalViewTextFunctionName = operator.getOriginalViewTextFunctionName();
sparkUDFInfos.add(new SparkUDFInfo(sparkUDFClassName, originalViewTextFunctionName,
Collections.singletonList(
URI.create(scalaVersion == ScalaVersion.SCALA_2_11 ? artifactoryUrlSpark211 : artifactoryUrlSpark212)),
SparkUDFInfo.UDFTYPE.TRANSPORTABLE_UDF));
final SqlOperator convertedFunction =
createSqlOperator(viewDependentFunctionName, operator.getReturnTypeInference());
createSqlOperator(originalViewTextFunctionName, operator.getReturnTypeInference());
return convertedFunction.createCall(sqlCall.getParserPosition(), sqlCall.getOperandList());
}

Expand Down

0 comments on commit 9e68ce2

Please sign in to comment.