Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added user_defined impl
  • Loading branch information
g-bauer committed Feb 13, 2024
commit 17c0d85a2bdbb5ad9cace26d331f473b0a632277
246 changes: 158 additions & 88 deletions feos-core/src/python/user_defined.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use crate::si::MolarWeight;
use crate::{
Components, DeBroglieWavelength, DeBroglieWavelengthDual, HelmholtzEnergy, HelmholtzEnergyDual,
IdealGas, Residual, StateHD,
};
use crate::{Components, IdealGas, Residual, StateHD};
use ndarray::Array1;
use num_dual::*;
use numpy::convert::IntoPyArray;
use numpy::{PyArray, PyReadonlyArray1, PyReadonlyArrayDyn};
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use quantity::python::PySIArray1;
use std::any::Any;
use std::convert::TryInto;
use std::fmt;

Expand Down Expand Up @@ -64,8 +62,11 @@ impl Components for PyIdealGas {
}

impl IdealGas for PyIdealGas {
fn ideal_gas_model(&self) -> &dyn DeBroglieWavelength {
self
fn ideal_gas_name(&self) -> String {
unimplemented!()
}
fn ln_lambda3<D: DualNum<f64> + Copy>(&self, temperature: D) -> Array1<D> {
unimplemented!()
}
}

Expand All @@ -78,7 +79,6 @@ impl fmt::Display for PyIdealGas {
/// Struct containing pointer to Python Class that implements Helmholtz energy.
pub struct PyResidual {
obj: Py<PyAny>,
contributions: Vec<Box<dyn HelmholtzEnergy>>,
}

impl fmt::Display for PyResidual {
Expand Down Expand Up @@ -110,10 +110,7 @@ impl PyResidual {
if !attr {
panic!("{}", "Python Class has to have a method 'helmholtz_energy' with signature:\n\tdef helmholtz_energy(self, state: StateHD) -> HD\nwhere 'HD' has to be any of {{float, Dual64, HyperDual64, HyperDualDual64, Dual3Dual64, Dual3_64}}.")
}
Ok(Self {
obj: obj.clone(),
contributions: vec![Box::new(PyHelmholtzEnergy(obj))],
})
Ok(Self { obj: obj.clone() })
})
}
}
Expand Down Expand Up @@ -144,37 +141,79 @@ impl Components for PyResidual {
}
}

impl Residual for PyResidual {
fn compute_max_density(&self, moles: &Array1<f64>) -> f64 {
Python::with_gil(|py| {
let py_result = self
.obj
.as_ref(py)
.call_method1("max_density", (moles.to_owned().into_pyarray(py),))
.unwrap();
py_result.extract().unwrap()
})
}
macro_rules! residual {
($($py_state_id:ident, $py_hd_id:ident, $hd_ty:ty);*) => {
impl Residual for PyResidual {
fn compute_max_density(&self, moles: &Array1<f64>) -> f64 {
Python::with_gil(|py| {
let py_result = self
.obj
.as_ref(py)
.call_method1("max_density", (moles.to_owned().into_pyarray(py),))
.unwrap();
py_result.extract().unwrap()
})
}

fn contributions(&self) -> &[Box<dyn HelmholtzEnergy>] {
&self.contributions
}
fn residual_helmholtz_energy<D: DualNum<f64> + Copy>(&self, state: &StateHD<D>) -> D {
// result to write to
let mut a = D::zero();

$(
if let Some(s) = (state as &dyn Any).downcast_ref::<StateHD<$hd_ty>>() {
let d = (&mut a as &mut dyn Any).downcast_mut::<$hd_ty>().unwrap();
*d = Python::with_gil(|py| {
let py_result = self
.obj
.as_ref(py)
.call_method1("helmholtz_energy", (<$py_state_id>::from(s.clone()),))
.unwrap();
<$hd_ty>::from(py_result.extract::<$py_hd_id>().unwrap())
});
return a
}
)*

// // Dual64
// if let Some(s) = (state as &dyn Any).downcast_ref::<StateHD<Dual64>>() {
// let d = (&mut a as &mut dyn Any).downcast_mut::<Dual64>().unwrap();
// *d = Python::with_gil(|py| {
// let py_result = self
// .obj
// .as_ref(py)
// .call_method1("helmholtz_energy", (<PyStateD>::from(s.clone()),))
// .unwrap();
// <Dual64>::from(py_result.extract::<PyDual64>().unwrap())
// });
// return a
// }
panic!("Something went wrong!.")
}

fn molar_weight(&self) -> MolarWeight<Array1<f64>> {
Python::with_gil(|py| {
let py_result = self.obj.as_ref(py).call_method0("molar_weight").unwrap();
if py_result.get_type().name().unwrap() != "SIArray1" {
panic!(
"Expected an 'SIArray1' for the 'molar_weight' method return type, got {}",
py_result.get_type().name().unwrap()
);
fn residual_helmholtz_energy_contributions<D: DualNum<f64> + Copy>(
&self,
state: &StateHD<D>,
) -> Vec<(String, D)> {
unimplemented!()
}
py_result
.extract::<PySIArray1>()
.unwrap()
.try_into()
.unwrap()
})

fn molar_weight(&self) -> MolarWeight<Array1<f64>> {
Python::with_gil(|py| {
let py_result = self.obj.as_ref(py).call_method0("molar_weight").unwrap();
if py_result.get_type().name().unwrap() != "SIArray1" {
panic!(
"Expected an 'SIArray1' for the 'molar_weight' method return type, got {}",
py_result.get_type().name().unwrap()
);
}
py_result
.extract::<PySIArray1>()
.unwrap()
.try_into()
.unwrap()
})
}
}
}
}

Expand Down Expand Up @@ -249,67 +288,68 @@ macro_rules! dual_number {
};
}

macro_rules! helmholtz_energy {
($py_state_id:ident, $py_hd_id:ident, $hd_ty:ty) => {
impl HelmholtzEnergyDual<$hd_ty> for PyHelmholtzEnergy {
fn helmholtz_energy(&self, state: &StateHD<$hd_ty>) -> $hd_ty {
Python::with_gil(|py| {
let py_result = self
.0
.as_ref(py)
.call_method1("helmholtz_energy", (<$py_state_id>::from(state.clone()),))
.unwrap();
<$hd_ty>::from(py_result.extract::<$py_hd_id>().unwrap())
})
}
}
};
}

macro_rules! de_broglie_wavelength {
($py_hd_id:ident, $hd_ty:ty) => {
impl DeBroglieWavelengthDual<$hd_ty> for PyIdealGas {
fn ln_lambda3(&self, temperature: $hd_ty) -> Array1<$hd_ty> {
Python::with_gil(|py| {
let py_result = self
.0
.as_ref(py)
.call_method1("ln_lambda3", (<$py_hd_id>::from(temperature),))
.unwrap();

// f64
let rr = if let Ok(r) = py_result.extract::<PyReadonlyArray1<f64>>() {
r.to_owned_array()
.mapv(|ri| <$hd_ty>::from(ri))
// anything but f64
} else if let Ok(r) = py_result.extract::<PyReadonlyArray1<PyObject>>() {
r.to_owned_array()
.mapv(|ri| <$hd_ty>::from(ri.extract::<$py_hd_id>(py).unwrap()))
} else {
panic!("ln_lambda3: data type of result must be one-dimensional numpy ndarray")
};
rr
})
}
}
};
}
// macro_rules! helmholtz_energy {
Comment thread
prehner marked this conversation as resolved.
Outdated
// ($py_state_id:ident, $py_hd_id:ident, $hd_ty:ty) => {
// impl HelmholtzEnergyDual<$hd_ty> for PyHelmholtzEnergy {
// fn helmholtz_energy(&self, state: &StateHD<$hd_ty>) -> $hd_ty {
// Python::with_gil(|py| {
// let py_result = self
// .0
// .as_ref(py)
// .call_method1("helmholtz_energy", (<$py_state_id>::from(state.clone()),))
// .unwrap();
// <$hd_ty>::from(py_result.extract::<$py_hd_id>().unwrap())
// })
// }
// }
// };
// }

// macro_rules! de_broglie_wavelength {
// ($py_hd_id:ident, $hd_ty:ty) => {
// impl DeBroglieWavelengthDual<$hd_ty> for PyIdealGas {
// fn ln_lambda3(&self, temperature: $hd_ty) -> Array1<$hd_ty> {
// Python::with_gil(|py| {
// let py_result = self
// .0
// .as_ref(py)
// .call_method1("ln_lambda3", (<$py_hd_id>::from(temperature),))
// .unwrap();

// // f64
// let rr = if let Ok(r) = py_result.extract::<PyReadonlyArray1<f64>>() {
// r.to_owned_array()
// .mapv(|ri| <$hd_ty>::from(ri))
// // anything but f64
// } else if let Ok(r) = py_result.extract::<PyReadonlyArray1<PyObject>>() {
// r.to_owned_array()
// .mapv(|ri| <$hd_ty>::from(ri.extract::<$py_hd_id>(py).unwrap()))
// } else {
// panic!("ln_lambda3: data type of result must be one-dimensional numpy ndarray")
// };
// rr
// })
// }
// }
// };
// }

macro_rules! impl_dual_state_helmholtz_energy {
($py_state_id:ident, $py_hd_id:ident, $hd_ty:ty, $py_field_ty:ty) => {
dual_number!($py_hd_id, $hd_ty, $py_field_ty);
state!($py_state_id, $py_hd_id, $hd_ty);
helmholtz_energy!($py_state_id, $py_hd_id, $hd_ty);
de_broglie_wavelength!($py_hd_id, $hd_ty);
// helmholtz_energy!($py_state_id, $py_hd_id, $hd_ty);
// de_broglie_wavelength!($py_hd_id, $hd_ty);
};
}

// No definition of dual number necessary for f64
state!(PyStateF, f64, f64);
helmholtz_energy!(PyStateF, f64, f64);
de_broglie_wavelength!(f64, f64);
// helmholtz_energy!(PyStateF, f64, f64);
// de_broglie_wavelength!(f64, f64);

impl_dual_state_helmholtz_energy!(PyStateD, PyDual64, Dual64, f64);

dual_number!(PyDualVec3, DualSVec64<3>, f64);
impl_dual_state_helmholtz_energy!(
PyStateDualDualVec3,
Expand Down Expand Up @@ -358,3 +398,33 @@ impl_dual_state_helmholtz_energy!(
Dual3<DualSVec64<3>, f64>,
PyDualVec3
);

residual!(
PyStateF, f64, f64;
PyStateD, PyDual64, Dual64;
PyStateDualDualVec3,
PyDualDualVec3,
Dual<DualSVec64<3>, f64>;
PyStateHD, PyHyperDual64, HyperDual64;
PyStateD2, PyDual2_64, Dual2_64;
PyStateD3, PyDual3_64, Dual3_64;
PyStateHDD, PyHyperDualDual64, HyperDual<Dual64, f64>;
PyStateHDDVec2,
PyHyperDualVec2,
HyperDual<DualSVec64<2>, f64>;
PyStateHDDVec3,
PyHyperDualVec3,
HyperDual<DualSVec64<3>, f64>;
PyStateD2D,
PyDual2Dual64,
Dual2<Dual64, f64>;
PyStateD3D,
PyDual3Dual64,
Dual3<Dual64, f64>;
PyStateD3DVec2,
PyDual3DualVec2,
Dual3<DualSVec64<2>, f64>;
PyStateD3DVec3,
PyDual3DualVec3,
Dual3<DualSVec64<3>, f64>
);