diff --git a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/observability/ObservationRequestTracker.java b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/observability/ObservationRequestTracker.java index 79e2c4846..7e71c71e4 100644 --- a/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/observability/ObservationRequestTracker.java +++ b/spring-data-cassandra/src/main/java/org/springframework/data/cassandra/observability/ObservationRequestTracker.java @@ -19,8 +19,11 @@ import io.micrometer.observation.Observation.Context; import io.micrometer.observation.Observation.Event; +import java.util.function.Consumer; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; + import org.springframework.data.cassandra.observability.CassandraObservation.Events; import org.springframework.data.cassandra.observability.CassandraObservation.HighCardinalityKeyNames; import org.springframework.lang.Nullable; @@ -50,9 +53,9 @@ public enum ObservationRequestTracker implements RequestTracker { public void onSuccess(Request request, long latencyNanos, DriverExecutionProfile executionProfile, Node node, String requestLogPrefix) { - if (request instanceof CassandraObservationSupplier) { + if (request instanceof CassandraObservationSupplier supplier) { - Observation observation = ((CassandraObservationSupplier) request).getObservation(); + Observation observation = supplier.getObservation(); if (log.isDebugEnabled()) { log.debug("Closing observation [" + observation + "]"); @@ -66,9 +69,9 @@ public void onSuccess(Request request, long latencyNanos, DriverExecutionProfile public void onError(Request request, Throwable error, long latencyNanos, DriverExecutionProfile executionProfile, @Nullable Node node, String requestLogPrefix) { - if (request instanceof CassandraObservationSupplier) { + if (request instanceof CassandraObservationSupplier supplier) { - Observation observation = ((CassandraObservationSupplier) request).getObservation(); + Observation observation = supplier.getObservation(); observation.error(error); if (log.isDebugEnabled()) { @@ -83,22 +86,17 @@ public void onError(Request request, Throwable error, long latencyNanos, DriverE public void onNodeError(Request request, Throwable error, long latencyNanos, DriverExecutionProfile executionProfile, Node node, String requestLogPrefix) { - if (request instanceof CassandraObservationSupplier) { - - Observation observation = ((CassandraObservationSupplier) request).getObservation(); - Context context = observation.getContext(); - - if (context instanceof CassandraObservationContext) { + if (request instanceof CassandraObservationSupplier supplier) { - ((CassandraObservationContext) context).setNode(node); + Observation observation = supplier.getObservation(); + ifContextPresent(observation, CassandraObservationContext.class, context -> context.setNode(node)); - observation.highCardinalityKeyValue( - String.format(HighCardinalityKeyNames.NODE_ERROR_TAG.asString(), node.getEndPoint()), error.toString()); - observation.event(Event.of(Events.NODE_ERROR.getValue())); + observation.highCardinalityKeyValue( + String.format(HighCardinalityKeyNames.NODE_ERROR_TAG.asString(), node.getEndPoint()), error.toString()); + observation.event(Event.of(Events.NODE_ERROR.getValue())); - if (log.isDebugEnabled()) { - log.debug("Marking node error for [" + observation + "]"); - } + if (log.isDebugEnabled()) { + log.debug("Marking node error for [" + observation + "]"); } } } @@ -107,20 +105,15 @@ public void onNodeError(Request request, Throwable error, long latencyNanos, Dri public void onNodeSuccess(Request request, long latencyNanos, DriverExecutionProfile executionProfile, Node node, String requestLogPrefix) { - if (request instanceof CassandraObservationSupplier) { - - Observation observation = ((CassandraObservationSupplier) request).getObservation(); - Context context = observation.getContext(); + if (request instanceof CassandraObservationSupplier supplier) { - if (context instanceof CassandraObservationContext) { + Observation observation = supplier.getObservation(); + ifContextPresent(observation, CassandraObservationContext.class, context -> context.setNode(node)); - ((CassandraObservationContext) context).setNode(node); + observation.event(Event.of(Events.NODE_SUCCESS.getValue())); - observation.event(Event.of(Events.NODE_SUCCESS.getValue())); - - if (log.isDebugEnabled()) { - log.debug("Marking node success for [" + observation + "]"); - } + if (log.isDebugEnabled()) { + log.debug("Marking node success for [" + observation + "]"); } } } @@ -130,4 +123,21 @@ public void close() throws Exception { } + /** + * If the {@link Observation} is a real observation (i.e. not no-op) and the context is of the given type, apply the + * consumer function to the context. + */ + static void ifContextPresent(Observation observation, Class contextType, + Consumer contextConsumer) { + + if (observation.isNoop()) { + return; + } + + Context context = observation.getContext(); + if (contextType.isInstance(context)) { + contextConsumer.accept(contextType.cast(context)); + } + } + } diff --git a/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/observability/ObservationRequestTrackerUnitTests.java b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/observability/ObservationRequestTrackerUnitTests.java new file mode 100644 index 000000000..018911e8e --- /dev/null +++ b/spring-data-cassandra/src/test/java/org/springframework/data/cassandra/observability/ObservationRequestTrackerUnitTests.java @@ -0,0 +1,116 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.cassandra.observability; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import io.micrometer.observation.Observation; + +import java.net.InetSocketAddress; +import java.util.function.Consumer; + +import org.junit.jupiter.api.Test; +import org.mockito.Answers; + +import org.springframework.lang.Nullable; + +import com.datastax.oss.driver.api.core.session.Request; +import com.datastax.oss.driver.internal.core.context.InternalDriverContext; +import com.datastax.oss.driver.internal.core.metadata.DefaultEndPoint; +import com.datastax.oss.driver.internal.core.metadata.DefaultNode; + +/** + * Unit tests for {@link ObservationRequestTracker}. + * + * @author Mark Paluch + */ +class ObservationRequestTrackerUnitTests { + + @Test // GH-1541 + void shouldStopObservation() { + + Request request = mockRequest(null); + + ObservationRequestTracker.INSTANCE.onSuccess(request, 0, null, null, ""); + + verify(((CassandraObservationSupplier) request).getObservation()).stop(); + } + + @Test // GH-1541 + void shouldAssociateNodeWithContext() { + + CassandraObservationContext context = new CassandraObservationContext(null, "foo", false, "foo", "foo", "bar"); + + Request request = mockRequest(context); + InternalDriverContext driverContext = mock(InternalDriverContext.class, Answers.RETURNS_MOCKS); + + DefaultNode node = new DefaultNode(new DefaultEndPoint(InetSocketAddress.createUnresolved("localhost", 1234)), + driverContext); + ObservationRequestTracker.INSTANCE.onNodeSuccess(request, 0, null, node, ""); + + assertThat(context.getNode()).isEqualTo(node); + } + + @Test // GH-1541 + void noOpObservationShouldNotAssociateContext() { + + CassandraObservationContext context = new CassandraObservationContext(null, "foo", false, "foo", "foo", "bar"); + Request request = mockRequest(context, observation -> { + when(observation.isNoop()).thenReturn(true); + }); + InternalDriverContext driverContext = mock(InternalDriverContext.class, Answers.RETURNS_MOCKS); + + DefaultNode node = new DefaultNode(new DefaultEndPoint(InetSocketAddress.createUnresolved("localhost", 1234)), + driverContext); + ObservationRequestTracker.INSTANCE.onNodeSuccess(request, 0, null, node, ""); + + assertThat(context.getNode()).isNull(); + } + + @Test // GH-1541 + void observationWithOtherContextShouldNotAssociateContext() { + + Request request = mockRequest(mock(Observation.Context.class)); + InternalDriverContext driverContext = mock(InternalDriverContext.class, Answers.RETURNS_MOCKS); + + DefaultNode node = new DefaultNode(new DefaultEndPoint(InetSocketAddress.createUnresolved("localhost", 1234)), + driverContext); + + assertThatNoException().isThrownBy(() -> { + ObservationRequestTracker.INSTANCE.onNodeSuccess(request, 0, null, node, ""); + }); + } + + private static Request mockRequest(@Nullable Observation.Context context) { + return mockRequest(context, observation -> {}); + } + + private static Request mockRequest(@Nullable Observation.Context context, + Consumer observationCustomizer) { + + Request request = mock(Request.class, withSettings().extraInterfaces(CassandraObservationSupplier.class)); + + Observation observation = mock(Observation.class); + CassandraObservationSupplier supplier = (CassandraObservationSupplier) request; + when(supplier.getObservation()).thenReturn(observation); + when(observation.getContext()).thenReturn(context); + + observationCustomizer.accept(observation); + + return request; + } +}