From 495ffc4777522e40941753e3b1b79c02f84b25b4 Mon Sep 17 00:00:00 2001 From: Grafting Rayman <156515434+GraftingRayman@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:00:30 +0000 Subject: Add files via upload --- r_basicsr/utils/download_util.py | 99 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 r_basicsr/utils/download_util.py (limited to 'r_basicsr/utils/download_util.py') diff --git a/r_basicsr/utils/download_util.py b/r_basicsr/utils/download_util.py new file mode 100644 index 0000000..59b2621 --- /dev/null +++ b/r_basicsr/utils/download_util.py @@ -0,0 +1,99 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file -- cgit v1.2.3