Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ matrix:
env: PY=3
env:
global:
- MINCONDA_VERSION="latest"
- MINCONDA_VERSION="4.3.21"
- MINCONDA_LINUX="Linux-x86_64"
- MINCONDA_OSX="MacOSX-x86_64"
before_install:
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ set(XTENSOR_PYTHON_HEADERS
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pytensor.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/pyvectorize.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_python_config.hpp
${XTENSOR_PYTHON_INCLUDE_DIR}/xtensor-python/xtensor_type_caster_base.hpp
)

OPTION(BUILD_TESTS "xtensor test suite" OFF)
Expand Down
24 changes: 24 additions & 0 deletions include/xtensor-python/pyarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "pycontainer.hpp"
#include "pystrides_adaptor.hpp"
#include "xtensor_type_caster_base.hpp"

namespace xt
{
Expand Down Expand Up @@ -69,6 +70,29 @@ namespace pybind11

PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};

// Type caster for casting ndarray to xexpression<pyarray>
template<typename T>
struct type_caster<xt::xexpression<xt::pyarray<T>>> : pyobject_caster<xt::pyarray<T>>
{
using Type = xt::xexpression<xt::pyarray<T>>;

operator Type&()
{
return this->value;
}

operator const Type&()
{
return this->value;
}
};

// Type caster for casting xarray to ndarray
template<class T>
struct type_caster<xt::xarray<T>> : xtensor_type_caster_base<xt::xarray<T>>
{
};
}
}

Expand Down
24 changes: 24 additions & 0 deletions include/xtensor-python/pytensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "pycontainer.hpp"
#include "pystrides_adaptor.hpp"
#include "xtensor_type_caster_base.hpp"

namespace xt
{
Expand Down Expand Up @@ -71,6 +72,29 @@ namespace pybind11

PYBIND11_TYPE_CASTER(type, handle_type_name<type>::name());
};

// Type caster for casting ndarray to xexpression<pytensor>
template<class T, std::size_t N>
struct type_caster<xt::xexpression<xt::pytensor<T, N>>> : pyobject_caster<xt::pytensor<T, N>>
{
using Type = xt::xexpression<xt::pytensor<T, N>>;

operator Type&()
{
return this->value;
}

operator const Type&()
{
return this->value;
}
};

// Type caster for casting xt::xtensor to ndarray
template<class T, std::size_t N>
struct type_caster<xt::xtensor<T, N>> : xtensor_type_caster_base<xt::xtensor<T, N>>
{
};
}
}

Expand Down
162 changes: 162 additions & 0 deletions include/xtensor-python/xtensor_type_caster_base.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
xtensor-python/xtensor_type_caster.hpp: Transparent conversion for xtensor and xarray

This code is based on the following code written by Wenzei Jakob

pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices

Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>

All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE file.
*/


#ifndef XTENSOR_TYPE_CASTER_HPP
#define XTENSOR_TYPE_CASTER_HPP

#include "xtensor/xtensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>

namespace pybind11
{
namespace detail
{
// Casts an xtensor (or xarray) type to numpy array. If given a base, the numpy array references the src data,
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
template<typename Type>
handle xtensor_array_cast(Type const &src, handle base = handle(), bool writeable = true)
{
std::vector<size_t> python_strides(src.strides().size());
std::transform(src.strides().begin(), src.strides().end(), python_strides.begin(),
[](auto v) { return sizeof(typename Type::value_type) * v; });

std::vector<size_t> python_shape(src.shape().size());
std::copy(src.shape().begin(), src.shape().end(), python_shape.begin());

array a(python_shape, python_strides, src.begin(), base);

if (!writeable)
{
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;
}

return a.release();
}

// Takes an lvalue ref to some xtensor (or xarray) type and a (python) base object, creating a numpy array that
// reference the xtensor object's data with `base` as the python-registered base class (if omitted,
// the base will be set to None, and lifetime management is up to the caller). The numpy array is
// non-writeable if the given type is const.
template <typename Type, typename CType>
handle xtensor_ref_array(CType &src, handle parent = none())
{
return xtensor_array_cast<Type>(src, parent, !std::is_const<CType>::value);
}

// Takes a pointer to xtensor (or xarray), builds a capsule around it, then returns a numpy
// array that references the encapsulated data with a python-side reference to the capsule to tie
// its destruction to that of any dependent python objects. Const-ness is determined by whether or
// not the CType of the pointer given is const.
template <typename Type, typename CType>
handle xtensor_encapsulate(CType *src)
{
capsule base(src, [](void *o) { delete static_cast<CType *>(o); });
return xtensor_ref_array<Type>(*src, base);
}

// Base class of type_caster for xtensor and xarray
template<class Type>
struct xtensor_type_caster_base
{
bool load(handle src, bool)
{
return false;
}

private:

// Cast implementation
template <typename CType>
static handle cast_impl(CType *src, return_value_policy policy, handle parent)
{
switch (policy)
{
case return_value_policy::take_ownership:
case return_value_policy::automatic:
return xtensor_encapsulate<Type>(src);
case return_value_policy::move:
return xtensor_encapsulate<Type>(new CType(std::move(*src)));
case return_value_policy::copy:
return xtensor_array_cast<Type>(*src);
case return_value_policy::reference:
case return_value_policy::automatic_reference:
return xtensor_ref_array<Type>(*src);
case return_value_policy::reference_internal:
return xtensor_ref_array<Type>(*src, parent);
default:
throw cast_error("unhandled return_value_policy: should not happen!");
};
}

public:

// Normal returned non-reference, non-const value:
static handle cast(Type &&src, return_value_policy /* policy */, handle parent)
{
return cast_impl(&src, return_value_policy::move, parent);
}

// If you return a non-reference const, we mark the numpy array readonly:
static handle cast(const Type &&src, return_value_policy /* policy */, handle parent)
{
return cast_impl(&src, return_value_policy::move, parent);
}

// lvalue reference return; default (automatic) becomes copy
static handle cast(Type &src, return_value_policy policy, handle parent)
{
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
{
policy = return_value_policy::copy;
}

return cast_impl(&src, policy, parent);
}

// const lvalue reference return; default (automatic) becomes copy
static handle cast(const Type &src, return_value_policy policy, handle parent)
{
if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference)
{
policy = return_value_policy::copy;
}

return cast(&src, policy, parent);
}

// non-const pointer return
static handle cast(Type *src, return_value_policy policy, handle parent)
{
return cast_impl(src, policy, parent);
}

// const pointer return
static handle cast(const Type *src, return_value_policy policy, handle parent)
{
return cast_impl(src, policy, parent);
}

static PYBIND11_DESCR name()
{
return _("xt::xtensor");
}

template <typename T>
using cast_op_type = cast_op_type<T>;
};
}
}

#endif