Skip to content

Commit

Permalink
download_files
Browse files Browse the repository at this point in the history
  • Loading branch information
bleudev committed Jun 28, 2024
1 parent 5f903d8 commit 80f6f49
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions ufpy/github/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os
import warnings
from shutil import copy, copytree, rmtree
from tempfile import gettempdir
from typing import Iterable
from zipfile import ZipFile

from requests import get

from ufpy.path import UOpen
from ufpy.typ import Empty

__all__ = (
'file',
Expand Down Expand Up @@ -114,15 +115,18 @@ def repo(repo: str, download_path: str, branch_name: str = 'main'):
rmtree(f'{download_path}/{main_directory_name}')


def format_paths(*paths: str) -> list[str] | str | Empty[list]:
def format_paths(*paths: str | list[str]) -> list[str] | list[list[str]] | list[str | list[str]] | str:
new_paths = []
for path in paths:
path = path.replace('\\', '/')
if isinstance(path, list):
path = format_paths(*path)
else:
path = path.replace('\\', '/')

if path.startswith('/'):
path = path[1:]
if path.endswith('/'):
path = path[:-1]
if path.startswith('/'):
path = path[1:]
if path.endswith('/'):
path = path[:-1]

new_paths.append(path)
return new_paths[0] if len(new_paths) <= 1 else new_paths
Expand All @@ -137,13 +141,22 @@ def __init__(self, repo: str, base_download_path: str = 'C:/', branch_name: str
def __enter__(self):
url = f'https://github.com/{self.__repo}/archive/{self.__branch}.zip'
self.__zip = ZipFile(io.BytesIO(get(url).content))

temp_dir = format_paths(gettempdir())
self.__zip.extractall(temp_dir)
repo_name = self.__repo.split('/')[-1]
self.__repo_path = f'{temp_dir}/{repo_name}-{self.__branch}'
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.__zip.close()
if os.path.exists(self.__repo_path):
rmtree(self.__repo_path)

def __del__(self):
self.__zip.close()
if os.path.exists(self.__repo_path):
rmtree(self.__repo_path)

def download_file(self, file_path: str, download_path: str = ''):
file_path, download_path = format_paths(file_path, download_path)
Expand All @@ -156,11 +169,16 @@ def download_file(self, file_path: str, download_path: str = ''):
raise Exception(
"Error with getting file from GitHub. Check that repo is public and that file path is correct.")

path = f'{download_path}{file_path}'
path = f'{download_path}/{file_path}'

with UOpen(path, 'w+') as f:
f.write(r.text)

def download_files(self, file_paths: Iterable[str], download_path: str = ''):
file_paths, download_path = format_paths(list(file_paths), download_path)
for file_path in file_paths:
self.download_file(file_path, download_path)

def download_folder(self, folder_path: str | list[str], download_path: str):
folder(self.__repo, folder_path, download_path, self.__branch)

Expand Down

0 comments on commit 80f6f49

Please sign in to comment.