Skip to content

Commit 0a93483

Browse files
committed
Use Session for making requests & allow customization
All requests made by python-cas now use a Session object, which enables keep-alive HTTP connections. The session can also be customized by passing a `session=` argument to `CASClient` constructors, to change behaviors such as HTTP headers, proxies, hooks and more. My use case requires making all CAS requests through an HTTP proxy.
1 parent e68a2a7 commit 0a93483

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

cas.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,16 @@ class CASClientBase(object):
7070

7171
def __init__(self, service_url=None, server_url=None,
7272
extra_login_params=None, renew=False,
73-
username_attribute=None, verify_ssl_certificate=True):
73+
username_attribute=None, verify_ssl_certificate=True,
74+
session=None):
7475

7576
self.service_url = service_url
7677
self.server_url = server_url
7778
self.extra_login_params = extra_login_params or {}
7879
self.renew = renew
7980
self.username_attribute = username_attribute
8081
self.verify_ssl_certificate = verify_ssl_certificate
81-
pass
82+
self.session = session or requests.Session()
8283

8384
def verify_ticket(self, ticket):
8485
"""Verify ticket.
@@ -136,7 +137,7 @@ def get_proxy_ticket(self, pgt):
136137
Raises:
137138
CASError: Non 200 http code or bad XML body.
138139
"""
139-
response = requests.get(self.get_proxy_url(pgt), verify=self.verify_ssl_certificate)
140+
response = self.session.get(self.get_proxy_url(pgt), verify=self.verify_ssl_certificate)
140141
if response.status_code == 200:
141142
from lxml import etree
142143
root = etree.fromstring(response.content)
@@ -168,7 +169,7 @@ def verify_ticket(self, ticket):
168169
params = [('ticket', ticket), ('service', self.service_url)]
169170
url = (urllib_parse.urljoin(self.server_url, 'validate') + '?' +
170171
urllib_parse.urlencode(params))
171-
page = requests.get(
172+
page = self.session.get(
172173
url,
173174
stream=True,
174175
verify=self.verify_ssl_certificate
@@ -208,7 +209,7 @@ def get_verification_response(self, ticket):
208209
if self.proxy_callback:
209210
params.update({'pgtUrl': self.proxy_callback})
210211
base_url = urllib_parse.urljoin(self.server_url, self.url_suffix)
211-
page = requests.get(
212+
page = self.session.get(
212213
base_url,
213214
params=params,
214215
verify=self.verify_ssl_certificate
@@ -376,7 +377,7 @@ def fetch_saml_validation(self, ticket):
376377
saml_validate_url = urllib_parse.urljoin(
377378
self.server_url, 'samlValidate',
378379
)
379-
return requests.post(
380+
return self.session.post(
380381
saml_validate_url,
381382
self.get_saml_assertion(ticket),
382383
params=params,

tests/test_cas.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
from __future__ import absolute_import
33

44
import sys
5+
try:
6+
from unittest.mock import Mock
7+
except ImportError:
8+
# Python 2.7 fallback
9+
from mock import Mock
510

611
import cas
712
import pytest
@@ -189,6 +194,25 @@ def test_can_saml_assertion_is_encoded():
189194
else:
190195
assert ticket in saml
191196

197+
# Test session= constructor argument with a mock session
198+
def test_v3_custom_session():
199+
response = Mock()
200+
response.content = SUCCESS_RESPONSE
201+
session = Mock()
202+
session.get = Mock(return_value=response)
203+
204+
client = cas.CASClient(
205+
version='3',
206+
server_url='https://cas.example.com/cas/',
207+
service_url='https://example.com/login',
208+
session=session)
209+
user, attributes, pgtiou = client.verify_ticket('ABC123')
210+
211+
assert user == 'user@example.com'
212+
assert not attributes
213+
assert not pgtiou
214+
215+
192216

193217
@fixture
194218
def client_v2():

0 commit comments

Comments
 (0)