Skip to content

Commit

Permalink
Merge pull request #14 from bosch-aisecurity-aishield/fixing_security…
Browse files Browse the repository at this point in the history
…_issues

fixing security level issues
  • Loading branch information
DeepakByrappa authored Mar 4, 2024
2 parents 523c53a + a95df66 commit bdefa31
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 37 deletions.
19 changes: 11 additions & 8 deletions src/modules/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from tqdm import tqdm
from utils import helper, github_util, report_util

import os
from modules import notebook_inspector, model_inspector


Expand Down Expand Up @@ -72,6 +72,7 @@ def orchestrator(repo_type: str = 'github', repo_url: str = None, github_clone_d
scanning_status = True
failed_scan_files = list()
scanned_report_dictionary = {}
base_path = str(os.getcwd())

try:
print("Scanning Started ...")
Expand All @@ -85,7 +86,8 @@ def orchestrator(repo_type: str = 'github', repo_url: str = None, github_clone_d
scanning_id=scanning_id,
path=path,
branch_name=branch_name,
depth=depth)
depth=depth,
base_path = base_path)

# iterate to get response from each files
for file in tqdm(to_be_scanned_files):
Expand Down Expand Up @@ -139,12 +141,13 @@ def orchestrator(repo_type: str = 'github', repo_url: str = None, github_clone_d
except Exception as e:
scanning_status = False
print("Scanning Failed due to {}".format(str(e)))

# Clean up the local cloned directory either scanning failed or Completed
if repo_type.lower() == "github":
github_util.delete_github_repo(repo_dir=save_dir)
elif repo_type.lower() not in ["file", "folder"]:
helper.delete_directory([save_dir])

if save_dir != None:
# Clean up the local cloned directory either scanning failed or Completed
if repo_type.lower() == "github":
github_util.delete_github_repo(repo_dir=save_dir)
elif repo_type.lower() not in ["file", "folder"]:
helper.delete_directory(base_path,[save_dir])

return report_path, scanning_status

Expand Down
2 changes: 1 addition & 1 deletion src/utils/github_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def delete_github_repo(repo_dir):
print("Locally cloned repository has been successfully removed")

except Exception as e:
print("{} Failed to remove due to {}".format(repo_dir, str(e)))
print("{} Failed to remove due to {}, it is recommended to delete the directory manually".format(repo_dir, str(e)))
50 changes: 33 additions & 17 deletions src/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None, github_clone_dir: str = None,
aws_access_key_id: str = None, aws_secret_access_key: str = None, region: str = None,
bucket_name: str = None, s3_download_dir: str = None,path: str = None, branch_name: str = 'main',depth: int=1):
bucket_name: str = None, s3_download_dir: str = None,path: str = None, branch_name: str = 'main',depth: int=1,base_path: str=None):
"""
Fetches files to be scanned based on the repository type and scanning ID.
Expand Down Expand Up @@ -53,31 +53,33 @@ def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None,

save_dir = github_clone_dir
# Clone the gitHub repository in the local

# if not os.path.exists(save_dir):
# os.makedirs(save_dir)

# print(os.path.dirname(save_dir))
if repo_type.lower() == 'github':
repo_url = repo_url
elif repo_type.lower() == 'huggingface':
if "https://huggingface.co/" not in repo_url:
repo_url = f'https://huggingface.co/{repo_url}'



github_util.clone_github_repo(repo_url, save_dir,branch_name,depth)

# get all h5 files
h5_files = search_files(github_clone_dir, '.h5')
h5_files = search_files(base_path,github_clone_dir, '.h5')

# get all .pb files
pb_files = search_files(github_clone_dir, '.pb')
pb_files = search_files(base_path,github_clone_dir, '.pb')

# get all .pkl files
pkl_files = search_files(github_clone_dir, '.pkl')
pkl_files = search_files(base_path,github_clone_dir, '.pkl')

# get all ipynb files
ipynb_files = search_files(github_clone_dir, '.ipynb')
ipynb_files = search_files(base_path,github_clone_dir, '.ipynb')

# get requirements files
requirement_files = search_files(github_clone_dir, 'requirements.txt')
requirement_files = search_files(base_path,github_clone_dir, 'requirements.txt')

to_be_scanned_files = h5_files + ipynb_files + pb_files + pkl_files + requirement_files

Expand All @@ -90,7 +92,7 @@ def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None,
# Ensure local directory exists
if not os.path.exists(save_dir):
os.makedirs(save_dir)

# create s3 object to interact with s3 buckets
s3_object = aws_s3_util.AIShieldWatchtowerS3(aws_access_key_id, aws_secret_access_key,
region, bucket_name, save_dir)
Expand All @@ -108,18 +110,19 @@ def fetch_scanning_files(repo_type: str, scanning_id: str, repo_url: str = None,

if repo_type.lower() == 'folder':
tar_dir = path # Assuming file_path is the path to the folder
h5_files = search_files(tar_dir, '.h5')
pb_files = search_files(tar_dir, '.pb')
pkl_files = search_files(tar_dir, '.pkl')
ipynb_files = search_files(tar_dir, '.ipynb')
requirement_files = search_files(tar_dir, 'requirements.txt')
folder_base_path = os.path.dirname(tar_dir)
h5_files = search_files(folder_base_path,tar_dir, '.h5')
pb_files = search_files(folder_base_path,tar_dir, '.pb')
pkl_files = search_files(folder_base_path,tar_dir, '.pkl')
ipynb_files = search_files(folder_base_path,tar_dir, '.ipynb')
requirement_files = search_files(folder_base_path,tar_dir, 'requirements.txt')

to_be_scanned_files = h5_files + ipynb_files + pb_files + pkl_files + requirement_files

return to_be_scanned_files, save_dir


def search_files(target_dir: str, file_extensions: str):
def search_files(base_path:str, target_dir: str, file_extensions: str):
"""
Finds all the files ending with a given extension in the specified directory and its sub-folders.
Expand All @@ -130,7 +133,14 @@ def search_files(target_dir: str, file_extensions: str):
Returns:
- List of paths to files with the specified extension.
"""
if not target_dir:
raise Exception("Target directory is empty")

# Normalize the target directory and check if it is within the base_path
full_target_dir = os.path.normpath(os.path.join(base_path, target_dir))

if not os.path.abspath(full_target_dir).startswith(os.path.abspath(base_path)):
raise Exception("Target directory is outside the base path")
# List to hold the paths of all files matching the given extension
matching_files = []

Expand Down Expand Up @@ -166,7 +176,7 @@ def make_directory(path):
print("{} created successfully".format(path))


def delete_directory(directory):
def delete_directory(base_path :str,directory):
"""
delete directory
Expand All @@ -181,6 +191,12 @@ def delete_directory(directory):
"""

for d in directory:

full_path = os.path.normpath(os.path.join(base_path, d))
if not os.path.abspath(full_path).startswith(os.path.abspath(base_path)):
print("Path '{}' is outside the base folder '{}' and cannot be deleted.".format(full_path, base_path))
return

try:
if os.path.isdir(d):
try:
Expand Down
22 changes: 11 additions & 11 deletions src/utils/report_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,18 +334,18 @@ def whisper_output_parser(output: str):
output = eval(output)
if len(output) != 0:
for out in output:
key = out['severity'].lower()
if key == "info" or key == "minor":
key = "Low"
elif key == "major":
key = "Medium"
elif key == "blocker" or key == "critical":
key = "High"
out['vulnerability_severity'] = key
if key in vulnerability_severity_map:
vulnerability_severity_map[key] = int(vulnerability_severity_map[key]) + 1
whisper_sev_key = out['severity'].lower()
if whisper_sev_key == "info" or whisper_sev_key == "minor":
vul_sev_key = "Low"
elif whisper_sev_key == "major":
vul_sev_key = "Medium"
elif whisper_sev_key == "blocker" or whisper_sev_key == "critical":
vul_sev_key = "High"
out['vulnerability_severity'] = vul_sev_key
if vul_sev_key in vulnerability_severity_map:
vulnerability_severity_map[vul_sev_key] = int(vulnerability_severity_map[vul_sev_key]) + 1
else:
vulnerability_severity_map[key] = 1
vulnerability_severity_map[vul_sev_key] = 1

except Exception as e:
print("Failed to parse whisper_output {}".format(e))
Expand Down

0 comments on commit bdefa31

Please sign in to comment.