Skip to content

Commit

Permalink
AccessControl interface accepts Request parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ilamhs committed Nov 7, 2024
1 parent e35f43c commit 3839bac
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.pinot.core.auth.FineGrainedAccessControl;
import org.apache.pinot.spi.annotations.InterfaceAudience;
import org.apache.pinot.spi.annotations.InterfaceStability;
import org.glassfish.grizzly.http.server.Request;


@InterfaceAudience.Public
Expand All @@ -40,6 +41,7 @@ public interface AccessControl extends FineGrainedAccessControl {
* @param endpointUrl the request url for which this access control is called
* @return whether the client has permission
*/
@Deprecated
default boolean hasAccess(@Nullable String tableName, AccessType accessType, HttpHeaders httpHeaders,
String endpointUrl) {
return true;
Expand All @@ -53,10 +55,40 @@ default boolean hasAccess(@Nullable String tableName, AccessType accessType, Htt
* @param endpointUrl the request url for which this access control is called
* @return whether the client has permission
*/
@Deprecated
default boolean hasAccess(AccessType accessType, HttpHeaders httpHeaders, String endpointUrl) {
return hasAccess(null, accessType, httpHeaders, endpointUrl);
}

/**
* Return whether the client has permission to the given table
*
* @param tableName name of the table to be accessed
* @param accessType type of the access
* @param httpHeaders HTTP headers containing requester identity
* @param request the request for which this access control is called
* @param endpointUrl the request url for which this access control is called
* @return whether the client has permission
*/
default boolean hasAccess(@Nullable String tableName, AccessType accessType, HttpHeaders httpHeaders,
@Nullable Request request, @Nullable String endpointUrl) {
return hasAccess(tableName, accessType, httpHeaders, endpointUrl);
}

/**
* Return whether the client has permission to access the endpoints with are not table level
*
* @param accessType type of the access
* @param httpHeaders HTTP headers
* @param request the request for which this access control is called
* @param endpointUrl the request url for which this access control is called
* @return whether the client has permission
*/
default boolean hasAccess(AccessType accessType, HttpHeaders httpHeaders, @Nullable Request request,
@Nullable String endpointUrl) {
return hasAccess(null, accessType, httpHeaders, request, endpointUrl);
}

/**
* Determine whether authentication is required for annotated (controller) endpoints only
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.controller.api.exception.ControllerApplicationException;
import org.apache.pinot.spi.utils.builder.TableNameBuilder;
import org.glassfish.grizzly.http.server.Request;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -43,24 +44,25 @@ private AccessControlUtils() {
/**
* Validate permission for the given access type against the given table
*
* @param tableName name of the table to be accessed (post database name translation)
* @param accessType type of the access
* @param httpHeaders HTTP headers containing requester identity required by access control object
* @param endpointUrl the request url for which this access control is called
* @param tableName name of the table to be accessed (post database name translation)
* @param accessType type of the access
* @param httpHeaders HTTP headers containing requester identity required by access control object
* @param request the request for which this access controll is called
* @param endpointUrl the request url for which this access control is called
* @param accessControl AccessControl object which does the actual validation
*/
public static void validatePermission(@Nullable String tableName, AccessType accessType,
@Nullable HttpHeaders httpHeaders, String endpointUrl, AccessControl accessControl) {
@Nullable HttpHeaders httpHeaders, Request request, String endpointUrl, AccessControl accessControl) {
String userMessage = getUserMessage(tableName, accessType, endpointUrl);
String rawTableName = TableNameBuilder.extractRawTableName(tableName);

try {
if (rawTableName == null) {
if (accessControl.hasAccess(accessType, httpHeaders, endpointUrl)) {
if (accessControl.hasAccess(accessType, httpHeaders, request, endpointUrl)) {
return;
}
} else {
if (accessControl.hasAccess(rawTableName, accessType, httpHeaders, endpointUrl)) {
if (accessControl.hasAccess(rawTableName, accessType, httpHeaders, request, endpointUrl)) {
return;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public void filter(ContainerRequestContext requestContext)
tableName = DatabaseUtils.translateTableName(tableName, _httpHeaders);
}
AccessType accessType = extractAccessType(endpointMethod);
AccessControlUtils.validatePermission(tableName, accessType, _httpHeaders, endpointUrl, accessControl);
AccessControlUtils.validatePermission(tableName, accessType, _httpHeaders, request, endpointUrl, accessControl);

FineGrainedAuthUtils.validateFineGrainedAuth(endpointMethod, uriInfo, _httpHeaders, accessControl);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.pinot.core.auth.BasicAuthPrincipal;
import org.apache.pinot.core.auth.BasicAuthUtils;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.glassfish.grizzly.http.server.Request;


/**
Expand Down Expand Up @@ -77,13 +78,14 @@ public boolean protectAnnotatedOnly() {
}

@Override
public boolean hasAccess(String tableName, AccessType accessType, HttpHeaders httpHeaders, String endpointUrl) {
return getPrincipal(httpHeaders)
.filter(p -> p.hasTable(tableName) && p.hasPermission(Objects.toString(accessType))).isPresent();
public boolean hasAccess(String tableName, AccessType accessType, HttpHeaders httpHeaders, Request request,
String endpointUrl) {
return getPrincipal(httpHeaders).filter(
p -> p.hasTable(tableName) && p.hasPermission(Objects.toString(accessType))).isPresent();
}

@Override
public boolean hasAccess(AccessType accessType, HttpHeaders httpHeaders, String endpointUrl) {
public boolean hasAccess(AccessType accessType, HttpHeaders httpHeaders, Request request, String endpointUrl) {
if (getPrincipal(httpHeaders).isEmpty()) {
throw new NotAuthorizedException("Basic");
}
Expand All @@ -101,8 +103,7 @@ private Optional<BasicAuthPrincipal> getPrincipal(HttpHeaders headers) {
}

return authHeaders.stream().map(org.apache.pinot.common.auth.BasicAuthUtils::normalizeBase64Token)
.map(_token2principal::get)
.filter(Objects::nonNull).findFirst();
.map(_token2principal::get).filter(Objects::nonNull).findFirst();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.pinot.core.auth.ZkBasicAuthPrincipal;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.apache.pinot.spi.utils.builder.TableNameBuilder;
import org.glassfish.grizzly.http.server.Request;


/**
Expand Down Expand Up @@ -79,14 +80,15 @@ public boolean protectAnnotatedOnly() {
}

@Override
public boolean hasAccess(String tableName, AccessType accessType, HttpHeaders httpHeaders, String endpointUrl) {
public boolean hasAccess(String tableName, AccessType accessType, HttpHeaders httpHeaders, Request request,
String endpointUrl) {
return getPrincipal(httpHeaders).filter(
p -> p.hasTable(TableNameBuilder.extractRawTableName(tableName))
&& p.hasPermission(Objects.toString(accessType))).isPresent();
p -> p.hasTable(TableNameBuilder.extractRawTableName(tableName)) && p.hasPermission(
Objects.toString(accessType))).isPresent();
}

@Override
public boolean hasAccess(AccessType accessType, HttpHeaders httpHeaders, String endpointUrl) {
public boolean hasAccess(AccessType accessType, HttpHeaders httpHeaders, Request request, String endpointUrl) {
return getPrincipal(httpHeaders).isPresent();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.pinot.core.auth.Actions;
import org.apache.pinot.core.auth.Authorize;
import org.apache.pinot.core.auth.TargetType;
import org.glassfish.grizzly.http.server.Request;

import static org.apache.pinot.spi.utils.CommonConstants.SWAGGER_AUTHORIZATION_KEY;

Expand All @@ -59,6 +60,9 @@ public class PinotControllerAuthResource {
@Context
HttpHeaders _httpHeaders;

@Context
Request _request;

/**
* Verify a token is both authenticated and authorized to perform an operation.
*
Expand All @@ -81,7 +85,7 @@ public boolean verify(@ApiParam(value = "Table name without type") @QueryParam("
@ApiParam(value = "API access type") @DefaultValue("READ") @QueryParam("accessType") AccessType accessType,
@ApiParam(value = "Endpoint URL") @QueryParam("endpointUrl") String endpointUrl) {
AccessControl accessControl = _accessControlFactory.create();
return accessControl.hasAccess(tableName, accessType, _httpHeaders, endpointUrl);
return accessControl.hasAccess(tableName, accessType, _httpHeaders, _request, endpointUrl);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
import org.apache.pinot.sql.parsers.PinotSqlType;
import org.apache.pinot.sql.parsers.SqlCompilationException;
import org.apache.pinot.sql.parsers.SqlNodeAndOptions;
import org.glassfish.grizzly.http.server.Request;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -101,10 +102,16 @@ public class PinotQueryResource {
@Inject
ControllerConf _controllerConf;

@Context
HttpHeaders _httpHeaders;

@Context
Request _request;

@POST
@Path("sql")
@ManualAuthorization // performed by broker
public String handlePostSql(String requestJsonStr, @Context HttpHeaders httpHeaders) {
public String handlePostSql(String requestJsonStr) {
try {
JsonNode requestJson = JsonUtils.stringToJsonNode(requestJsonStr);
if (!requestJson.has("sql")) {
Expand All @@ -121,7 +128,7 @@ public String handlePostSql(String requestJsonStr, @Context HttpHeaders httpHead
queryOptions = requestJson.get("queryOptions").asText();
}
LOGGER.debug("Trace: {}, Running query: {}", traceEnabled, sqlQuery);
return executeSqlQuery(httpHeaders, sqlQuery, traceEnabled, queryOptions, "/sql");
return executeSqlQuery(_httpHeaders, _request, sqlQuery, traceEnabled, queryOptions, "/sql");
} catch (ProcessingException pe) {
LOGGER.error("Caught exception while processing post request {}", pe.getMessage());
return constructQueryExceptionResponse(pe);
Expand All @@ -138,10 +145,10 @@ public String handlePostSql(String requestJsonStr, @Context HttpHeaders httpHead
@Path("sql")
@ManualAuthorization
public String handleGetSql(@QueryParam("sql") String sqlQuery, @QueryParam("trace") String traceEnabled,
@QueryParam("queryOptions") String queryOptions, @Context HttpHeaders httpHeaders) {
@QueryParam("queryOptions") String queryOptions) {
try {
LOGGER.debug("Trace: {}, Running query: {}", traceEnabled, sqlQuery);
return executeSqlQuery(httpHeaders, sqlQuery, traceEnabled, queryOptions, "/sql");
return executeSqlQuery(_httpHeaders, _request, sqlQuery, traceEnabled, queryOptions, "/sql");
} catch (ProcessingException pe) {
LOGGER.error("Caught exception while processing get request {}", pe.getMessage());
return constructQueryExceptionResponse(pe);
Expand All @@ -154,7 +161,7 @@ public String handleGetSql(@QueryParam("sql") String sqlQuery, @QueryParam("trac
}
}

private String executeSqlQuery(@Context HttpHeaders httpHeaders, String sqlQuery, String traceEnabled,
private String executeSqlQuery(HttpHeaders httpHeaders, Request request, String sqlQuery, String traceEnabled,
@Nullable String queryOptions, String endpointUrl)
throws Exception {
SqlNodeAndOptions sqlNodeAndOptions;
Expand All @@ -173,15 +180,16 @@ private String executeSqlQuery(@Context HttpHeaders httpHeaders, String sqlQuery
if (Boolean.parseBoolean(options.get(QueryOptionKey.USE_MULTISTAGE_ENGINE))) {
if (_controllerConf.getProperty(CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_ENABLED,
CommonConstants.Helix.DEFAULT_MULTI_STAGE_ENGINE_ENABLED)) {
return getMultiStageQueryResponse(sqlQuery, queryOptions, httpHeaders, endpointUrl, traceEnabled);
return getMultiStageQueryResponse(sqlQuery, queryOptions, httpHeaders, request, endpointUrl, traceEnabled);
} else {
throw QueryException.getException(QueryException.INTERNAL_ERROR, "V2 Multi-Stage query engine not enabled.");
}
} else {
PinotSqlType sqlType = sqlNodeAndOptions.getSqlType();
switch (sqlType) {
case DQL:
return getQueryResponse(sqlQuery, sqlNodeAndOptions.getSqlNode(), traceEnabled, queryOptions, httpHeaders);
return getQueryResponse(sqlQuery, sqlNodeAndOptions.getSqlNode(), traceEnabled, queryOptions, httpHeaders,
request);
case DML:
Map<String, String> headers =
httpHeaders.getRequestHeaders().entrySet().stream().filter(entry -> !entry.getValue().isEmpty())
Expand All @@ -194,14 +202,14 @@ private String executeSqlQuery(@Context HttpHeaders httpHeaders, String sqlQuery
}
}

private String getMultiStageQueryResponse(String query, String queryOptions, HttpHeaders httpHeaders,
private String getMultiStageQueryResponse(String query, String queryOptions, HttpHeaders httpHeaders, Request request,
String endpointUrl, String traceEnabled)
throws ProcessingException {

// Validate data access
// we don't have a cross table access control rule so only ADMIN can make request to multi-stage engine.
AccessControl accessControl = _accessControlFactory.create();
if (!accessControl.hasAccess(AccessType.READ, httpHeaders, endpointUrl)) {
if (!accessControl.hasAccess(AccessType.READ, httpHeaders, request, endpointUrl)) {
throw new WebApplicationException("Permission denied", Response.Status.FORBIDDEN);
}

Expand Down Expand Up @@ -253,7 +261,7 @@ private String getMultiStageQueryResponse(String query, String queryOptions, Htt
}

private String getQueryResponse(String query, @Nullable SqlNode sqlNode, String traceEnabled, String queryOptions,
HttpHeaders httpHeaders)
HttpHeaders httpHeaders, Request request)
throws ProcessingException {
// Get resource table name.
String tableName;
Expand Down Expand Up @@ -289,7 +297,7 @@ private String getQueryResponse(String query, @Nullable SqlNode sqlNode, String

// Validate data access
AccessControl accessControl = _accessControlFactory.create();
if (!accessControl.hasAccess(rawTableName, AccessType.READ, httpHeaders, Actions.Table.QUERY)) {
if (!accessControl.hasAccess(rawTableName, AccessType.READ, httpHeaders, request, Actions.Table.QUERY)) {
return QueryException.ACCESS_DENIED_ERROR.toString();
}

Expand Down
Loading

0 comments on commit 3839bac

Please sign in to comment.