Skip to content

Commit a6841bf

Browse files
authored
Merge pull request robinhood-unofficial#149 from anthonykrivonos/master
✅ Fixes Every OAuth Issue + Major Improvements and Fixes
2 parents d21b190 + b88d34c commit a6841bf

6 files changed

Lines changed: 138 additions & 108 deletions

File tree

CONTRIBUTORS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ Bharath Lohray (@lordloh)
88
John Purcell (@lockefox)
99
Greg Oberifeld (@gregoberfield)
1010
Eric Evans (@ciresnave)
11+
Anthony Krivonos (@anthonykrivonos)
12+
Aaron Mazie (@aamazie)

Robinhood/Robinhood.py

100644100755
Lines changed: 98 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
import dateutil
1818

1919
#Application-specific imports
20-
from . import exceptions as RH_exception
21-
from . import endpoints
22-
20+
import exceptions as RH_exception
21+
import endpoints
2322

2423
class Bounds(Enum):
2524
"""Enum for bounds in `historicals` endpoint """
@@ -43,11 +42,13 @@ class Robinhood:
4342
password = None
4443
headers = None
4544
auth_token = None
46-
oauth_token = None
45+
refresh_token = None
4746

4847
logger = logging.getLogger('Robinhood')
4948
logger.addHandler(logging.NullHandler())
5049

50+
client_id = "c82SH0WZOsabOXGP2sxqcj34FxkvfnWRZBKlBjFS"
51+
5152

5253
###########################################################################
5354
# Logging in and initializing
@@ -104,12 +105,13 @@ def login(self,
104105
self.password = password
105106
payload = {
106107
'password': self.password,
107-
'username': self.username
108+
'username': self.username,
109+
'grant_type': 'password',
110+
'client_id': self.client_id
108111
}
109112

110113
if mfa_code:
111114
payload['mfa_code'] = mfa_code
112-
113115
try:
114116
res = self.session.post(endpoints.login(), data=payload, timeout=15)
115117
res.raise_for_status()
@@ -120,9 +122,10 @@ def login(self,
120122
if 'mfa_required' in data.keys(): # pragma: no cover
121123
raise RH_exception.TwoFactorRequired() # requires a second call to enable 2FA
122124

123-
if 'token' in data.keys():
124-
self.auth_token = data['token']
125-
self.headers['Authorization'] = 'Token ' + self.auth_token
125+
if 'access_token' in data.keys() and 'refresh_token' in data.keys():
126+
self.auth_token = data['access_token']
127+
self.refresh_token = data['refresh_token']
128+
self.headers['Authorization'] = 'Bearer ' + self.auth_token
126129
return True
127130

128131
return False
@@ -137,7 +140,11 @@ def logout(self):
137140
"""
138141

139142
try:
140-
req = self.session.post(endpoints.logout(), timeout=15)
143+
payload = {
144+
'client_id': self.client_id,
145+
'token': self.refresh_token
146+
}
147+
req = self.session.post(endpoints.logout(), data=payload, timeout=15)
141148
req.raise_for_status()
142149
except requests.exceptions.HTTPError as err_msg:
143150
warnings.warn('Failed to log out ' + repr(err_msg))
@@ -192,7 +199,7 @@ def instrument(self, id):
192199
Returns:
193200
(:obj:`dict`): JSON dict of instrument
194201
"""
195-
url = str(endpoints.instruments()) + str(id) + "/"
202+
url = str(endpoints.instruments()) + "?symbol=" + str(id)
196203

197204
try:
198205
req = requests.get(url, timeout=15)
@@ -201,7 +208,7 @@ def instrument(self, id):
201208
except requests.exceptions.HTTPError:
202209
raise RH_exception.InvalidInstrumentId()
203210

204-
return data
211+
return data['results']
205212

206213

207214
def quote_data(self, stock=''):
@@ -308,7 +315,7 @@ def get_quote(self, stock=''):
308315
"""Wrapper for quote_data """
309316

310317
data = self.quote_data(stock)
311-
return data["symbol"]
318+
return data
312319

313320
def get_historical_quotes(self, stock, interval, span, bounds=Bounds.REGULAR):
314321
"""Fetch historical data for stock
@@ -334,15 +341,10 @@ def get_historical_quotes(self, stock, interval, span, bounds=Bounds.REGULAR):
334341
if isinstance(bounds, str): # recast to Enum
335342
bounds = Bounds(bounds)
336343

337-
params = {
338-
'symbols': ','.join(stock).upper(),
339-
'interval': interval,
340-
'span': span,
341-
'bounds': bounds.name.lower()
342-
}
344+
historicals = endpoints.historicals() + "/?symbols=" + ','.join(stock).upper() + "&interval=" + interval + "&span=" + span + "&bounds=" + bounds.name.lower()
343345

344-
res = self.session.get(endpoints.historicals(), params=params, timeout=15)
345-
return res.json()
346+
res = self.session.get(historicals, timeout=15)
347+
return res.json()['results'][0]
346348

347349

348350
def get_news(self, stock):
@@ -369,7 +371,6 @@ def print_quote(self, stock=''): # pragma: no cover
369371
data = self.get_quote_list(stock, 'symbol,last_trade_price')
370372
for item in data:
371373
quote_str = item[0] + ": $" + item[1]
372-
print(quote_str)
373374
self.logger.info(quote_str)
374375

375376

@@ -634,12 +635,12 @@ def get_options(self, stock, expiration_dates, option_type):
634635
Returns:
635636
Options Contracts (List): a list (chain) of contracts for a given underlying equity instrument
636637
"""
637-
instrumentid = self.get_url(self.quote_data(stock)["instrument"])["id"]
638-
if(type(expiration_dates) == list):
639-
_expiration_dates_string = expiration_dates.join(",")
638+
instrument_id = self.get_url(self.quote_data(stock)["instrument"])["id"]
639+
if (type(expiration_dates) == list):
640+
_expiration_dates_string = ",".join(expiration_dates)
640641
else:
641642
_expiration_dates_string = expiration_dates
642-
chain_id = self.get_url(endpoints.chain(instrumentid))["results"][0]["id"]
643+
chain_id = self.get_url(endpoints.chain(instrument_id))["results"][0]["id"]
643644
return [contract for contract in self.get_url(endpoints.options(chain_id, _expiration_dates_string, option_type))["results"]]
644645

645646
@login_required
@@ -650,13 +651,12 @@ def get_option_market_data(self, optionid):
650651
651652
Returns: dictionary of options market data.
652653
"""
653-
if not self.oauth_token:
654-
res = self.session.post(endpoints.convert_token(), timeout=15)
655-
res.raise_for_status()
656-
res = res.json()
657-
self.oauth_token = res["access_token"]
658-
self.headers['Authorization'] = 'Bearer ' + self.oauth_token
659-
return self.get_url(endpoints.market_data(optionid))
654+
market_data = {}
655+
try:
656+
market_data = self.get_url(endpoints.market_data(optionid)) or {}
657+
except requests.exceptions.HTTPError:
658+
raise RH_exception.InvalidOptionId()
659+
return market_data
660660

661661

662662
###########################################################################
@@ -858,7 +858,7 @@ def securities_owned(self):
858858
def place_order(self,
859859
instrument,
860860
quantity=1,
861-
bid_price=0.0,
861+
price=0.0,
862862
transaction=None,
863863
trigger='immediate',
864864
order='market',
@@ -888,13 +888,12 @@ def place_order(self,
888888
if isinstance(transaction, str):
889889
transaction = Transaction(transaction)
890890

891-
if not bid_price:
892-
bid_price = self.quote_data(instrument['symbol'])['bid_price']
891+
if not price:
892+
price = self.quote_data(instrument['symbol'])['bid_price']
893893

894894
payload = {
895895
'account': self.get_account()['url'],
896896
'instrument': unquote(instrument['url']),
897-
'price': float(bid_price),
898897
'quantity': quantity,
899898
'side': transaction.name.lower(),
900899
'symbol': instrument['symbol'],
@@ -903,14 +902,10 @@ def place_order(self,
903902
'type': order.lower()
904903
}
905904

906-
#data = 'account=%s&instrument=%s&price=%f&quantity=%d&side=%s&symbol=%s#&time_in_force=gfd&trigger=immediate&type=market' % (
907-
# self.get_account()['url'],
908-
# urllib.parse.unquote(instrument['url']),
909-
# float(bid_price),
910-
# quantity,
911-
# transaction,
912-
# instrument['symbol']
913-
#)
905+
if order.lower() == "stop":
906+
payload['stop_price'] = float(price)
907+
else:
908+
payload['price'] = float(price)
914909

915910
res = self.session.post(endpoints.orders(), data=payload, timeout=15)
916911
res.raise_for_status()
@@ -1251,6 +1246,11 @@ def submit_order(self,
12511246
(:obj:`requests.request`): result from `orders` put command
12521247
"""
12531248

1249+
# Used for default price input
1250+
# Price is required, so we use the current bid price if it is not specified
1251+
current_quote = self.get_quote(symbol)
1252+
current_bid_price = current_quote['bid_price']
1253+
12541254
# Start with some parameter checks. I'm paranoid about $.
12551255
if(instrument_URL is None):
12561256
if(symbol is None):
@@ -1302,8 +1302,9 @@ def submit_order(self,
13021302
if(price is not None):
13031303
if(order_type.lower() == 'market'):
13041304
raise(ValueError('Market order has price limit in call to submit_order'))
1305-
1306-
price = float(price)
1305+
price = float(price)
1306+
else:
1307+
price = current_bid_price # default to current bid price
13071308

13081309
if(quantity is None):
13091310
raise(ValueError('No quantity specified in call to submit_order'))
@@ -1316,20 +1317,22 @@ def submit_order(self,
13161317
payload = {}
13171318

13181319
for field, value in [
1319-
('account', self.get_account()['url']),
1320-
('instrument', instrument_URL),
1321-
('symbol', symbol),
1322-
('type', order_type),
1323-
('time_in_force', time_in_force),
1324-
('trigger', trigger),
1325-
('price', price),
1326-
('stop_price', stop_price),
1327-
('quantity', quantity),
1328-
('side', side)
1329-
]:
1320+
('account', self.get_account()['url']),
1321+
('instrument', instrument_URL),
1322+
('symbol', symbol),
1323+
('type', order_type),
1324+
('time_in_force', time_in_force),
1325+
('trigger', trigger),
1326+
('price', price),
1327+
('stop_price', stop_price),
1328+
('quantity', quantity),
1329+
('side', side)
1330+
]:
13301331
if(value is not None):
13311332
payload[field] = value
13321333

1334+
print(payload)
1335+
13331336
res = self.session.post(endpoints.orders(), data=payload, timeout=15)
13341337
res.raise_for_status()
13351338

@@ -1341,38 +1344,55 @@ def submit_order(self,
13411344

13421345
def cancel_order(
13431346
self,
1344-
order_id
1345-
):
1347+
order_id):
13461348
"""
1347-
Cancels specified order and returns the response (results from `orders` command).
1349+
Cancels specified order and returns the response (results from `orders` command).
13481350
If order cannot be cancelled, `None` is returned.
1349-
13501351
Args:
1351-
order_id (str): Order ID that is to be cancelled or order dict returned from
1352+
order_id (str or dict): Order ID string that is to be cancelled or open order dict returned from
13521353
order get.
13531354
Returns:
13541355
(:obj:`requests.request`): result from `orders` put command
13551356
"""
1356-
if order_id is str:
1357+
if isinstance(order_id, str):
13571358
try:
1358-
order = self.session.get(self.endpoints['orders'] + order_id, timeout=15).json()
1359+
order = self.session.get(endpoints.orders() + order_id, timeout=15).json()
13591360
except (requests.exceptions.HTTPError) as err_msg:
13601361
raise ValueError('Failed to get Order for ID: ' + order_id
13611362
+ '\n Error message: '+ repr(err_msg))
1362-
else:
1363-
raise ValueError('Cancelling orders requires a valid order_id string')
13641363

1365-
if order.get('cancel') is not None:
1366-
try:
1367-
res = self.session.post(order['cancel'], timeout=15)
1368-
res.raise_for_status()
1364+
if order.get('cancel') is not None:
1365+
try:
1366+
res = self.session.post(order['cancel'], timeout=15)
1367+
res.raise_for_status()
1368+
return res
1369+
except (requests.exceptions.HTTPError) as err_msg:
1370+
raise ValueError('Failed to cancel order ID: ' + order_id
1371+
+ '\n Error message: '+ repr(err_msg))
1372+
return None
1373+
1374+
if isinstance(order_id, dict):
1375+
order_id = order_id['id']
1376+
try:
1377+
order = self.session.get(endpoints.orders() + order_id, timeout=15).json()
13691378
except (requests.exceptions.HTTPError) as err_msg:
1370-
raise ValueError('Failed to cancel order ID: ' + order_id
1371-
+ '\n Error message: '+ repr(err_msg))
1372-
return None
1373-
1379+
raise ValueError('Failed to get Order for ID: ' + order_id
1380+
+ '\n Error message: '+ repr(err_msg))
1381+
1382+
if order.get('cancel') is not None:
1383+
try:
1384+
res = self.session.post(order['cancel'], timeout=15)
1385+
res.raise_for_status()
1386+
return res
1387+
except (requests.exceptions.HTTPError) as err_msg:
1388+
raise ValueError('Failed to cancel order ID: ' + order_id
1389+
+ '\n Error message: '+ repr(err_msg))
1390+
return None
1391+
1392+
elif not isinstance(order_id, str) or not isinstance(order_id, dict):
1393+
raise ValueError('Cancelling orders requires a valid order_id string or open order dictionary')
1394+
1395+
13741396
# Order type cannot be cancelled without a valid cancel link
1375-
else:
1397+
else:
13761398
raise ValueError('Unable to cancel order ID: ' + order_id)
1377-
1378-
return res

Robinhood/__init__.py

100644100755
File mode changed.

0 commit comments

Comments
 (0)