Skip to content

Commit 90f5e35

Browse files
committed
Merge branch 'dse'
2 parents c43c77a + d958010 commit 90f5e35

13 files changed

Lines changed: 522 additions & 15 deletions

File tree

cassandra/auth.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
except ImportError:
1616
SASLClient = None
1717

18-
1918
class AuthProvider(object):
2019
"""
2120
An abstract class that defines the interface that will be used for
@@ -63,6 +62,9 @@ class Authenticator(object):
6362
.. versionadded:: 2.0.0
6463
"""
6564

65+
server_authenticator_class = None
66+
""" Set during the connection AUTHENTICATE phase """
67+
6668
def initial_response(self):
6769
"""
6870
Returns an message to send to the server to initiate the SASL handshake.

cassandra/cluster.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -895,11 +895,14 @@ def shutdown(self):
895895

896896
def _new_session(self):
897897
session = Session(self, self.metadata.all_hosts())
898+
self._session_register_user_types(session)
899+
self.sessions.add(session)
900+
return session
901+
902+
def _session_register_user_types(self, session):
898903
for keyspace, type_map in six.iteritems(self._user_types):
899904
for udt_name, klass in six.iteritems(type_map):
900905
session.user_type_registered(keyspace, udt_name, klass)
901-
self.sessions.add(session)
902-
return session
903906

904907
def _cleanup_failed_on_up_handling(self, host):
905908
self.load_balancing_policy.on_down(host)

cassandra/connection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,15 +694,14 @@ def _handle_startup_response(self, startup_response, did_authenticate=False):
694694
if self.authenticator is None:
695695
raise AuthenticationFailed('Remote end requires authentication.')
696696

697-
self.authenticator_class = startup_response.authenticator
698-
699697
if isinstance(self.authenticator, dict):
700698
log.debug("Sending credentials-based auth response on %s", self)
701699
cm = CredentialsMessage(creds=self.authenticator)
702700
callback = partial(self._handle_startup_response, did_authenticate=True)
703701
self.send_msg(cm, self.get_request_id(), cb=callback)
704702
else:
705703
log.debug("Sending SASL-based auth response on %s", self)
704+
self.authenticator.server_authenticator_class = startup_response.authenticator
706705
initial_response = self.authenticator.initial_response()
707706
initial_response = "" if initial_response is None else initial_response
708707
self.send_msg(AuthResponseMessage(initial_response), self.get_request_id(), self._handle_auth_response)

cassandra/protocol.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,11 +491,15 @@ def __init__(self, query, consistency_level, serial_consistency_level=None,
491491
self.fetch_size = fetch_size
492492
self.paging_state = paging_state
493493
self.timestamp = timestamp
494+
self._query_params = None # only used internally. May be set to a list of native-encoded values to have them sent with the request.
494495

495496
def send_body(self, f, protocol_version):
496497
write_longstring(f, self.query)
497498
write_consistency_level(f, self.consistency_level)
498499
flags = 0x00
500+
if self._query_params is not None:
501+
flags |= _VALUES_FLAG # also v2+, but we're only setting params internally right now
502+
499503
if self.serial_consistency_level:
500504
if protocol_version >= 2:
501505
flags |= _WITH_SERIAL_CONSISTENCY_FLAG
@@ -525,6 +529,12 @@ def send_body(self, f, protocol_version):
525529
flags |= _PROTOCOL_TIMESTAMP
526530

527531
write_byte(f, flags)
532+
533+
if self._query_params is not None:
534+
write_short(f, len(self._query_params))
535+
for param in self._query_params:
536+
write_value(f, param)
537+
528538
if self.fetch_size:
529539
write_int(f, self.fetch_size)
530540
if self.paging_state:

cassandra/query.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,6 @@ def __str__(self):
338338
(self.query_string, consistency))
339339
__repr__ = __str__
340340

341-
342341
class PreparedStatement(object):
343342
"""
344343
A statement that has been prepared against at least one Cassandra node.

dse/__init__.py

Whitespace-only changes.

dse/auth.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from cassandra.auth import AuthProvider, Authenticator
2+
3+
try:
4+
import kerberos
5+
_have_kerberos = True
6+
except ImportError:
7+
_have_kerberos = False
8+
9+
try:
10+
from puresasl.client import SASLClient
11+
_have_puresasl = True
12+
except ImportError:
13+
_have_puresasl = False
14+
15+
16+
class DSEPlainTextAuthProvider(AuthProvider):
17+
def __init__(self, username=None, password=None):
18+
self.username = username
19+
self.password = password
20+
21+
def new_authenticator(self, host):
22+
return PlainTextAuthenticator(self.username, self.password)
23+
24+
25+
class DSEGSSAPIAuthProvider(AuthProvider):
26+
def __init__(self, service=None, qops=None):
27+
if not _have_puresasl:
28+
raise ImportError('The puresasl library has not been installed')
29+
if not _have_kerberos:
30+
raise ImportError('The kerberos library has not been installed')
31+
self.service = service
32+
self.qops = qops
33+
34+
def new_authenticator(self, host):
35+
return GSSAPIAuthenticator(host, self.service, self.qops)
36+
37+
38+
class BaseDSEAuthenticator(Authenticator):
39+
def get_mechanism(self):
40+
raise NotImplementedError("get_mechanism not implemented")
41+
42+
def get_initial_challenge(self):
43+
raise NotImplementedError("get_initial_challenge not implemented")
44+
45+
def initial_response(self):
46+
if self.server_authenticator_class == "com.datastax.bdp.cassandra.auth.DseAuthenticator":
47+
return self.get_mechanism()
48+
else:
49+
return self.evaluate_challenge(self.get_initial_challenge())
50+
51+
52+
class PlainTextAuthenticator(BaseDSEAuthenticator):
53+
def __init__(self, username, password):
54+
self.username = username
55+
self.password = password
56+
57+
def get_mechanism(self):
58+
return "PLAIN"
59+
60+
def get_initial_challenge(self):
61+
return "PLAIN-START"
62+
63+
def evaluate_challenge(self, challenge):
64+
if challenge == 'PLAIN-START':
65+
return "\x00%s\x00%s" % (self.username, self.password)
66+
raise Exception('Did not receive a valid challenge response from server')
67+
68+
69+
class GSSAPIAuthenticator(BaseDSEAuthenticator):
70+
def __init__(self, host, service, qops):
71+
self.sasl = SASLClient(host, service, 'GSSAPI', authorization_id=None, callback=None, qops=qops)
72+
73+
def get_mechanism(self):
74+
return "GSSAPI"
75+
76+
def get_initial_challenge(self):
77+
return "GSSAPI-START"
78+
79+
def evaluate_challenge(self, challenge):
80+
if challenge == 'GSSAPI-START':
81+
return self.sasl.process()
82+
else:
83+
return self.sasl.process(challenge)

dse/cluster.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from cassandra.cluster import Cluster, Session
2+
import dse.cqltypes # unsued here, imported to cause type registration
3+
from dse.util import Point, Circle, LineString, Polygon
4+
5+
6+
class Cluster(Cluster):
7+
8+
def _new_session(self):
9+
session = Session(self, self.metadata.all_hosts())
10+
self._session_register_user_types(session)
11+
self.sessions.add(session)
12+
return session
13+
14+
15+
class Session(Session):
16+
17+
def __init__(self, cluster, hosts):
18+
super(Session, self).__init__(cluster, hosts)
19+
20+
def cql_encode_str_quoted(val):
21+
return "'%s'" % val
22+
23+
for typ in (Point, Circle, LineString, Polygon):
24+
self.encoder.mapping[typ] = cql_encode_str_quoted

dse/cqltypes.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import io
2+
from itertools import chain
3+
from six.moves import range
4+
import struct
5+
from cassandra.cqltypes import CassandraType
6+
from cassandra.util import is_little_endian as _platform_is_le
7+
from dse.marshal import point_be, point_le, circle_be, circle_le
8+
from dse.util import Point, Circle, LineString, Polygon
9+
10+
_endian_flag = 1 if _platform_is_le else 0
11+
12+
13+
class WKBGeometryType(object):
14+
POINT = 1
15+
LINESTRING = 2
16+
POLYGON = 3
17+
CIRCLE = 101 # DSE custom
18+
19+
20+
class PointType(CassandraType):
21+
typename = 'PointType'
22+
23+
_platform_point = point_le if _platform_is_le else point_be
24+
_type = struct.pack('=BI', _endian_flag, WKBGeometryType.POINT)
25+
26+
@staticmethod
27+
def serialize(val, protocol_version):
28+
return PointType._type + PointType._platform_point.pack(val.x, val.y)
29+
30+
@staticmethod
31+
def deserialize(byts, protocol_version):
32+
is_little_endian = bool(byts[0])
33+
point = point_le if is_little_endian else point_be
34+
return Point(*point.unpack_from(byts, 5)) # ofs = endian byte + int type
35+
36+
37+
class CircleType(CassandraType):
38+
typename = 'CircleType'
39+
40+
_platform_circle = circle_le if _platform_is_le else circle_be
41+
_type = struct.pack('=BI', _endian_flag, WKBGeometryType.CIRCLE)
42+
43+
@staticmethod
44+
def serialize(val, protocol_version):
45+
return CircleType._type + CircleType._platform_circle.pack(val.x, val.y, val.r)
46+
47+
@staticmethod
48+
def deserialize(byts, protocol_version):
49+
is_little_endian = bool(byts[0])
50+
circle = circle_le if is_little_endian else circle_be
51+
return Circle(*circle.unpack_from(byts, 5))
52+
53+
54+
class LineStringType(CassandraType):
55+
typename = 'LineStringType'
56+
57+
_type = struct.pack('=BI', _endian_flag, WKBGeometryType.LINESTRING)
58+
59+
@staticmethod
60+
def serialize(val, protocol_version):
61+
num_points = len(val.coords)
62+
return LineStringType._type + struct.pack('=I' + 'dd' * num_points, num_points, *(d for coords in val.coords for d in coords))
63+
64+
@staticmethod
65+
def deserialize(byts, protocol_version):
66+
is_little_endian = bool(byts[0])
67+
point = point_le if is_little_endian else point_be
68+
coords = ((point.unpack_from(byts, offset) for offset in range(1 + 4 + 4, len(byts), point.size))) # start = endian + int type + int count
69+
return LineString(coords)
70+
71+
72+
class PolygonType(CassandraType):
73+
typename = 'PolygonType'
74+
75+
_type = struct.pack('=BI', _endian_flag, WKBGeometryType.POLYGON)
76+
_platform_ring_count = struct.Struct('=I').pack
77+
78+
@staticmethod
79+
def serialize(val, protocol_version):
80+
buf = io.BytesIO(PolygonType._type)
81+
buf.seek(0, 2)
82+
83+
num_rings = 1 + len(val.interiors)
84+
buf.write(PolygonType._platform_ring_count(num_rings))
85+
for ring in chain((val.exterior,), val.interiors):
86+
num_points = len(ring.coords)
87+
buf.write(struct.pack('=I' + 'dd' * num_points, num_points, *(d for coord in ring.coords for d in coord)))
88+
return buf.getvalue()
89+
90+
@staticmethod
91+
def deserialize(byts, protocol_version):
92+
is_little_endian = bool(byts[0])
93+
if is_little_endian:
94+
int_fmt = '<i'
95+
point = point_le
96+
else:
97+
int_fmt = '>i'
98+
point = point_be
99+
p = 5
100+
ring_count = struct.unpack_from(int_fmt, byts, p)[0]
101+
p += 4
102+
rings = []
103+
for _ in range(ring_count):
104+
point_count = struct.unpack_from(int_fmt, byts, p)[0]
105+
p += 4
106+
end = p + point_count * point.size
107+
rings.append([point.unpack_from(byts, offset) for offset in range(p, end, point.size)])
108+
p = end
109+
return Polygon(exterior=rings[0], interiors=rings[1:])

0 commit comments

Comments
 (0)