Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into inference_metadata_…
Browse files Browse the repository at this point in the history
…fields
  • Loading branch information
jimczi committed Nov 30, 2024
2 parents 9acf74d + deb838c commit 7bcd2b1
Show file tree
Hide file tree
Showing 21 changed files with 790 additions and 444 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

package org.elasticsearch.entitlement.instrumentation.impl;

import org.elasticsearch.entitlement.instrumentation.CheckerMethod;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
import org.elasticsearch.entitlement.instrumentation.Instrumenter;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
Expand All @@ -20,37 +20,23 @@
import org.objectweb.asm.Type;

import java.io.IOException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Stream;

public class InstrumentationServiceImpl implements InstrumentationService {

@Override
public Instrumenter newInstrumenter(String classNameSuffix, Map<MethodKey, CheckerMethod> instrumentationMethods) {
return new InstrumenterImpl(classNameSuffix, instrumentationMethods);
}

/**
* @return a {@link MethodKey} suitable for looking up the given {@code targetMethod} in the entitlements trampoline
*/
public MethodKey methodKeyForTarget(Method targetMethod) {
Type actualType = Type.getMethodType(Type.getMethodDescriptor(targetMethod));
return new MethodKey(
Type.getInternalName(targetMethod.getDeclaringClass()),
targetMethod.getName(),
Stream.of(actualType.getArgumentTypes()).map(Type::getInternalName).toList()
);
public Instrumenter newInstrumenter(Map<MethodKey, CheckMethod> checkMethods) {
return InstrumenterImpl.create(checkMethods);
}

@Override
public Map<MethodKey, CheckerMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException,
public Map<MethodKey, CheckMethod> lookupMethodsToInstrument(String entitlementCheckerClassName) throws ClassNotFoundException,
IOException {
var methodsToInstrument = new HashMap<MethodKey, CheckerMethod>();
var methodsToInstrument = new HashMap<MethodKey, CheckMethod>();
var checkerClass = Class.forName(entitlementCheckerClassName);
var classFileInfo = InstrumenterImpl.getClassFileInfo(checkerClass);
ClassReader reader = new ClassReader(classFileInfo.bytecodes());
Expand All @@ -69,9 +55,9 @@ public MethodVisitor visitMethod(
var methodToInstrument = parseCheckerMethodSignature(checkerMethodName, checkerMethodArgumentTypes);

var checkerParameterDescriptors = Arrays.stream(checkerMethodArgumentTypes).map(Type::getDescriptor).toList();
var checkerMethod = new CheckerMethod(Type.getInternalName(checkerClass), checkerMethodName, checkerParameterDescriptors);
var checkMethod = new CheckMethod(Type.getInternalName(checkerClass), checkerMethodName, checkerParameterDescriptors);

methodsToInstrument.put(methodToInstrument, checkerMethod);
methodsToInstrument.put(methodToInstrument, checkMethod);

return mv;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

package org.elasticsearch.entitlement.instrumentation.impl;

import org.elasticsearch.entitlement.instrumentation.CheckerMethod;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.Instrumenter;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.objectweb.asm.AnnotationVisitor;
Expand Down Expand Up @@ -37,30 +37,43 @@

public class InstrumenterImpl implements Instrumenter {

private static final String checkerClassDescriptor;
private static final String handleClass;
static {
private final String getCheckerClassMethodDescriptor;
private final String handleClass;

/**
* To avoid class name collisions during testing without an agent to replace classes in-place.
*/
private final String classNameSuffix;
private final Map<MethodKey, CheckMethod> checkMethods;

InstrumenterImpl(
String handleClass,
String getCheckerClassMethodDescriptor,
String classNameSuffix,
Map<MethodKey, CheckMethod> checkMethods
) {
this.handleClass = handleClass;
this.getCheckerClassMethodDescriptor = getCheckerClassMethodDescriptor;
this.classNameSuffix = classNameSuffix;
this.checkMethods = checkMethods;
}

static String getCheckerClassName() {
int javaVersion = Runtime.version().feature();
final String classNamePrefix;
if (javaVersion >= 23) {
classNamePrefix = "Java23";
} else {
classNamePrefix = "";
}
String checkerClass = "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker";
handleClass = checkerClass + "Handle";
checkerClassDescriptor = Type.getObjectType(checkerClass).getDescriptor();
return "org/elasticsearch/entitlement/bridge/" + classNamePrefix + "EntitlementChecker";
}

/**
* To avoid class name collisions during testing without an agent to replace classes in-place.
*/
private final String classNameSuffix;
private final Map<MethodKey, CheckerMethod> instrumentationMethods;

public InstrumenterImpl(String classNameSuffix, Map<MethodKey, CheckerMethod> instrumentationMethods) {
this.classNameSuffix = classNameSuffix;
this.instrumentationMethods = instrumentationMethods;
public static InstrumenterImpl create(Map<MethodKey, CheckMethod> checkMethods) {
String checkerClass = getCheckerClassName();
String handleClass = checkerClass + "Handle";
String getCheckerClassMethodDescriptor = Type.getMethodDescriptor(Type.getObjectType(checkerClass));
return new InstrumenterImpl(handleClass, getCheckerClassMethodDescriptor, "", checkMethods);
}

public ClassFileInfo instrumentClassFile(Class<?> clazz) throws IOException {
Expand Down Expand Up @@ -156,7 +169,7 @@ public MethodVisitor visitMethod(int access, String name, String descriptor, Str
boolean isStatic = (access & ACC_STATIC) != 0;
boolean isCtor = "<init>".equals(name);
var key = new MethodKey(className, name, Stream.of(Type.getArgumentTypes(descriptor)).map(Type::getInternalName).toList());
var instrumentationMethod = instrumentationMethods.get(key);
var instrumentationMethod = checkMethods.get(key);
if (instrumentationMethod != null) {
// LOGGER.debug("Will instrument method {}", key);
return new EntitlementMethodVisitor(Opcodes.ASM9, mv, isStatic, isCtor, descriptor, instrumentationMethod);
Expand Down Expand Up @@ -190,7 +203,7 @@ class EntitlementMethodVisitor extends MethodVisitor {
private final boolean instrumentedMethodIsStatic;
private final boolean instrumentedMethodIsCtor;
private final String instrumentedMethodDescriptor;
private final CheckerMethod instrumentationMethod;
private final CheckMethod checkMethod;
private boolean hasCallerSensitiveAnnotation = false;

EntitlementMethodVisitor(
Expand All @@ -199,13 +212,13 @@ class EntitlementMethodVisitor extends MethodVisitor {
boolean instrumentedMethodIsStatic,
boolean instrumentedMethodIsCtor,
String instrumentedMethodDescriptor,
CheckerMethod instrumentationMethod
CheckMethod checkMethod
) {
super(api, methodVisitor);
this.instrumentedMethodIsStatic = instrumentedMethodIsStatic;
this.instrumentedMethodIsCtor = instrumentedMethodIsCtor;
this.instrumentedMethodDescriptor = instrumentedMethodDescriptor;
this.instrumentationMethod = instrumentationMethod;
this.checkMethod = checkMethod;
}

@Override
Expand Down Expand Up @@ -278,19 +291,19 @@ private void forwardIncomingArguments() {
private void invokeInstrumentationMethod() {
mv.visitMethodInsn(
INVOKEINTERFACE,
instrumentationMethod.className(),
instrumentationMethod.methodName(),
checkMethod.className(),
checkMethod.methodName(),
Type.getMethodDescriptor(
Type.VOID_TYPE,
instrumentationMethod.parameterDescriptors().stream().map(Type::getType).toArray(Type[]::new)
checkMethod.parameterDescriptors().stream().map(Type::getType).toArray(Type[]::new)
),
true
);
}
}

protected void pushEntitlementChecker(MethodVisitor mv) {
mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", "()" + checkerClassDescriptor, false);
mv.visitMethodInsn(INVOKESTATIC, handleClass, "instance", getCheckerClassMethodDescriptor, false);
}

public record ClassFileInfo(String fileName, byte[] bytecodes) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

package org.elasticsearch.entitlement.instrumentation.impl;

import org.elasticsearch.entitlement.instrumentation.CheckerMethod;
import org.elasticsearch.entitlement.instrumentation.CheckMethod;
import org.elasticsearch.entitlement.instrumentation.InstrumentationService;
import org.elasticsearch.entitlement.instrumentation.MethodKey;
import org.elasticsearch.test.ESTestCase;
Expand Down Expand Up @@ -52,15 +52,15 @@ interface TestCheckerCtors {
}

public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestChecker.class.getName());

assertThat(methodsMap, aMapWithSize(3));
assertThat(checkMethods, aMapWithSize(3));
assertThat(
methodsMap,
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "staticMethod", List.of("I", "java/lang/String", "java/lang/Object"))),
equalTo(
new CheckerMethod(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker",
"check$org_example_TestTargetClass$staticMethod",
List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;", "Ljava/lang/Object;")
Expand All @@ -69,7 +69,7 @@ public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundE
)
);
assertThat(
methodsMap,
checkMethods,
hasEntry(
equalTo(
new MethodKey(
Expand All @@ -79,7 +79,7 @@ public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundE
)
),
equalTo(
new CheckerMethod(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker",
"check$$instanceMethodNoArgs",
List.of(
Expand All @@ -91,7 +91,7 @@ public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundE
)
);
assertThat(
methodsMap,
checkMethods,
hasEntry(
equalTo(
new MethodKey(
Expand All @@ -101,7 +101,7 @@ public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundE
)
),
equalTo(
new CheckerMethod(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestChecker",
"check$$instanceMethodWithArgs",
List.of(
Expand All @@ -117,15 +117,15 @@ public void testInstrumentationTargetLookup() throws IOException, ClassNotFoundE
}

public void testInstrumentationTargetLookupWithOverloads() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName());
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerOverloads.class.getName());

assertThat(methodsMap, aMapWithSize(2));
assertThat(checkMethods, aMapWithSize(2));
assertThat(
methodsMap,
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "staticMethodWithOverload", List.of("I", "java/lang/String"))),
equalTo(
new CheckerMethod(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerOverloads",
"check$org_example_TestTargetClass$staticMethodWithOverload",
List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;")
Expand All @@ -134,11 +134,11 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException, C
)
);
assertThat(
methodsMap,
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "staticMethodWithOverload", List.of("I", "I"))),
equalTo(
new CheckerMethod(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerOverloads",
"check$org_example_TestTargetClass$staticMethodWithOverload",
List.of("Ljava/lang/Class;", "I", "I")
Expand All @@ -149,15 +149,15 @@ public void testInstrumentationTargetLookupWithOverloads() throws IOException, C
}

public void testInstrumentationTargetLookupWithCtors() throws IOException, ClassNotFoundException {
Map<MethodKey, CheckerMethod> methodsMap = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());
Map<MethodKey, CheckMethod> checkMethods = instrumentationService.lookupMethodsToInstrument(TestCheckerCtors.class.getName());

assertThat(methodsMap, aMapWithSize(2));
assertThat(checkMethods, aMapWithSize(2));
assertThat(
methodsMap,
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "<init>", List.of("I", "java/lang/String"))),
equalTo(
new CheckerMethod(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerCtors",
"check$org_example_TestTargetClass$",
List.of("Ljava/lang/Class;", "I", "Ljava/lang/String;")
Expand All @@ -166,11 +166,11 @@ public void testInstrumentationTargetLookupWithCtors() throws IOException, Class
)
);
assertThat(
methodsMap,
checkMethods,
hasEntry(
equalTo(new MethodKey("org/example/TestTargetClass", "<init>", List.of())),
equalTo(
new CheckerMethod(
new CheckMethod(
"org/elasticsearch/entitlement/instrumentation/impl/InstrumentationServiceImplTests$TestCheckerCtors",
"check$org_example_TestTargetClass$",
List.of("Ljava/lang/Class;")
Expand Down
Loading

0 comments on commit 7bcd2b1

Please sign in to comment.