diff --git a/safety/safety.py b/safety/safety.py index 1c29d4c..ae6d1fc 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -100,20 +100,26 @@ def write_to_cache(db_name, data): if exc.errno != errno.EEXIST: raise - with open(DB_CACHE_FILE, "r") as f: - try: - cache = json.loads(f.read()) - except json.JSONDecodeError: - LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.') + cache_file_lock = f"{DB_CACHE_FILE}.lock" + lock = FileLock(cache_file_lock, timeout=10) + with lock: + if os.path.exists(DB_CACHE_FILE): + with open(DB_CACHE_FILE, "r") as f: + try: + cache = json.loads(f.read()) + except json.JSONDecodeError: + LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.') + cache = {} + else: cache = {} - with open(DB_CACHE_FILE, "w") as f: - cache[db_name] = { - "cached_at": time.time(), - "db": data - } - f.write(json.dumps(cache)) - LOG.debug('Safety updated the cache file for %s database.', db_name) + with open(DB_CACHE_FILE, "w") as f: + cache[db_name] = { + "cached_at": time.time(), + "db": data + } + f.write(json.dumps(cache)) + LOG.debug('Safety updated the cache file for %s database.', db_name) def fetch_database_url(session, mirror, db_name, cached, telemetry=True, diff --git a/tests/test_safety.py b/tests/test_safety.py index d829d6b..630704e 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -171,6 +171,9 @@ def test_check_live(self): def test_check_live_cached(self): from safety.constants import DB_CACHE_FILE + # Ensure the cache directory and file exist + os.makedirs(os.path.dirname(DB_CACHE_FILE), exist_ok=True) + # lets clear the cache first try: with open(DB_CACHE_FILE, 'w') as f: