Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dataset downloads from Google Drive #10963

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 47 additions & 11 deletions tensorflow_datasets/core/download/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from etils import epath
from tensorflow_datasets.core import units
from tensorflow_datasets.core import utils
from tensorflow_datasets.core import lazy_imports_lib
from tensorflow_datasets.core.download import checksums as checksums_lib
from tensorflow_datasets.core.download import resource as resource_lib
from tensorflow_datasets.core.download import util as download_utils_lib
Expand Down Expand Up @@ -130,6 +131,44 @@ def _get_filename(response: Response) -> str:
return _basename_from_url(response.url)


def _process_gdrive_confirmation(original_url: str, contents: str) -> str:
"""Process Google Drive confirmation page.

Extracts the download link from a Google Drive confirmation page.

Args:
original_url: The URL the confirmation page was originally
retrieved from.
contents: The confirmation page's HTML.

Returns:
download_url: The URL for downloading the file.
"""
bs4 = lazy_imports_lib.lazy_imports.bs4
soup = bs4.BeautifulSoup(contents, 'html.parser')
form = soup.find('form')
if not form:
raise ValueError(
f'Failed to obtain confirmation link for GDrive URL {original_url}.'
)
action = form.get('action', '')
if not action:
raise ValueError(
f'Failed to obtain confirmation link for GDrive URL {original_url}.'
)
# Find the <input>s named 'uuid', 'export', 'id' and 'confirm'
input_names = ['uuid', 'export', 'id', 'confirm']
params = {}
for name in input_names:
input_tag = form.find('input', {'name': name})
if input_tag:
params[name] = input_tag.get('value', '')
query_string = urllib.parse.urlencode(params)
download_url = f'{action}?{query_string}' if query_string else action
download_url = urllib.parse.urljoin(original_url, download_url)
return download_url


class _Downloader:
"""Class providing async download API with checksum validation.

Expand Down Expand Up @@ -318,11 +357,15 @@ def _open_with_requests(
session.mount(
'https://', requests.adapters.HTTPAdapter(max_retries=retries)
)
if _DRIVE_URL.match(url):
url = _normalize_drive_url(url)
with session.get(url, stream=True, **kwargs) as response:
_assert_status(response)
yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE))
if _DRIVE_URL.match(url) and 'Content-Disposition' not in response.headers:
download_url = _process_gdrive_confirmation(url, response.text)
with session.get(download_url, stream=True, **kwargs) as download_response:
_assert_status(download_response)
yield (download_response, download_response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE))
else:
_assert_status(response)
yield (response, response.iter_content(chunk_size=io.DEFAULT_BUFFER_SIZE))


@contextlib.contextmanager
Expand All @@ -338,13 +381,6 @@ def _open_with_urllib(
)


def _normalize_drive_url(url: str) -> str:
"""Returns Google Drive url with confirmation token."""
# This bypasses the "Google Drive can't scan this file for viruses" warning
# when dowloading large files.
return url + '&confirm=t'


def _assert_status(response: requests.Response) -> None:
"""Ensure the URL response is 200."""
if response.status_code != 200:
Expand Down