Skip to content

Commit a7f8936

Browse files
msaroufimfacebook-github-bot
authored andcommitted
add support for **kwargs in HttpReader (meta-pytorch#392)
Summary: Please read through our [contribution guide](https://github.com/pytorch/data/blob/main/CONTRIBUTING.md) prior to creating your pull request. - Note that there is a section on requirements related to adding a new DataPipe. Fixes meta-pytorch#391 ### Changes - Added support for `**kwarg` for `request` parameters in `HttpReader` Although it doesn't look like the HTTP dataset pipe is tested anywhere so I'll take a closer look and add one Pull Request resolved: meta-pytorch#392 Reviewed By: NivekT, ejguan Differential Revision: D36287860 Pulled By: msaroufim fbshipit-source-id: daac6c6e3a57608928c8e262f7cdc4d5a349ee03
1 parent e0d0aa9 commit a7f8936

3 files changed

Lines changed: 22 additions & 9 deletions

File tree

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ conda install pytorch -c pytorch-nightly
5252
git clone https://github.com/pytorch/data.git
5353
cd data
5454
python setup.py develop
55-
pip install flake8 typing mypy pytest
55+
pip install flake8 typing mypy pytest expecttest
5656
```
5757

5858
## Pull Requests

test/test_remote_io.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def test_http_reader_iterdatapipe(self):
4141
file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
4242
expected_file_name = "LICENSE"
4343
expected_MD5_hash = "bb9675028dd39d2dd2bf71002b93e66c"
44-
http_reader_dp = HttpReader(IterableWrapper([file_url]))
44+
query_params = {"auth": ("fake_username", "fake_password"), "allow_redirects": True}
45+
timeout = 120
46+
http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, **query_params)
4547

4648
# Functional Test: test if the Http Reader can download and read properly
4749
reader_dp = http_reader_dp.readlines()

torchdata/datapipes/iter/load/online.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import re
88
import urllib
99

10-
from typing import Dict, Iterator, Optional, Tuple
10+
from typing import Any, Dict, Iterator, Optional, Tuple
1111

1212
import requests
1313

@@ -33,14 +33,16 @@ def _get_proxies() -> Optional[Dict[str, str]]:
3333
return None
3434

3535

36-
def _get_response_from_http(url: str, *, timeout: Optional[float]) -> Tuple[str, StreamWrapper]:
36+
def _get_response_from_http(
37+
url: str, *, timeout: Optional[float], **query_params: Optional[Dict[str, Any]]
38+
) -> Tuple[str, StreamWrapper]:
3739
try:
3840
with requests.Session() as session:
3941
proxies = _get_proxies()
4042
if timeout is None:
41-
r = session.get(url, stream=True, proxies=proxies)
43+
r = session.get(url, stream=True, proxies=proxies, **query_params)
4244
else:
43-
r = session.get(url, timeout=timeout, stream=True, proxies=proxies)
45+
r = session.get(url, timeout=timeout, stream=True, proxies=proxies, **query_params)
4446
return url, StreamWrapper(r.raw)
4547
except HTTPError as e:
4648
raise Exception(f"Could not get the file. [HTTP Error] {e.response}.")
@@ -59,11 +61,14 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
5961
Args:
6062
source_datapipe: a DataPipe that contains URLs
6163
timeout: timeout in seconds for HTTP request
64+
**kwargs: a Dictionary to pass optional arguments that requests takes. For the full list check out https://docs.python-requests.org/en/master/api/
6265
6366
Example:
6467
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
6568
>>> file_url = "https://raw.githubusercontent.com/pytorch/data/main/LICENSE"
66-
>>> http_reader_dp = HttpReader(IterableWrapper([file_url]))
69+
>>> query_params = {"auth" : ("fake_username", "fake_password"), "allow_redirects" : True}
70+
>>> timeout = 120
71+
>>> http_reader_dp = HttpReader(IterableWrapper([file_url]), timeout=timeout, query_params)
6772
>>> reader_dp = http_reader_dp.readlines()
6873
>>> it = iter(reader_dp)
6974
>>> path, line = next(it)
@@ -73,13 +78,19 @@ class HTTPReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
7378
b'BSD 3-Clause License'
7479
"""
7580

76-
def __init__(self, source_datapipe: IterDataPipe[str], timeout: Optional[float] = None) -> None:
81+
def __init__(
82+
self, source_datapipe: IterDataPipe[str], timeout: Optional[float] = None, **kwargs: Optional[Dict[str, Any]]
83+
) -> None:
7784
self.source_datapipe: IterDataPipe[str] = source_datapipe
7885
self.timeout = timeout
86+
self.query_params = kwargs
7987

8088
def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
8189
for url in self.source_datapipe:
82-
yield _get_response_from_http(url, timeout=self.timeout)
90+
if self.query_params:
91+
yield _get_response_from_http(url, timeout=self.timeout, **self.query_params)
92+
else:
93+
yield _get_response_from_http(url, timeout=self.timeout)
8394

8495
def __len__(self) -> int:
8596
return len(self.source_datapipe)

0 commit comments

Comments
 (0)