Skip to content

Commit

Permalink
Modifying the spring cloud gateway
Browse files Browse the repository at this point in the history
  • Loading branch information
ChengJie1053 committed Nov 10, 2023
1 parent c604486 commit f094821
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 129 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.linkis.gateway.springcloud.constant;

public class GatewayConstant {

public static final String FIXED_INSTANCE = "client-ip";
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.linkis.gateway.security.LinkisPreFilter$;
import org.apache.linkis.gateway.security.SecurityFilter;
import org.apache.linkis.gateway.springcloud.SpringCloudGatewayConfiguration;
import org.apache.linkis.gateway.springcloud.constant.GatewayConstant;
import org.apache.linkis.server.Message;

import org.apache.commons.lang3.StringUtils;
Expand Down Expand Up @@ -118,7 +119,8 @@ private BaseGatewayContext getBaseGatewayContext(ServerWebExchange exchange, Rou
return gatewayContext;
}

private Route getRealRoute(Route route, ServiceInstance serviceInstance) {
private Route getRealRoute(
Route route, ServiceInstance serviceInstance, ServerWebExchange exchange) {
String routeUri = route.getUri().toString();
String scheme = route.getUri().getScheme();
if (routeUri.startsWith(SpringCloudGatewayConfiguration.ROUTE_URI_FOR_WEB_SOCKET_HEADER())) {
Expand All @@ -130,7 +132,10 @@ private Route getRealRoute(Route route, ServiceInstance serviceInstance) {
}
String uri = scheme + serviceInstance.getApplicationName();
if (StringUtils.isNotBlank(serviceInstance.getInstance())) {
uri = scheme + SpringCloudGatewayConfiguration.mergeServiceInstance(serviceInstance);
exchange
.getRequest()
.mutate()
.header(GatewayConstant.FIXED_INSTANCE, serviceInstance.getInstance());
}
return Route.async()
.id(route.getId())
Expand Down Expand Up @@ -196,7 +201,7 @@ private Mono<Void> gatewayDeal(
}
Route route = exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR);
if (serviceInstance != null) {
Route realRoute = getRealRoute(route, serviceInstance);
Route realRoute = getRealRoute(route, serviceInstance, exchange);
exchange.getAttributes().put(ServerWebExchangeUtils.GATEWAY_ROUTE_ATTR, realRoute);
} else {
RouteDefinition realRd = null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.linkis.gateway.springcloud.loadbalancer;

import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClients;
import org.springframework.context.annotation.Configuration;

@Configuration
@LoadBalancerClients(defaultConfiguration = {LinkisLoadBalancerClientConfiguration.class})
public class GatewayLoadBalancerConfiguration {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.linkis.gateway.springcloud.loadbalancer;

import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.loadbalancer.core.ReactorLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.core.env.Environment;

public class LinkisLoadBalancerClientConfiguration {
@Bean
public ReactorLoadBalancer<ServiceInstance> customLoadBalancer(
Environment environment, LoadBalancerClientFactory loadBalancerClientFactory) {
String name = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
return new ServiceInstancePriorityLoadBalancer(
loadBalancerClientFactory.getLazyProvider(name, ServiceInstanceListSupplier.class), name);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.linkis.gateway.springcloud.loadbalancer;

import org.apache.linkis.gateway.springcloud.constant.GatewayConstant;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.*;
import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.SelectedInstanceCallback;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;

import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

import reactor.core.publisher.Mono;

public class ServiceInstancePriorityLoadBalancer implements ReactorServiceInstanceLoadBalancer {

private static final Log log = LogFactory.getLog(ServiceInstancePriorityLoadBalancer.class);
private final String serviceId;

final AtomicInteger position;
private final ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;

public ServiceInstancePriorityLoadBalancer(
ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider,
String serviceId) {
this(serviceInstanceListSupplierProvider, serviceId, (new Random()).nextInt(1000));
}

public ServiceInstancePriorityLoadBalancer(
ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider,
String serviceId,
int seedPosition) {
this.serviceId = serviceId;
this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
this.position = new AtomicInteger(seedPosition);
}

@Override
public Mono<Response<ServiceInstance>> choose(Request request) {
List<String> clientIpList =
((RequestDataContext) request.getContext())
.getClientRequest()
.getHeaders()
.get(GatewayConstant.FIXED_INSTANCE);
String clientIp = CollectionUtils.isNotEmpty(clientIpList) ? clientIpList.get(0) : null;
ServiceInstanceListSupplier supplier =
serviceInstanceListSupplierProvider.getIfAvailable(NoopServiceInstanceListSupplier::new);
return supplier
.get(request)
.next()
.map(serviceInstances -> processInstanceResponse(supplier, serviceInstances, clientIp));
}

private Response<ServiceInstance> processInstanceResponse(
ServiceInstanceListSupplier supplier,
List<ServiceInstance> serviceInstances,
String clientIp) {
Response<ServiceInstance> serviceInstanceResponse =
getInstanceResponse(serviceInstances, clientIp);
if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
((SelectedInstanceCallback) supplier)
.selectedServiceInstance(serviceInstanceResponse.getServer());
}
return serviceInstanceResponse;
}

private Response<ServiceInstance> getInstanceResponse(
List<ServiceInstance> instances, String clientIp) {
if (instances.isEmpty()) {
if (log.isWarnEnabled()) {
log.warn("No servers available for service: " + serviceId);
}
return new EmptyResponse();
}
int pos = this.position.incrementAndGet() & Integer.MAX_VALUE;

if (StringUtils.isEmpty(clientIp)) {
return new DefaultResponse(instances.get(pos % instances.size()));
}
for (ServiceInstance instance : instances) {
String[] ipAndPort = clientIp.split(":");
if (ipAndPort.length == 2
&& Objects.equals(ipAndPort[0], instance.getHost())
&& Objects.equals(ipAndPort[1], instance.getPort())) {
return new DefaultResponse(instance);
}
}

return new DefaultResponse(instances.get(pos % instances.size()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,11 @@ import org.apache.linkis.gateway.springcloud.http.{
}
import org.apache.linkis.gateway.springcloud.websocket.SpringCloudGatewayWebsocketFilter
import org.apache.linkis.rpc.Sender
import org.apache.linkis.rpc.interceptor.ServiceInstanceUtils
import org.apache.linkis.server.conf.ServerConfiguration

import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.autoconfigure.AutoConfigureAfter
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty
import org.springframework.cloud.client
import org.springframework.cloud.client.DefaultServiceInstance
import org.springframework.cloud.client.loadbalancer.LoadBalancerClient
import org.springframework.cloud.gateway.config.{GatewayAutoConfiguration, GatewayProperties}
import org.springframework.cloud.gateway.filter._
Expand All @@ -44,8 +41,6 @@ import org.springframework.cloud.gateway.route.builder.{
PredicateSpec,
RouteLocatorBuilder
}
import org.springframework.cloud.loadbalancer.blocking.client.BlockingLoadBalancerClient
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory
import org.springframework.context.annotation.{Bean, Configuration}
import org.springframework.web.reactive.socket.client.WebSocketClient
import org.springframework.web.reactive.socket.server.WebSocketService
Expand Down Expand Up @@ -133,127 +128,6 @@ class SpringCloudGatewayConfiguration {
)
.build()

// @Bean
// def createLoadBalancerClient(springClientFactory: SpringClientFactory): RibbonLoadBalancerClient =
// new RibbonLoadBalancerClient(springClientFactory) {
//
// override def getServer(serviceId: String): Server = if (isMergeModuleInstance(serviceId)) {
// val serviceInstance = getServiceInstance(serviceId)
// logger.info("redirect to " + serviceInstance)
// val lb = this.getLoadBalancer(serviceInstance.getApplicationName)
// lb.getAllServers.asScala.find(_.getHostPort == serviceInstance.getInstance).get
// } else super.getServer(serviceId)
//
// def isSecure(server: Server, serviceId: String) = {
// val config = springClientFactory.getClientConfig(serviceId)
// val serverIntrospector = serverIntrospectorFun(serviceId)
// RibbonUtils.isSecure(config, serverIntrospector, server)
// }
//
// def serverIntrospectorFun(serviceId: String) = {
// var serverIntrospector =
// springClientFactory.getInstance(serviceId, classOf[ServerIntrospector])
// if (serverIntrospector == null) serverIntrospector = new DefaultServerIntrospector
// serverIntrospector
// }
//
// override def choose(serviceId: String, hint: Any): client.ServiceInstance =
// if (isMergeModuleInstance(serviceId)) {
// val serviceInstance = getServiceInstance(serviceId)
// logger.info("redirect to " + serviceInstance)
// val lb = this.getLoadBalancer(serviceInstance.getApplicationName)
// val serverOption =
// lb.getAllServers.asScala.find(_.getHostPort == serviceInstance.getInstance)
// if (serverOption.isDefined) {
// val server = serverOption.get
// new RibbonLoadBalancerClient.RibbonServer(
// serviceId,
// server,
// isSecure(server, serviceId),
// serverIntrospectorFun(serviceId).getMetadata(server)
// )
// } else {
// logger.warn(
// "RibbonLoadBalancer not have Server, execute default super choose method" + serviceInstance
// )
// super.choose(serviceInstance.getApplicationName, hint)
// }
// } else super.choose(serviceId, hint)
//
// }

@Bean
def createLoadBalancerClient(
loadBalancerClientFactory: LoadBalancerClientFactory
): BlockingLoadBalancerClient =
new BlockingLoadBalancerClient(loadBalancerClientFactory) {

override def choose(serviceId: String): client.ServiceInstance = {
// serviceId = merge-gw-18linkis-cg-entrance192—168—217–172—9104
if (isMergeModuleInstance(serviceId)) {
// serviceInstance = (linkis-cg-entrance,192.168.217.172:9104)
val serviceInstance = getServiceInstance(serviceId)
logger.info("redirect to " + serviceInstance)

val serverOption: Option[ServiceInstance] = ServiceInstanceUtils.getRPCServerLoader
.getServiceInstances(serviceInstance.getApplicationName)
.find(_.getInstance == serviceInstance.getInstance)

if (serverOption.isDefined) {
val server = serverOption.get
// serviceInstance.getApplicationName = linkis-cg-entrance
// super.choose(server.getApplicationName)
// SpringCloudFeignConfigurationCache.getDiscoveryClient
// .getInstances(server.getApplicationName).get(0)
val hostAndPort: Array[String] = server.getInstance.split(":")
new DefaultServiceInstance(
server.getApplicationName,
serviceId,
hostAndPort.head,
hostAndPort.last.toInt,
true
)
} else {
logger.warn(
"BlockingLoadBalancer not have Server, execute default super choose method" + serviceInstance
)
super.choose(serviceInstance.getApplicationName)
}
} else super.choose(serviceId)
}

}

// @Bean
// def createLoadBalancerClient(
// loadBalancerClientFactory: LoadBalancerClientFactory
// ): BlockingLoadBalancerClient =
// new BlockingLoadBalancerClient(loadBalancerClientFactory) {
// override def choose(serviceId: String): client.ServiceInstance = {
// // serviceId = merge-gw-18linkis-cg-entrance192—168—217–172—9104
// if (isMergeModuleInstance(serviceId)) {
// // serviceInstance = (linkis-cg-entrance,192.168.217.172:9104)
// val serviceInstance = getServiceInstance(serviceId)
// logger.info("redirect to " + serviceInstance)
//
// val serverOption: Option[client.ServiceInstance] = SpringCloudFeignConfigurationCache.getDiscoveryClient
// .getInstances(serviceInstance.getApplicationName).iterator()
// .asScala.find(s => s.getHost + ":" + s.getPort == serviceInstance.getInstance)
//
// if (serverOption.isDefined) {
// val server: client.ServiceInstance = serverOption.get
// // serviceInstance.getApplicationName = linkis-cg-entrance
// // super.choose(serviceInstance.getApplicationName)
// super.choose(server.getServiceId)
// } else {
// logger.warn(
// "BlockingLoadBalancer not have Server, execute default super choose method" + serviceInstance
// )
// super.choose(serviceId)
// }
// } else super.choose(serviceId)
// }
// }
@Bean
@ConditionalOnProperty(name = Array("spring.cloud.gateway.url.enabled"), matchIfMissing = true)
def linkisGatewayHttpHeadersFilter(): LinkisGatewayHttpHeadersFilter = {
Expand Down

0 comments on commit f094821

Please sign in to comment.