Skip to content

Commit 9869c2a

Browse files
authored
PYTHON-751: Add IPV4Address/IPV6Address support for inet type (apache#828)
* Add IPV4Address/IPV6Address support for inet type
1 parent a549ca6 commit 9869c2a

10 files changed

Lines changed: 103 additions & 44 deletions

File tree

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Features
55
--------
66
* Send keyspace in QUERY, PREPARE, and BATCH messages (PYTHON-678)
7+
* Add IPv4Address/IPv6Address support for inet types (PYTHON-751)
78

89
Bug Fixes
910
---------

cassandra/cqltypes.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
from uuid import UUID
4444
import warnings
4545

46+
if six.PY3:
47+
import ipaddress
4648

4749
from cassandra.marshal import (int8_pack, int8_unpack, int16_pack, int16_unpack,
4850
uint16_pack, uint16_unpack, uint32_pack, uint32_unpack,
@@ -517,12 +519,17 @@ def deserialize(byts, protocol_version):
517519

518520
@staticmethod
519521
def serialize(addr, protocol_version):
520-
if ':' in addr:
521-
return util.inet_pton(socket.AF_INET6, addr)
522-
else:
523-
# util.inet_pton could also handle, but this is faster
524-
# since we've already determined the AF
525-
return socket.inet_aton(addr)
522+
try:
523+
if ':' in addr:
524+
return util.inet_pton(socket.AF_INET6, addr)
525+
else:
526+
# util.inet_pton could also handle, but this is faster
527+
# since we've already determined the AF
528+
return socket.inet_aton(addr)
529+
except:
530+
if six.PY3 and isinstance(addr, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
531+
return addr.packed
532+
raise ValueError("can't interpret %r as an inet address" % (addr,))
526533

527534

528535
class CounterColumnType(LongType):

cassandra/encoder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from uuid import UUID
3030
import six
3131

32+
if six.PY3:
33+
import ipaddress
34+
3235
from cassandra.util import (OrderedDict, OrderedMap, OrderedMapSerializedKey,
3336
sortedset, Time, Date)
3437

@@ -103,6 +106,8 @@ def __init__(self):
103106
memoryview: self.cql_encode_bytes,
104107
bytes: self.cql_encode_bytes,
105108
type(None): self.cql_encode_none,
109+
ipaddress.IPv4Address: self.cql_encode_ipaddress,
110+
ipaddress.IPv6Address: self.cql_encode_ipaddress
106111
})
107112

108113
def cql_encode_none(self, val):
@@ -225,3 +230,11 @@ def cql_encode_all_types(self, val):
225230
if :attr:`~Encoder.mapping` does not contain an entry for the type.
226231
"""
227232
return self.mapping.get(type(val), self.cql_encode_object)(val)
233+
234+
if six.PY3:
235+
def cql_encode_ipaddress(self, val):
236+
"""
237+
Converts an ipaddress (IPV4Address, IPV6Address) to a CQL string. This
238+
is suitable for ``inet`` type columns.
239+
"""
240+
return "'%s'" % val.compressed

tests/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,33 @@ def is_monkey_patched():
5959
return is_gevent_monkey_patched() or is_eventlet_monkey_patched()
6060

6161

62+
EVENT_LOOP_MANAGER = os.getenv('EVENT_LOOP_MANAGER', "libev")
63+
if "gevent" in EVENT_LOOP_MANAGER:
64+
import gevent.monkey
65+
gevent.monkey.patch_all()
66+
from cassandra.io.geventreactor import GeventConnection
67+
connection_class = GeventConnection
68+
elif "eventlet" in EVENT_LOOP_MANAGER:
69+
from eventlet import monkey_patch
70+
monkey_patch()
71+
72+
from cassandra.io.eventletreactor import EventletConnection
73+
connection_class = EventletConnection
74+
elif "async" in EVENT_LOOP_MANAGER:
75+
from cassandra.io.asyncorereactor import AsyncoreConnection
76+
connection_class = AsyncoreConnection
77+
elif "twisted" in EVENT_LOOP_MANAGER:
78+
from cassandra.io.twistedreactor import TwistedConnection
79+
connection_class = TwistedConnection
80+
81+
else:
82+
try:
83+
from cassandra.io.libevreactor import LibevConnection
84+
connection_class = LibevConnection
85+
except ImportError:
86+
connection_class = None
87+
88+
6289
MONKEY_PATCH_LOOP = bool(os.getenv('MONKEY_PATCH_LOOP', False))
6390

6491
notwindows = unittest.skipUnless(not "Windows" in platform.system(), "This test is not adequate for windows")

tests/integration/__init__.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15-
16-
EVENT_LOOP_MANAGER = os.getenv('EVENT_LOOP_MANAGER', "libev")
17-
if "gevent" in EVENT_LOOP_MANAGER:
18-
import gevent.monkey
19-
gevent.monkey.patch_all()
20-
from cassandra.io.geventreactor import GeventConnection
21-
connection_class = GeventConnection
22-
elif "eventlet" in EVENT_LOOP_MANAGER:
23-
from eventlet import monkey_patch
24-
monkey_patch()
25-
26-
from cassandra.io.eventletreactor import EventletConnection
27-
connection_class = EventletConnection
28-
elif "async" in EVENT_LOOP_MANAGER:
29-
from cassandra.io.asyncorereactor import AsyncoreConnection
30-
connection_class = AsyncoreConnection
31-
elif "twisted" in EVENT_LOOP_MANAGER:
32-
from cassandra.io.twistedreactor import TwistedConnection
33-
connection_class = TwistedConnection
34-
35-
else:
36-
from cassandra.io.libevreactor import LibevConnection
37-
connection_class = LibevConnection
38-
3915
from cassandra.cluster import Cluster
16+
17+
from tests import connection_class, EVENT_LOOP_MANAGER
4018
Cluster.connection_class = connection_class
4119

4220
try:

tests/integration/cqlengine/columns/test_validation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,6 @@ def test_default_zero_fields_validate(self):
529529

530530

531531
class TestAscii(BaseCassEngTestCase):
532-
533532
def test_min_length(self):
534533
""" Test arbitrary minimal lengths requirements. """
535534

tests/integration/datatype_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from decimal import Decimal
1616
from datetime import datetime, date, time
1717
from uuid import uuid1, uuid4
18+
import six
1819

1920
from cassandra.util import OrderedMap, Date, Time, sortedset, Duration
2021

@@ -90,7 +91,11 @@ def get_sample_data():
9091
sample_data[datatype] = 3.4028234663852886e+38
9192

9293
elif datatype == 'inet':
93-
sample_data[datatype] = '123.123.123.123'
94+
sample_data[datatype] = ('123.123.123.123', '2001:db8:85a3:8d3:1319:8a2e:370:7348')
95+
if six.PY3:
96+
import ipaddress
97+
sample_data[datatype] += (ipaddress.IPv4Address("123.123.123.123"),
98+
ipaddress.IPv6Address('2001:db8:85a3:8d3:1319:8a2e:370:7348'))
9499

95100
elif datatype == 'int':
96101
sample_data[datatype] = 2147483647
@@ -140,10 +145,20 @@ def get_sample(datatype):
140145
"""
141146
Helper method to access created sample data for primitive types
142147
"""
143-
148+
if isinstance(SAMPLE_DATA[datatype], tuple):
149+
return SAMPLE_DATA[datatype][0]
144150
return SAMPLE_DATA[datatype]
145151

146152

153+
def get_all_samples(datatype):
154+
"""
155+
Helper method to access created sample data for primitive types
156+
"""
157+
if isinstance(SAMPLE_DATA[datatype], tuple):
158+
return SAMPLE_DATA[datatype]
159+
return SAMPLE_DATA[datatype],
160+
161+
147162
def get_collection_sample(collection_type, datatype):
148163
"""
149164
Helper method to access created sample data for collection types

tests/integration/standard/test_cluster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from cassandra.pool import Host
3535
from cassandra.query import SimpleStatement, TraceUnavailable, tuple_factory
3636

37-
37+
from tests import notwindows
3838
from tests.integration import use_singledc, PROTOCOL_VERSION, get_server_versions, CASSANDRA_VERSION, \
3939
execute_until_pass, execute_with_long_wait_retry, get_node, MockLoggingHandler, get_unsupported_lower_protocol, \
4040
get_unsupported_upper_protocol, protocolv5, local, CASSANDRA_IP
@@ -476,6 +476,7 @@ def test_refresh_schema_type(self):
476476
cluster.shutdown()
477477

478478
@local
479+
@notwindows
479480
def test_refresh_schema_no_wait(self):
480481
contact_points = [CASSANDRA_IP]
481482
cluster = Cluster(protocol_version=PROTOCOL_VERSION, max_schema_agreement_wait=10,

tests/integration/standard/test_types.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from tests.integration import use_singledc, PROTOCOL_VERSION, execute_until_pass, notprotocolv1, \
3434
BasicSharedKeyspaceUnitTestCase, greaterthancass21, lessthancass30, greaterthanorequalcass3_10
3535
from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, PRIMITIVE_DATATYPES_KEYS, \
36-
get_sample, get_collection_sample
36+
get_sample, get_all_samples, get_collection_sample
3737

3838

3939
def setup_module():
@@ -161,8 +161,30 @@ def test_can_insert_primitive_datatypes(self):
161161
for expected, actual in zip(params, results):
162162
self.assertEqual(actual, expected)
163163

164+
# try the same thing sending one insert at the time
165+
s.execute("TRUNCATE alltypes;")
166+
for i, datatype in enumerate(PRIMITIVE_DATATYPES):
167+
single_col_name = chr(start_index + i)
168+
single_col_names = ["zz", single_col_name]
169+
placeholders = ','.join(["%s"] * len(single_col_names))
170+
single_columns_string = ', '.join(single_col_names)
171+
for j, data_sample in enumerate(get_all_samples(datatype)):
172+
key = i + 1000 * j
173+
single_params = (key, data_sample)
174+
s.execute("INSERT INTO alltypes ({0}) VALUES ({1})".format(single_columns_string, placeholders),
175+
single_params)
176+
# verify data
177+
result = s.execute("SELECT {0} FROM alltypes WHERE zz=%s".format(single_columns_string), (key,))[0][1]
178+
compare_value = data_sample
179+
if six.PY3:
180+
import ipaddress
181+
if isinstance(data_sample, ipaddress.IPv4Address) or isinstance(data_sample, ipaddress.IPv6Address):
182+
compare_value = str(data_sample)
183+
self.assertEqual(result, compare_value)
184+
164185
# try the same thing with a prepared statement
165186
placeholders = ','.join(["?"] * len(col_names))
187+
s.execute("TRUNCATE alltypes;")
166188
insert = s.prepare("INSERT INTO alltypes ({0}) VALUES ({1})".format(columns_string, placeholders))
167189
s.execute(insert.bind(params))
168190

tests/unit/test_cluster.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@
2727
from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory
2828
from cassandra.pool import Host
2929
from tests.unit.utils import mock_session_pools
30-
31-
try:
32-
from cassandra.io.libevreactor import LibevConnection
33-
except ImportError:
34-
LibevConnection = None # noqa
30+
from tests import connection_class
3531

3632

3733
class ExceptionTypeTest(unittest.TestCase):
@@ -129,9 +125,9 @@ def test_event_delay_timing(self, *_):
129125

130126
class SessionTest(unittest.TestCase):
131127
def setUp(self):
132-
if LibevConnection is None:
128+
if connection_class is None:
133129
raise unittest.SkipTest('libev does not appear to be installed correctly')
134-
LibevConnection.initialize_reactor()
130+
connection_class.initialize_reactor()
135131

136132
# TODO: this suite could be expanded; for now just adding a test covering a PR
137133
@mock_session_pools
@@ -164,9 +160,9 @@ def test_default_serial_consistency_level(self, *_):
164160

165161
class ExecutionProfileTest(unittest.TestCase):
166162
def setUp(self):
167-
if LibevConnection is None:
163+
if connection_class is None:
168164
raise unittest.SkipTest('libev does not appear to be installed correctly')
169-
LibevConnection.initialize_reactor()
165+
connection_class.initialize_reactor()
170166

171167
def _verify_response_future_profile(self, rf, prof):
172168
self.assertEqual(rf._load_balancer, prof.load_balancing_policy)

0 commit comments

Comments
 (0)