Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

这个项目支持校验某个用户对某个表是否具有增删改权限吗? #17

Open
KuanKuanya opened this issue Jul 25, 2024 · 1 comment

Comments

@KuanKuanya
Copy link

No description provided.

@melin
Copy link
Owner

melin commented Jul 25, 2024

只是解析sql,可以结合解析出来的信息,自己实现权限校验。例如我们的实现:

/**
 * @TODO
 * 1、支持存储过程权限校验
 * 2、支持flink cdc 权限校验
 */
@Service
public class AuthorizationService {

    private static final Logger LOG = LoggerFactory.getLogger(AuthorizationService.class);

    @Autowired
    private TableAccessLogService tableAccessLogService;

    @Autowired
    private UserInfoService userInfoService;

    @Autowired
    private SuperiorBeeConfigClient configClient;

    @Autowired
    private FunctionService functionService;

    @Autowired
    private DataSourceService dataSourceService;

    @Autowired
    private TableService tableService;

    @Autowired
    private WorkspaceService workspaceService;

    @Autowired
    private SecTablePrivsService tablePrivsService;

    @Autowired
    private SecDatabasePrivsService databasePrivsService;

    @Transactional
    public void checkAuthority(AuthContext context, String sql, String[] sparkTempTables) {
        String userId = context.getUserId();
        if (StringUtils.isBlank(context.getUserId())) {
            throw new IllegalArgumentException("userId can not empty");
        }

        Statement statement = SparkSqlHelper.parseStatement(sql);
        boolean supportedSql = SparkSqlHelper.checkSupportedSQL(statement.getStatementType());
        if (!supportedSql) {
            throw new SQLParserException("not support sql: " + sql);
        }

        String[] superTableOwners = configClient.getStringArray(SuperiorConf.SUPERIOR_SKIP_TABLE_AUTH_CHECK_USERS);
        if (!ArrayUtils.contains(superTableOwners, userId)) {
            // 校验sql执行权限
            this.checkAuthority(context, statement, sparkTempTables);
        }
    }

    /**
     * 校验用户是否用执行sql语句权限
     * 创建表:
     * 1. 检测用户是否数据工作空间成员
     * 删除表:
     * 1. 检测用户是否为表owner
     * 修改表 & 添加列 & 修改列:
     * 1. 检测用户是否为表owner
     * 查询表:
     * 1. select table 是否为owner,或者有insert权限
     * 写入表:
     * 1. insertInstanceDependent into table 是否为owner,或者有insert权限
     * 2. select table 是否为owner,或者有insert权限
     */
    @Transactional(rollbackFor = Exception.class)
    public void checkAuthority(AuthContext context, Statement statement, String[] sparkTempTables) {
        final StatementType statementType = statement.getStatementType();
        if (SHOW == statementType) {
            return;
        }
        final PrivilegeType privilegeType = statement.getPrivilegeType();

        if (CREATE_TABLE == statementType || CREATE_TABLE_AS_LIKE == statementType) {
            CreateTable createTable = (CreateTable) statement;
            checkAccessTableAuth(context, createTable.getTableId(), privilegeType);
        } else if (CREATE_VIEW == statementType) {
            CreateView createView = (CreateView) statement;
            checkAccessTableAuth(context, createView.getTableId(), privilegeType);
            QueryStmt queryStmt = createView.getQueryStmt();
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());
        } else if (CREATE_TABLE_AS_SELECT == statementType) {
            CreateTableAsSelect tableAsSelect = (CreateTableAsSelect) statement;
            checkAccessTableAuth(context, tableAsSelect.getTableId(), privilegeType);
            QueryStmt queryStmt = tableAsSelect.getQueryStmt();
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());
            addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
        } else if (DROP_TABLE == statementType) {
            DropTable dropTable = (DropTable) statement;
            TableId tableId = dropTable.getTableId();
            checkAccessTableAuth(context, tableId, privilegeType, dropTable.getIfExists());
        } else if (DROP_VIEW == statementType) {
            DropView view = (DropView) statement;
            TableId tableId = view.getTableId();
            checkAccessTableAuth(context, tableId, privilegeType, view.getIfExists());
        } else if (TRUNCATE_TABLE == statementType) {
            TruncateTable truncateTable = (TruncateTable) statement;
            TableId tableId = truncateTable.getTableId();
            checkAccessTableAuth(context, tableId, privilegeType);
        } else if (SELECT == statementType) {
            QueryStmt queryStmt = (QueryStmt) statement;
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());
            addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
        } else if (INSERT == statementType) { // 多路输出
            InsertTable multiInsertStmt = (InsertTable) statement;
            QueryStmt queryStmt = multiInsertStmt.getQueryStmt();
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());

            for (TableId tableId : multiInsertStmt.getOutputTables()) {
                checkAccessTableAuth(context, tableId, privilegeType);
            }

            addTabAccessLog(context, queryStmt.getInputTables()); // 插入查询记录
        } else if (DELETE == statementType) {
            DeleteTable deleteTable = (DeleteTable) statement;
            checkAccessTableAuth(context, deleteTable.getTableId(), privilegeType);
        } else if (UPDATE == statementType) {
            UpdateTable updateTable = (UpdateTable) statement;
            checkAccessTableAuth(context, updateTable.getTableId(), privilegeType);
        } else if (MERGE == statementType) {
            MergeTable mergeIntoTable = (MergeTable) statement;
            mergeIntoTable.getInputTables().forEach(tableId -> {
                boolean sparkTempTable = isSparkTempTable(context.getCurrentDatabase(), sparkTempTables, tableId);

                if (!sparkTempTable) {
                    checkAccessTableAuth(context, tableId, privilegeType);
                }
            });

            checkAccessTableAuth(context, mergeIntoTable.getTargetTable(), privilegeType);
        } else if (EXPORT_TABLE == statementType) { // export table
            ExportTable tableData = (ExportTable) statement;
            for (TableId tableId : tableData.getInputTables()) {
                checkAccessTableAuth(context, tableId, privilegeType);
            }
            checkFunctionAuth(context, tableData.getFunctionNames());
        } else if (DATATUNNEL == statementType && statement instanceof DataTunnelExpr) {
            DataTunnelExpr dataTunnelExpr = (DataTunnelExpr) statement;
            for (TableId tableId : dataTunnelExpr.getInputTables()) {
                checkAccessTableAuth(context, tableId, privilegeType);
            }

            long tenantId = context.getTenantId();
            String regionCode = context.getRegionCode();
            String userId = context.getUserId();
            checkDatatunnelAuthority(tenantId, regionCode, userId, dataTunnelExpr);
            checkFunctionAuth(context, dataTunnelExpr.getFunctionNames());
        } else if (CACHE == statementType) { // spark cache
            CacheTable cacheTable = (CacheTable) statement;
            if (cacheTable.getQueryStmt() != null) {
                QueryStmt queryStmt = cacheTable.getQueryStmt();
                this.checkAuthority(context, queryStmt, sparkTempTables);
            } else {
                checkAccessTableAuth(context, cacheTable.getTableId(), privilegeType);
            }

        } else if (ALTER_TABLE == statementType) {
            checkAlterTableAuthority(context, statement, sparkTempTables, privilegeType);
        } else if (CALL == statementType) {
            CallProcedure procedure = (CallProcedure) statement;
            if (procedure.getProperties().containsKey("table")) {
                String tableId = procedure.getProperties().get("table");
                String[] items = StringUtils.split(tableId, ".");
                String databaseName = null;
                String tableName = null;
                if (items.length == 1) {
                    tableName = items[0];
                } else if (items.length == 2) {
                    databaseName = items[0];
                    tableName = items[1];
                } else {
                    throw new SuperiorException("Unsupported identifier " + tableId);
                }
                checkAdminTableAuthority(context, databaseName, tableName);
            }
        }
    }

    private void checkAdminTableAuthority(AuthContext context, String workspaceCode, String tableName) {

        long tenantId = context.getTenantId();
        String regionCode = context.getRegionCode();
        String userId = context.getUserId();
        String currentCatalog = context.getCurrentCatalog();
        String currentDatabase = context.getCurrentDatabase();

        String catalogName = workspaceService.getCatalogName(tenantId, regionCode, workspaceCode, currentCatalog);
        String databaseName = CommonUtils.getDatabaseName(workspaceCode, currentDatabase);
        tableName = StringUtils.lowerCase(tableName);

        TableEntity table = tableService.queryTable(tenantId, regionCode, catalogName, databaseName, tableName);
        if (table == null) {
            String msg = String.format("table not exist: %s.%s.%s", catalogName, databaseName, tableName);
            throw new AccessTableException(msg);
        }
        if (!tablePrivsService.checkOwner(table, userId)) {
            String msg = String.format("%s 不是表owner: %s.%s.%s", userId, catalogName, databaseName, tableName);
            throw new AccessTableException(msg);
        }
    }

    private void checkDatatunnelAuthority(
            Long tenantId, String regionCode, String userId, DataTunnelExpr dataTunnelExpr) {
        Map<String, Object> sourceOptions = dataTunnelExpr.getSourceOptions();
        if (sourceOptions.containsKey("datasource")) {
            String datasource = (String) sourceOptions.get("datasource");
            String catalogName = dataSourceService.queryCatalogName(tenantId, regionCode, datasource);
            String schemaName = (String) sourceOptions.getOrDefault("databaseName", null);
            if (schemaName == null) {
                schemaName = (String) sourceOptions.get("schemaName");
            }
            String tableName = (String) sourceOptions.get("tableName");

            AuthContext context = new AuthContext(tenantId, regionCode, userId, HIVE, catalogName, schemaName);
            if (StringUtils.isNotBlank(tableName) && StringUtils.isNotBlank(schemaName)) {
                checkAccessTableAuth(context, new TableId(tableName), PrivilegeType.READ);
            }
        }

        Map<String, Object> sinkOptions = dataTunnelExpr.getSinkOptions();
        if (sinkOptions.containsKey("datasource")) {
            String datasource = (String) sinkOptions.get("datasource");
            String catalogName = dataSourceService.queryCatalogName(tenantId, regionCode, datasource);
            String schemaName = (String) sinkOptions.getOrDefault("databaseName", null);
            if (schemaName == null) {
                schemaName = (String) sinkOptions.get("schemaName");
            }
            String tableName = (String) sinkOptions.get("tableName");

            AuthContext context = new AuthContext(tenantId, regionCode, userId, HIVE, catalogName, schemaName);
            if (StringUtils.isNotBlank(tableName) && StringUtils.isNotBlank(schemaName)) {
                checkAccessTableAuth(context, new TableId(tableName), PrivilegeType.WRITE);
            }
        }
    }

    private void checkAlterTableAuthority(
            AuthContext context, Statement statement, String[] sparkTempTables, PrivilegeType privilegeType) {
        AlterTable alterTable = (AlterTable) statement;
        TableId tableId = alterTable.getTableId();

        checkAccessTableAuth(context, tableId, privilegeType, alterTable.getIfExists());

        AlterActionType alterType = alterTable.getFirstAlterType();
        if (AlterActionType.ALTER_VIEW_QUERY == alterType) {
            AlterViewAction view = (AlterViewAction) alterTable.firstAction();
            QueryStmt queryStmt = view.getQueryStmt();
            checkSelectTableAuth(context, queryStmt.getInputTables(), sparkTempTables);
            checkFunctionAuth(context, queryStmt.getFunctionNames());
        }
    }

    /**
     * 记录用户查询表信息
     */
    private void addTabAccessLog(AuthContext context, List<TableId> inputTables) {

        long tenantId = context.getTenantId();
        String regionCode = context.getRegionCode();
        String userId = context.getUserId();
        String currentCatalog = context.getCurrentCatalog();
        String currentDatabase = context.getCurrentDatabase();

        UserInfoEntity userInfoEntity = userInfoService.queryUser(tenantId, userId);
        if (userInfoEntity == null) {
            return;
        }
        for (TableId table : inputTables) {
            String userName = userInfoEntity.getCnName();
            String catalogName = CommonUtils.getCatalogName(currentCatalog, table.getCatalogName());
            String databaseName = CommonUtils.getDatabaseName(currentDatabase, table.getSchemaName());
            String tableName = table.getTableName().toLowerCase();
            this.tableAccessLogService.update(
                    tenantId, regionCode, catalogName, databaseName, tableName, userName, userId);
        }
    }

    private void checkSelectTableAuth(AuthContext context, List<TableId> inputTables, String[] sparkTempTables) {
        DataSourceType dataSourceType = context.getDataSourceType();
        for (TableId tableId : inputTables) {
            if (ORACLE == dataSourceType || DAMENG == dataSourceType || OCEANBASE == dataSourceType) {
                if (StringUtils.equalsIgnoreCase("dual", tableId.getTableName())) {
                    continue;
                }
            }

            boolean sparkTempTable = isSparkTempTable(context.getCurrentDatabase(), sparkTempTables, tableId);

            if (!sparkTempTable) {
                checkAccessTableAuth(context, tableId, PrivilegeType.READ);
            }
        }
    }

    private void checkFunctionAuth(AuthContext context, HashSet<FunctionId> functionNames) {
        if (functionNames == null) {
            return;
        }

        LOG.info("context: {} function names: {}", context, StringUtils.join(functionNames, ","));

        long tenantId = context.getTenantId();
        String regionCode = context.getRegionCode();
        String userId = context.getUserId();
        String currentCatalog = context.getCurrentCatalog();
        String currentDatabase = context.getCurrentDatabase();

        for (FunctionId functionId : functionNames) {
            String database = functionId.getSchemaName();
            String funcName = functionId.getFunctionName();
            database = database == null ? context.getCurrentDatabase() : database;

            database = StringUtils.lowerCase(database);
            funcName = StringUtils.lowerCase(funcName);
            FunctionEntity function =
                    functionService.queryFunction(tenantId, regionCode, currentCatalog, database, funcName);
            if (function != null) {
                if (AUTH_ONESELF.equals(function.getAuthType())) {
                    if (!userId.equals(function.getCreater())) {
                        throw new AccessFunctionException("无权访问函数: {}, 函数访问范围:仅个人可用", funcName);
                    }
                } else if (AUTH_WORKSPACE_USERS.equals(function.getAuthType())) {
                    if (StringUtils.isNotBlank(database) && !currentDatabase.equals(database)) {
                        throw new AccessFunctionException("无权访问函数: , 函数访问范围:仅项目空间成员可用", funcName);
                    }
                } else if (AUTH_ASSIGN_USERS.equals(function.getAuthType())) {
                    if (!function.getAuthUsers().contains(userId)) {
                        throw new AccessFunctionException("无权访问函数: {}, 函数访问范围:指定用户可用", funcName);
                    }
                }
            }
        }
    }

    private boolean isSparkTempTable(String currentDatabaseName, String[] sparkTempTables, TableId tableId) {
        boolean sparkTempTable = false;

        if (sparkTempTables != null && sparkTempTables.length > 0) {
            String schemaName = CommonUtils.getDatabaseName(currentDatabaseName, tableId.getSchemaName());
            String tableName = StringUtils.lowerCase(tableId.getTableName());

            if (StringUtils.equalsIgnoreCase(currentDatabaseName, schemaName)) {
                for (String name : sparkTempTables) {
                    if (StringUtils.equalsIgnoreCase(name, tableName)) {
                        sparkTempTable = true;
                        break;
                    }
                }
            }
        }
        return sparkTempTable;
    }

    @Transactional
    public void checkAccessTableAuth(AuthContext authContext, TableId tableId, PrivilegeType privilegeType) {
        this.checkAccessTableAuth(authContext, tableId, privilegeType, false);
    }

    private void checkAccessTableAuth(
            AuthContext context, TableId tableId, PrivilegeType privilegeType, boolean ifExists) {

        long tenantId = context.getTenantId();
        String regionCode = context.getRegionCode();
        String userId = context.getUserId();
        String catalogName = CommonUtils.getCatalogName(context.getCurrentCatalog(), tableId.getCatalogName());
        String databaseName = CommonUtils.getDatabaseName(context.getCurrentDatabase(), tableId.getSchemaName());
        String tableName = tableId.getTableName();

        // 访问paimon 系统表。格式:SELECT * FROM hive_metastore.bigdata.paimon_users_ods$schemas
        String paimonSysTableName = StringUtils.substringAfterLast(tableName, "$");
        if (StringUtils.isNotBlank(paimonSysTableName)
                && ArrayUtils.contains(PAIMON_SYS_TABLES, paimonSysTableName.toLowerCase())) {
            tableName = StringUtils.substringBeforeLast(tableName, "$");
        }

        if (privilegeType != PrivilegeType.CREATE) {
            TableEntity table = tableService.queryTable(tenantId, regionCode, catalogName, databaseName, tableName);
            if (table == null) {
                if (ifExists) {
                    return; // alter table if exists 语句,找不到表,不报错
                }
                throw new AccessTableException("table not exist: {}.{}.{}", catalogName, databaseName, tableName);
            }
            if (tablePrivsService.checkOwner(table, userId)) {
                return;
            }
        }

        if (StringUtils.isBlank(databaseName)) {
            throw new SuperiorException("databaseName can not blank");
        }

        List<SecTablePrivsEntity> tablePrivsList =
                tablePrivsService.queryTablePrivs(tenantId, userId, catalogName, databaseName, tableName);
        List<SecDatabasePrivsEntity> databasePrivsList = null;
        if (tablePrivsList.isEmpty()) {
            databasePrivsList = databasePrivsService.queryDatabasePrivs(tenantId, userId, catalogName, databaseName);
            if (databasePrivsList.isEmpty()) {
                throw new AccessTableException(
                        "{} 没有申请表: {}.{}.{} 使用权限: {}", userId, catalogName, databaseName, tableName, privilegeType);
            }
        }

        for (SecTablePrivsEntity tablePrivs : tablePrivsList) { // 权限审核通过
            LocalDate currentDate = LocalDate.now();
            // 权限过期
            if (currentDate.isAfter(tablePrivs.getExpireDate())) {
                if (tablePrivs.getStatus() == 1 || tablePrivs.getStatus() == 15) {
                    tablePrivs.setStatus(9);
                    tablePrivsService.updateEntity(tablePrivs);
                }
            }

            switch (privilegeType) {
                case READ:
                    if (tablePrivs.isReadPriv()) {
                        return;
                    }
                    break;
                case WRITE:
                    if (tablePrivs.isWritePriv()) {
                        return;
                    }
                    break;
                case ALTER:
                    if (tablePrivs.isAlterPriv()) {
                        return;
                    }
                    break;
                case DROP:
                    if (tablePrivs.isDropPriv()) {
                        return;
                    }
                    break;
                default:
                    throw new AccessTableException("not support " + privilegeType);
            }
        }

        if (databasePrivsList != null && !databasePrivsList.isEmpty()) {
            for (SecDatabasePrivsEntity databasePrivs : databasePrivsList) { // 权限审核通过
                LocalDate currentDate = LocalDate.now();
                // 权限过期
                if (currentDate.isAfter(databasePrivs.getExpireDate())) {
                    if (databasePrivs.getStatus() == 15) {
                        databasePrivs.setStatus(9);
                        databasePrivsService.updateEntity(databasePrivs);
                    }
                }

                switch (privilegeType) {
                    case CREATE:
                        if (databasePrivs.isCreatePriv()) {
                            return;
                        }
                        break;
                    case READ:
                        if (databasePrivs.isReadPriv()) {
                            return;
                        }
                        break;
                    case WRITE:
                        if (databasePrivs.isWritePriv()) {
                            return;
                        }
                        break;
                    case ALTER:
                        if (databasePrivs.isAlterPriv()) {
                            return;
                        }
                        break;
                    case DROP:
                        if (databasePrivs.isDropPriv()) {
                            return;
                        }
                        break;
                    default:
                        throw new AccessTableException("not support " + privilegeType);
                }
            }
        }

        String msg = String.format(
                "%s 没有申请表: %s.%s.%s 使用权限: %s", userId, catalogName, databaseName, tableName, privilegeType);
        throw new AccessTableException(msg);
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants