diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 029ccfac66..be8a5702ec 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1955,7 +1955,7 @@ def user_type_registered(self, keyspace, user_type, klass): def encode(val): return '{ %s }' % ' , '.join('%s : %s' % ( - field_name, + field_name.encode('utf-8') if six.PY2 and isinstance(field_name, six.text_type) else field_name, self.encoder.cql_encode_all_types(getattr(val, field_name, None)) ) for field_name in type_meta.field_names) diff --git a/cassandra/cqlengine/columns.py b/cassandra/cqlengine/columns.py index 11481bcd81..815540a8cf 100644 --- a/cassandra/cqlengine/columns.py +++ b/cassandra/cqlengine/columns.py @@ -18,6 +18,8 @@ import six import warnings +import msgpack + from cassandra import util from cassandra.cqltypes import DateType, SimpleDateType from cassandra.cqlengine import ValidationError @@ -195,6 +197,8 @@ def to_python(self, value): Converts data from the database into python values raises a ValidationError if the value can't be converted """ + if value is None and self.has_default: + return self.get_default() return value def to_database(self, value): @@ -333,6 +337,11 @@ def validate(self, value): raise ValidationError('{0} is shorter than {1} characters'.format(self.column_name, self.min_length)) return value + def to_python(self, value): + value = self.validate(value) + if value is None and self.has_default: + return self.get_default() + return value class Integer(Column): """ @@ -351,7 +360,10 @@ def validate(self, value): raise ValidationError("{0} {1} can't be converted to integral value".format(self.column_name, value)) def to_python(self, value): - return self.validate(value) + value = self.validate(value) + if value is None and self.has_default: + return self.get_default() + return value def to_database(self, value): return self.validate(value) @@ -835,10 +847,11 @@ def sub_columns(self): class UDTValueManager(BaseValueManager): @property def changed(self): - return self.value != self.previous_value or self.value.has_changed_fields() + return self.value != self.previous_value or (self.value is not None and self.value.has_changed_fields()) def reset_previous_value(self): - self.value.reset_changed_fields() + if self.value is not None: + self.value.reset_changed_fields() self.previous_value = copy(self.value) @@ -867,6 +880,13 @@ def __init__(self, user_type, **kwargs): def sub_columns(self): return list(self.user_type._fields.values()) + def to_database(self, value): + ret = deepcopy(value) + for k, v in ret.items(): + col = self.user_type._fields[k] + ret[k] = col.to_database(v) + return ret + def resolve_udts(col_def, out_list): for col in col_def.sub_columns: @@ -897,3 +917,42 @@ def to_database(self, value): def get_cql(self): return "token({0})".format(", ".join(c.cql for c in self.partition_columns)) + + +def decode_datetime(obj): + if b'__datetime__' in obj: + obj = datetime.strptime(obj["as_str"], "%Y%m%dT%H:%M:%S.%f") + return obj + + +def encode_datetime(obj): + if isinstance(obj, datetime): + return {'__datetime__': True, 'as_str': obj.strftime("%Y%m%dT%H:%M:%S.%f")} + return obj + + +class Json(Blob): + def to_database(self, value): + if not value: + return None + value = Column.to_database(self, value) + ret = msgpack.packb(value, use_bin_type=True, default=encode_datetime) + return bytearray(ret) + + def to_python(self, value): + if value is None and self.has_default: + return self.get_default() + if not isinstance(value, (six.binary_type, bytearray)): + return value + return msgpack.unpackb(value, object_hook=decode_datetime, encoding='utf-8') + + +def to_python(self, value): + ret = {} + for col_name, col in self.user_type._fields.items(): + val = value.get(col_name) if value is not None else None + ret[col_name] = col.to_python(val) + return self.user_type(**ret) + + +UserDefinedType.to_python = to_python diff --git a/cassandra/cqlengine/models.py b/cassandra/cqlengine/models.py index ddb7945995..f9e5ae5807 100644 --- a/cassandra/cqlengine/models.py +++ b/cassandra/cqlengine/models.py @@ -259,6 +259,8 @@ def __get__(self, instance, owner): """ try: return instance._values[self.column.column_name].getval() + except KeyError: + return self.column.get_default() except AttributeError: return self.query_evaluator @@ -268,6 +270,8 @@ def __set__(self, instance, value): TODO: use None instance to create update statements """ if instance: + if instance._values.get(self.column.column_name, None) is None: + instance._values[self.column.column_name] = self.column.value_manager(instance, self.column, value) return instance._values[self.column.column_name].setval(value) else: raise AttributeError('cannot reassign column values') @@ -361,11 +365,9 @@ def __init__(self, **values): self._ttl = self.__default_ttl__ self._timestamp = None self._transaction = None - for name, column in self._columns.items(): value = values.get(name, None) - if value is not None or isinstance(column, columns.BaseContainerColumn): - value = column.to_python(value) + value = column.to_python(value) value_mngr = column.value_manager(self, column, value) if name in values: value_mngr.explicit = True @@ -390,6 +392,12 @@ def __str__(self): return '{0} <{1}>'.format(self.__class__.__name__, ', '.join('{0}={1}'.format(k, getattr(self, k)) for k in self._primary_keys.keys())) + def __setstate__(self, state): + # register when unpickle from cache, avoid property missing + state.update(self.__dict__) # pylint: disable=E0203 + self.__dict__ = state + self._timeout = connection.NOT_SET + @classmethod def _discover_polymorphic_submodels(cls): if not cls._is_polymorphic_base: diff --git a/cassandra/cqlengine/usertype.py b/cassandra/cqlengine/usertype.py index 88ec033ba8..6c98bd4bed 100644 --- a/cassandra/cqlengine/usertype.py +++ b/cassandra/cqlengine/usertype.py @@ -27,11 +27,9 @@ class BaseUserType(object): def __init__(self, **values): self._values = {} - for name, field in self._fields.items(): value = values.get(name, None) - if value is not None or isinstance(field, columns.BaseContainerColumn): - value = field.to_python(value) + value = field.to_python(value) value_mngr = field.value_manager(self, field, value) if name in values: value_mngr.explicit = True @@ -56,7 +54,13 @@ def __ne__(self, other): return not self.__eq__(other) def __str__(self): - return "{{{0}}}".format(', '.join("'{0}': {1}".format(k, getattr(self, k)) for k, v in six.iteritems(self._values))) + lst = [] + for k, v in six.iteritems(self._values): + val = getattr(self, k) + if six.PY2 and isinstance(val, six.text_type): + val = val.encode('utf-8') + lst.append("'{0}': {1}".format(k, val)) + return "{{{0}}}".format(', '.join(lst)) def has_changed_fields(self): return any(v.changed for v in self._values.values()) @@ -87,7 +91,7 @@ def __len__(self): try: return self._len except: - self._len = len(self._columns.keys()) + self._len = len(self._fields.keys()) return self._len def keys(self): @@ -145,7 +149,7 @@ class UserTypeMetaClass(type): def __new__(cls, name, bases, attrs): field_dict = OrderedDict() - field_defs = [(k, v) for k, v in attrs.items() if isinstance(v, columns.Column)] + field_defs = [(k.decode('utf-8'), v) for k, v in attrs.items() if isinstance(v, columns.Column)] field_defs = sorted(field_defs, key=lambda x: x[1].position) def _transform_column(field_name, field_obj): diff --git a/cassandra/io/geventreactor.py b/cassandra/io/geventreactor.py index f26e61523c..0340a0c4d6 100644 --- a/cassandra/io/geventreactor.py +++ b/cassandra/io/geventreactor.py @@ -36,7 +36,8 @@ def is_timeout(err): return ( err in (EINPROGRESS, EALREADY, EWOULDBLOCK) or - (err == EINVAL and os.name in ('nt', 'ce')) + (err == EINVAL and os.name in ('nt', 'ce')) or + isinstance(err, socket.timeout) ) @@ -152,7 +153,7 @@ def handle_read(self): if len(buf) < self.in_buffer_size: break except socket.error as err: - if not is_timeout(err): + if not is_timeout(err) and len(buf) != self.in_buffer_size: log.debug("Exception during socket recv for %s: %s", self, err) self.defunct(err) return # leave the read loop