-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
这个项目支持校验某个用户对某个表是否具有增删改权限吗? #17
Comments
只是解析sql,可以结合解析出来的信息,自己实现权限校验。例如我们的实现: /**
* @TODO
* 1、支持存储过程权限校验
* 2、支持flink cdc 权限校验
*/
@Service
public class AuthorizationService {
private static final Logger LOG = LoggerFactory.getLogger(AuthorizationService.class);
@Autowired
private TableAccessLogService tableAccessLogService;
@Autowired
private UserInfoService userInfoService;
@Autowired
private SuperiorBeeConfigClient configClient;
@Autowired
private FunctionService functionService;
@Autowired
private DataSourceService dataSourceService;
@Autowired
private TableService tableService;
@Autowired
private WorkspaceService workspaceService;
@Autowired
private SecTablePrivsService tablePrivsService;
@Autowired
private SecDatabasePrivsService databasePrivsService;
@Transactional
public void checkAuthority(AuthContext context, String sql, String[] sparkTempTables) {
String userId = context.getUserId();
if (StringUtils.isBlank(context.getUserId())) {
throw new IllegalArgumentException("userId can not empty");
}
Statement statement = SparkSqlHelper.parseStatement(sql);
boolean supportedSql = SparkSqlHelper.checkSupportedSQL(statement.getStatementType());
if (!supportedSql) {
throw new SQLParserException("not support sql: " + sql);
}
String[] superTableOwners = configClient.getStringArray(SuperiorConf.SUPERIOR_SKIP_TABLE_AUTH_CHECK_USERS);
if (!ArrayUtils.contains(superTableOwners, userId)) {
// 校验sql执行权限
this.checkAuthority(context, statement, sparkTempTables);
}
}
/**
* 校验用户是否用执行sql语句权限
* 创建表:
* 1. 检测用户是否数据工作空间成员
* 删除表:
* 1. 检测用户是否为表owner
* 修改表 & 添加列 & 修改列:
* 1. 检测用户是否为表owner
* 查询表:
* 1. select table 是否为owner,或者有insert权限
* 写入表:
* 1. insertInstanceDependent into table 是否为owner,或者有insert权限
* 2. select table 是否为owner,或者有insert权限
*/
@Transactional(rollbackFor = Exception.class)
public void checkAuthority(AuthContext context, Statement statement, String[] sparkTempTables) {
final StatementType statementType = statement.getStatementType();
if (SHOW == statementType) {
return;
}
final PrivilegeType privilegeType = statement.getPrivilegeType();
if (CREATE_TABLE == statementType || CREATE_TABLE_AS_LIKE == statementType) {
CreateTable createTable = (CreateTable) statement;
checkAccessTableAuth(context, createTable.getTableId(), privilegeType);
} else if (CREATE_VIEW == statementType) {
CreateView createView = (CreateView) statement;
checkAccessTableAuth(context, createView.getTableId(), privilegeType);
QueryStmt queryStmt = createView.getQueryStmt();
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
} else if (CREATE_TABLE_AS_SELECT == statementType) {
CreateTableAsSelect tableAsSelect = (CreateTableAsSelect) statement;
checkAccessTableAuth(context, tableAsSelect.getTableId(), privilegeType);
QueryStmt queryStmt = tableAsSelect.getQueryStmt();
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
} else if (DROP_TABLE == statementType) {
DropTable dropTable = (DropTable) statement;
TableId tableId = dropTable.getTableId();
checkAccessTableAuth(context, tableId, privilegeType, dropTable.getIfExists());
} else if (DROP_VIEW == statementType) {
DropView view = (DropView) statement;
TableId tableId = view.getTableId();
checkAccessTableAuth(context, tableId, privilegeType, view.getIfExists());
} else if (TRUNCATE_TABLE == statementType) {
TruncateTable truncateTable = (TruncateTable) statement;
TableId tableId = truncateTable.getTableId();
checkAccessTableAuth(context, tableId, privilegeType);
} else if (SELECT == statementType) {
QueryStmt queryStmt = (QueryStmt) statement;
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
} else if (INSERT == statementType) { // 多路输出
InsertTable multiInsertStmt = (InsertTable) statement;
QueryStmt queryStmt = multiInsertStmt.getQueryStmt();
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
for (TableId tableId : multiInsertStmt.getOutputTables()) {
checkAccessTableAuth(context, tableId, privilegeType);
}
addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
} else if (DELETE == statementType) {
DeleteTable deleteTable = (DeleteTable) statement;
checkAccessTableAuth(context, deleteTable.getTableId(), privilegeType);
} else if (UPDATE == statementType) {
UpdateTable updateTable = (UpdateTable) statement;
checkAccessTableAuth(context, updateTable.getTableId(), privilegeType);
} else if (MERGE == statementType) {
MergeTable mergeIntoTable = (MergeTable) statement;
mergeIntoTable.getInputTables().forEach(tableId -> {
boolean sparkTempTable = isSparkTempTable(context.getCurrentDatabase(), sparkTempTables, tableId);
if (!sparkTempTable) {
checkAccessTableAuth(context, tableId, privilegeType);
}
});
checkAccessTableAuth(context, mergeIntoTable.getTargetTable(), privilegeType);
} else if (EXPORT_TABLE == statementType) { // export table
ExportTable tableData = (ExportTable) statement;
for (TableId tableId : tableData.getInputTables()) {
checkAccessTableAuth(context, tableId, privilegeType);
}
checkFunctionAuth(context, tableData.getFunctionNames());
} else if (DATATUNNEL == statementType && statement instanceof DataTunnelExpr) {
DataTunnelExpr dataTunnelExpr = (DataTunnelExpr) statement;
for (TableId tableId : dataTunnelExpr.getInputTables()) {
checkAccessTableAuth(context, tableId, privilegeType);
}
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
checkDatatunnelAuthority(tenantId, regionCode, userId, dataTunnelExpr);
checkFunctionAuth(context, dataTunnelExpr.getFunctionNames());
} else if (CACHE == statementType) { // spark cache
CacheTable cacheTable = (CacheTable) statement;
if (cacheTable.getQueryStmt() != null) {
QueryStmt queryStmt = cacheTable.getQueryStmt();
this.checkAuthority(context, queryStmt, sparkTempTables);
} else {
checkAccessTableAuth(context, cacheTable.getTableId(), privilegeType);
}
} else if (ALTER_TABLE == statementType) {
checkAlterTableAuthority(context, statement, sparkTempTables, privilegeType);
} else if (CALL == statementType) {
CallProcedure procedure = (CallProcedure) statement;
if (procedure.getProperties().containsKey("table")) {
String tableId = procedure.getProperties().get("table");
String[] items = StringUtils.split(tableId, ".");
String databaseName = null;
String tableName = null;
if (items.length == 1) {
tableName = items[0];
} else if (items.length == 2) {
databaseName = items[0];
tableName = items[1];
} else {
throw new SuperiorException("Unsupported identifier " + tableId);
}
checkAdminTableAuthority(context, databaseName, tableName);
}
}
}
private void checkAdminTableAuthority(AuthContext context, String workspaceCode, String tableName) {
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
String currentCatalog = context.getCurrentCatalog();
String currentDatabase = context.getCurrentDatabase();
String catalogName = workspaceService.getCatalogName(tenantId, regionCode, workspaceCode, currentCatalog);
String databaseName = CommonUtils.getDatabaseName(workspaceCode, currentDatabase);
tableName = StringUtils.lowerCase(tableName);
TableEntity table = tableService.queryTable(tenantId, regionCode, catalogName, databaseName, tableName);
if (table == null) {
String msg = String.format("table not exist: %s.%s.%s", catalogName, databaseName, tableName);
throw new AccessTableException(msg);
}
if (!tablePrivsService.checkOwner(table, userId)) {
String msg = String.format("%s 不是表owner: %s.%s.%s", userId, catalogName, databaseName, tableName);
throw new AccessTableException(msg);
}
}
private void checkDatatunnelAuthority(
Long tenantId, String regionCode, String userId, DataTunnelExpr dataTunnelExpr) {
Map<String, Object> sourceOptions = dataTunnelExpr.getSourceOptions();
if (sourceOptions.containsKey("datasource")) {
String datasource = (String) sourceOptions.get("datasource");
String catalogName = dataSourceService.queryCatalogName(tenantId, regionCode, datasource);
String schemaName = (String) sourceOptions.getOrDefault("databaseName", null);
if (schemaName == null) {
schemaName = (String) sourceOptions.get("schemaName");
}
String tableName = (String) sourceOptions.get("tableName");
AuthContext context = new AuthContext(tenantId, regionCode, userId, HIVE, catalogName, schemaName);
if (StringUtils.isNotBlank(tableName) && StringUtils.isNotBlank(schemaName)) {
checkAccessTableAuth(context, new TableId(tableName), PrivilegeType.READ);
}
}
Map<String, Object> sinkOptions = dataTunnelExpr.getSinkOptions();
if (sinkOptions.containsKey("datasource")) {
String datasource = (String) sinkOptions.get("datasource");
String catalogName = dataSourceService.queryCatalogName(tenantId, regionCode, datasource);
String schemaName = (String) sinkOptions.getOrDefault("databaseName", null);
if (schemaName == null) {
schemaName = (String) sinkOptions.get("schemaName");
}
String tableName = (String) sinkOptions.get("tableName");
AuthContext context = new AuthContext(tenantId, regionCode, userId, HIVE, catalogName, schemaName);
if (StringUtils.isNotBlank(tableName) && StringUtils.isNotBlank(schemaName)) {
checkAccessTableAuth(context, new TableId(tableName), PrivilegeType.WRITE);
}
}
}
private void checkAlterTableAuthority(
AuthContext context, Statement statement, String[] sparkTempTables, PrivilegeType privilegeType) {
AlterTable alterTable = (AlterTable) statement;
TableId tableId = alterTable.getTableId();
checkAccessTableAuth(context, tableId, privilegeType, alterTable.getIfExists());
AlterActionType alterType = alterTable.getFirstAlterType();
if (AlterActionType.ALTER_VIEW_QUERY == alterType) {
AlterViewAction view = (AlterViewAction) alterTable.firstAction();
QueryStmt queryStmt = view.getQueryStmt();
checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
checkFunctionAuth(context, queryStmt.getFunctionNames());
}
}
/**
* 记录用户查询表信息
*/
private void addTabAccessLog(AuthContext context, List<TableId> inputTables) {
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
String currentCatalog = context.getCurrentCatalog();
String currentDatabase = context.getCurrentDatabase();
UserInfoEntity userInfoEntity = userInfoService.queryUser(tenantId, userId);
if (userInfoEntity == null) {
return;
}
for (TableId table : inputTables) {
String userName = userInfoEntity.getCnName();
String catalogName = CommonUtils.getCatalogName(currentCatalog, table.getCatalogName());
String databaseName = CommonUtils.getDatabaseName(currentDatabase, table.getSchemaName());
String tableName = table.getTableName().toLowerCase();
this.tableAccessLogService.update(
tenantId, regionCode, catalogName, databaseName, tableName, userName, userId);
}
}
private void checkSelectTableAuth(AuthContext context, List<TableId> inputTables, String[] sparkTempTables) {
DataSourceType dataSourceType = context.getDataSourceType();
for (TableId tableId : inputTables) {
if (ORACLE == dataSourceType || DAMENG == dataSourceType || OCEANBASE == dataSourceType) {
if (StringUtils.equalsIgnoreCase("dual", tableId.getTableName())) {
continue;
}
}
boolean sparkTempTable = isSparkTempTable(context.getCurrentDatabase(), sparkTempTables, tableId);
if (!sparkTempTable) {
checkAccessTableAuth(context, tableId, PrivilegeType.READ);
}
}
}
private void checkFunctionAuth(AuthContext context, HashSet<FunctionId> functionNames) {
if (functionNames == null) {
return;
}
LOG.info("context: {} function names: {}", context, StringUtils.join(functionNames, ","));
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
String currentCatalog = context.getCurrentCatalog();
String currentDatabase = context.getCurrentDatabase();
for (FunctionId functionId : functionNames) {
String database = functionId.getSchemaName();
String funcName = functionId.getFunctionName();
database = database == null ? context.getCurrentDatabase() : database;
database = StringUtils.lowerCase(database);
funcName = StringUtils.lowerCase(funcName);
FunctionEntity function =
functionService.queryFunction(tenantId, regionCode, currentCatalog, database, funcName);
if (function != null) {
if (AUTH_ONESELF.equals(function.getAuthType())) {
if (!userId.equals(function.getCreater())) {
throw new AccessFunctionException("无权访问函数: {}, 函数访问范围:仅个人可用", funcName);
}
} else if (AUTH_WORKSPACE_USERS.equals(function.getAuthType())) {
if (StringUtils.isNotBlank(database) && !currentDatabase.equals(database)) {
throw new AccessFunctionException("无权访问函数: , 函数访问范围:仅项目空间成员可用", funcName);
}
} else if (AUTH_ASSIGN_USERS.equals(function.getAuthType())) {
if (!function.getAuthUsers().contains(userId)) {
throw new AccessFunctionException("无权访问函数: {}, 函数访问范围:指定用户可用", funcName);
}
}
}
}
}
private boolean isSparkTempTable(String currentDatabaseName, String[] sparkTempTables, TableId tableId) {
boolean sparkTempTable = false;
if (sparkTempTables != null && sparkTempTables.length > 0) {
String schemaName = CommonUtils.getDatabaseName(currentDatabaseName, tableId.getSchemaName());
String tableName = StringUtils.lowerCase(tableId.getTableName());
if (StringUtils.equalsIgnoreCase(currentDatabaseName, schemaName)) {
for (String name : sparkTempTables) {
if (StringUtils.equalsIgnoreCase(name, tableName)) {
sparkTempTable = true;
break;
}
}
}
}
return sparkTempTable;
}
@Transactional
public void checkAccessTableAuth(AuthContext authContext, TableId tableId, PrivilegeType privilegeType) {
this.checkAccessTableAuth(authContext, tableId, privilegeType, false);
}
private void checkAccessTableAuth(
AuthContext context, TableId tableId, PrivilegeType privilegeType, boolean ifExists) {
long tenantId = context.getTenantId();
String regionCode = context.getRegionCode();
String userId = context.getUserId();
String catalogName = CommonUtils.getCatalogName(context.getCurrentCatalog(), tableId.getCatalogName());
String databaseName = CommonUtils.getDatabaseName(context.getCurrentDatabase(), tableId.getSchemaName());
String tableName = tableId.getTableName();
// 访问paimon 系统表。格式:SELECT * FROM hive_metastore.bigdata.paimon_users_ods$schemas
String paimonSysTableName = StringUtils.substringAfterLast(tableName, "$");
if (StringUtils.isNotBlank(paimonSysTableName)
&& ArrayUtils.contains(PAIMON_SYS_TABLES, paimonSysTableName.toLowerCase())) {
tableName = StringUtils.substringBeforeLast(tableName, "$");
}
if (privilegeType != PrivilegeType.CREATE) {
TableEntity table = tableService.queryTable(tenantId, regionCode, catalogName, databaseName, tableName);
if (table == null) {
if (ifExists) {
return; // alter table if exists 语句,找不到表,不报错
}
throw new AccessTableException("table not exist: {}.{}.{}", catalogName, databaseName, tableName);
}
if (tablePrivsService.checkOwner(table, userId)) {
return;
}
}
if (StringUtils.isBlank(databaseName)) {
throw new SuperiorException("databaseName can not blank");
}
List<SecTablePrivsEntity> tablePrivsList =
tablePrivsService.queryTablePrivs(tenantId, userId, catalogName, databaseName, tableName);
List<SecDatabasePrivsEntity> databasePrivsList = null;
if (tablePrivsList.isEmpty()) {
databasePrivsList = databasePrivsService.queryDatabasePrivs(tenantId, userId, catalogName, databaseName);
if (databasePrivsList.isEmpty()) {
throw new AccessTableException(
"{} 没有申请表: {}.{}.{} 使用权限: {}", userId, catalogName, databaseName, tableName, privilegeType);
}
}
for (SecTablePrivsEntity tablePrivs : tablePrivsList) { // 权限审核通过
LocalDate currentDate = LocalDate.now();
// 权限过期
if (currentDate.isAfter(tablePrivs.getExpireDate())) {
if (tablePrivs.getStatus() == 1 || tablePrivs.getStatus() == 15) {
tablePrivs.setStatus(9);
tablePrivsService.updateEntity(tablePrivs);
}
}
switch (privilegeType) {
case READ:
if (tablePrivs.isReadPriv()) {
return;
}
break;
case WRITE:
if (tablePrivs.isWritePriv()) {
return;
}
break;
case ALTER:
if (tablePrivs.isAlterPriv()) {
return;
}
break;
case DROP:
if (tablePrivs.isDropPriv()) {
return;
}
break;
default:
throw new AccessTableException("not support " + privilegeType);
}
}
if (databasePrivsList != null && !databasePrivsList.isEmpty()) {
for (SecDatabasePrivsEntity databasePrivs : databasePrivsList) { // 权限审核通过
LocalDate currentDate = LocalDate.now();
// 权限过期
if (currentDate.isAfter(databasePrivs.getExpireDate())) {
if (databasePrivs.getStatus() == 15) {
databasePrivs.setStatus(9);
databasePrivsService.updateEntity(databasePrivs);
}
}
switch (privilegeType) {
case CREATE:
if (databasePrivs.isCreatePriv()) {
return;
}
break;
case READ:
if (databasePrivs.isReadPriv()) {
return;
}
break;
case WRITE:
if (databasePrivs.isWritePriv()) {
return;
}
break;
case ALTER:
if (databasePrivs.isAlterPriv()) {
return;
}
break;
case DROP:
if (databasePrivs.isDropPriv()) {
return;
}
break;
default:
throw new AccessTableException("not support " + privilegeType);
}
}
}
String msg = String.format(
"%s 没有申请表: %s.%s.%s 使用权限: %s", userId, catalogName, databaseName, tableName, privilegeType);
throw new AccessTableException(msg);
}
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
No description provided.
The text was updated successfully, but these errors were encountered: