Skip to content

Commit fbf6808

Browse files
committed
Allow customizing requests Session settings
Added a new `session_factory` argument to `CASClient` constructors, which allows customizing many requests behaviors such as HTTP headers, proxies, hooks and more.
1 parent e68a2a7 commit fbf6808

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

cas.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,16 @@ class CASClientBase(object):
7070

7171
def __init__(self, service_url=None, server_url=None,
7272
extra_login_params=None, renew=False,
73-
username_attribute=None, verify_ssl_certificate=True):
73+
username_attribute=None, verify_ssl_certificate=True,
74+
session=None):
7475

7576
self.service_url = service_url
7677
self.server_url = server_url
7778
self.extra_login_params = extra_login_params or {}
7879
self.renew = renew
7980
self.username_attribute = username_attribute
8081
self.verify_ssl_certificate = verify_ssl_certificate
81-
pass
82+
self.session = session or requests.Session()
8283

8384
def verify_ticket(self, ticket):
8485
"""Verify ticket.
@@ -136,7 +137,7 @@ def get_proxy_ticket(self, pgt):
136137
Raises:
137138
CASError: Non 200 http code or bad XML body.
138139
"""
139-
response = requests.get(self.get_proxy_url(pgt), verify=self.verify_ssl_certificate)
140+
response = self.session.get(self.get_proxy_url(pgt), verify=self.verify_ssl_certificate)
140141
if response.status_code == 200:
141142
from lxml import etree
142143
root = etree.fromstring(response.content)
@@ -168,7 +169,7 @@ def verify_ticket(self, ticket):
168169
params = [('ticket', ticket), ('service', self.service_url)]
169170
url = (urllib_parse.urljoin(self.server_url, 'validate') + '?' +
170171
urllib_parse.urlencode(params))
171-
page = requests.get(
172+
page = self.session.get(
172173
url,
173174
stream=True,
174175
verify=self.verify_ssl_certificate
@@ -208,7 +209,7 @@ def get_verification_response(self, ticket):
208209
if self.proxy_callback:
209210
params.update({'pgtUrl': self.proxy_callback})
210211
base_url = urllib_parse.urljoin(self.server_url, self.url_suffix)
211-
page = requests.get(
212+
page = self.session.get(
212213
base_url,
213214
params=params,
214215
verify=self.verify_ssl_certificate
@@ -376,7 +377,7 @@ def fetch_saml_validation(self, ticket):
376377
saml_validate_url = urllib_parse.urljoin(
377378
self.server_url, 'samlValidate',
378379
)
379-
return requests.post(
380+
return self.session.post(
380381
saml_validate_url,
381382
self.get_saml_assertion(ticket),
382383
params=params,

tests/test_cas.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import absolute_import
33

44
import sys
5+
from unittest.mock import Mock, MagicMock
56

67
import cas
78
import pytest
@@ -189,6 +190,25 @@ def test_can_saml_assertion_is_encoded():
189190
else:
190191
assert ticket in saml
191192

193+
# Test session= constructor argument with a mock session
194+
def test_v3_custom_session():
195+
response = Mock()
196+
response.content = SUCCESS_RESPONSE
197+
session = Mock()
198+
session.get = Mock(return_value=response)
199+
200+
client = cas.CASClient(
201+
version='3',
202+
server_url='https://cas.example.com/cas/',
203+
service_url='https://example.com/login',
204+
session=session)
205+
user, attributes, pgtiou = client.verify_ticket('ABC123')
206+
207+
assert user == 'user@example.com'
208+
assert not attributes
209+
assert not pgtiou
210+
211+
192212

193213
@fixture
194214
def client_v2():

0 commit comments

Comments
 (0)