Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ gauss-quad = "0.2"
approx = "0.5"
criterion = "0.8"
paste = "1.0"
rusqlite = "0.39"

feos-core = { version = "0.9", path = "crates/feos-core" }
feos-dft = { version = "0.9", path = "crates/feos-dft" }
Expand Down
1 change: 1 addition & 0 deletions crates/feos-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ indexmap = { workspace = true, features = ["serde"] }
rayon = { workspace = true, optional = true }
typenum = { workspace = true }
itertools = { workspace = true }
rusqlite = { workspace = true, features = ["bundled"], optional = true }

[dev-dependencies]
approx = { workspace = true }
Expand Down
3 changes: 3 additions & 0 deletions crates/feos-core/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ pub enum FeosError {
#[cfg(feature = "ndarray")]
#[error(transparent)]
ShapeError(#[from] ndarray::ShapeError),
#[cfg(feature = "rusqlite")]
#[error(transparent)]
RusqLiteError(#[from] rusqlite::Error),
}

/// Convenience type for `Result<T, FeosError>`.
Expand Down
116 changes: 116 additions & 0 deletions crates/feos-core/src/parameter/database.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use super::{
BinaryAssociationRecord, BinaryRecord, CombiningRule, IdentifierOption, Parameters, PureRecord,
};
use crate::FeosResult;
use rusqlite::{Connection, ToSql, params_from_iter};
use serde::de::DeserializeOwned;
use std::path::Path;

impl<P: Clone, B: Clone, A: CombiningRule<P> + Clone> Parameters<P, B, A> {
pub fn from_database<F, S>(
substances: &[S],
file: F,
identifier_option: IdentifierOption,
) -> FeosResult<Self>
where
F: AsRef<Path>,
S: ToSql,
P: DeserializeOwned + Clone,
B: DeserializeOwned + Clone,
A: DeserializeOwned + Clone,
{
let conn = Connection::open(file)?;
let pure_records = PureRecord::from_database(substances, &conn, identifier_option)?;
let binary_records = BinaryRecord::from_database(substances, &conn, identifier_option)?;
Self::new(pure_records, binary_records)
}
}

impl<M, A> PureRecord<M, A> {
pub fn from_database<S>(
substances: &[S],
connection: &Connection,
identifier_option: IdentifierOption,
) -> FeosResult<Vec<Self>>
where
S: ToSql,
M: DeserializeOwned,
A: DeserializeOwned,
{
let values = (0..substances.len())
.map(|i| format!("({i},?)"))
.collect::<Vec<_>>()
.join(",");

let query = format!(
"
WITH input(idx, ident) AS (
VALUES {values}
)
SELECT pr.pure_record
FROM input
JOIN pure_records pr
ON pr.{identifier_option} = input.ident
"
);
let mut stmt = connection.prepare(&query)?;
stmt.query_and_then(params_from_iter(substances), |r| {
Ok(serde_json::from_str(&r.get::<_, String>("pure_record")?)?)
})?
.collect()
}
}

impl<B, A> BinaryRecord<usize, B, A> {
pub fn from_database<S>(
substances: &[S],
connection: &Connection,
identifier_option: IdentifierOption,
) -> FeosResult<Vec<Self>>
where
S: ToSql,
B: DeserializeOwned,
A: DeserializeOwned,
{
let values = (0..substances.len())
.map(|i| format!("({i},?)"))
.collect::<Vec<_>>()
.join(",");

let query = format!(
"
WITH input(idx, ident) AS (
VALUES {values}
)
SELECT i1.idx AS comp1, i2.idx AS comp2, br.model_record, br.association_sites
FROM binary_records br
JOIN pure_records p1 ON br.id1 = p1.id
JOIN pure_records p2 ON br.id2 = p2.id
JOIN input i1 ON p1.{identifier_option} = i1.ident
JOIN input i2 ON p2.{identifier_option} = i2.ident
"
);
let mut stmt = connection.prepare(&query)?;
stmt.query_and_then(params_from_iter(substances), |r| {
let mut id1: i32 = r.get("comp1")?;
let mut id2: i32 = r.get("comp2")?;
let model_record = serde_json::from_str(&r.get::<_, String>("model_record")?)?;
let mut association_sites: Vec<BinaryAssociationRecord<_>> =
serde_json::from_str(&r.get::<_, String>("association_sites")?)?;
if id1 > id2 {
association_sites
.iter_mut()
.for_each(|a| std::mem::swap(&mut a.id1, &mut a.id2));
std::mem::swap(&mut id1, &mut id2);
};

Ok(BinaryRecord::with_association(
id1 as usize,
id2 as usize,
model_record,
association_sites,
))
})?
.collect()
}
}
6 changes: 3 additions & 3 deletions crates/feos-core/src/parameter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use std::path::Path;

mod association;
mod chemical_record;
#[cfg(feature = "rusqlite")]
mod database;
mod identifier;
mod model_record;

Expand Down Expand Up @@ -349,9 +351,7 @@ impl<P: Clone, B: Clone, A: CombiningRule<P> + Clone> Parameters<P, B, A> {
let records = PureRecord::from_multiple_json(input, identifier_option)?;

let binary_records = if let Some(path) = file_binary {
let file = File::open(path)?;
let reader = BufReader::new(file);
serde_json::from_reader(reader)?
BinaryRecord::from_json(path)?
} else {
Vec::new()
};
Expand Down
3 changes: 2 additions & 1 deletion py-feos/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ rayon = { workspace = true, optional = true }
itertools = { workspace = true }
typenum = { workspace = true }
paste = { workspace = true }
rusqlite = { workspace = true, features = ["bundled"]}

feos = { workspace = true }
feos-core = { workspace = true, features = ["ndarray"] }
feos-core = { workspace = true, features = ["ndarray", "rusqlite"] }
feos-derive = { workspace = true }
feos-dft = { workspace = true, optional = true }

Expand Down
42 changes: 42 additions & 0 deletions py-feos/src/parameter/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::error::PyFeosError;
use feos_core::FeosError;
use feos_core::parameter::*;
use indexmap::IndexSet;
use pyo3::prelude::*;
Expand Down Expand Up @@ -222,6 +223,42 @@ impl PyParameters {
})
}

/// Creates parameters from a database file.
///
/// Parameters
/// ----------
/// substances : List[str]
/// The substances to search.
/// path : str
/// Path to database file.
/// identifier_option : IdentifierOption, optional, defaults to IdentifierOption.Name
/// Identifier that is used to search substance.
#[staticmethod]
#[pyo3(
signature = (substances, path, identifier_option=PyIdentifierOption::Name),
text_signature = "(substances, path, identifier_option=IdentifierOption.Name)"
)]
fn from_database(
substances: Vec<String>,
path: String,
identifier_option: PyIdentifierOption,
) -> PyResult<Self> {
let identifier_option = IdentifierOption::from(identifier_option);
let conn = rusqlite::Connection::open(path)
.map_err(FeosError::from)
.map_err(PyFeosError::from)?;

let pure_records = PureRecord::from_database(&substances, &conn, identifier_option)
.map_err(PyFeosError::from)?;
let binary_records = BinaryRecord::from_database(&substances, &conn, identifier_option)
.map_err(PyFeosError::from)?;

Ok(Self {
pure_records,
binary_records,
})
}

/// Generates JSON-formatted string for pure and binary records (if initialized).
///
/// Parameters
Expand Down Expand Up @@ -488,6 +525,10 @@ impl PyParameters {
for &p in &params {
format_optional(o, model_record.get(p));
}
} else {
for _ in &params {
format_optional(o, None);
}
}
if !r.association_sites.is_empty() {
let s = &r.association_sites[0];
Expand Down Expand Up @@ -557,6 +598,7 @@ impl PyGcParameters {
.map_err(PyFeosError::from)?)
}

#[cfg(feature = "gc_pcsaft")]
pub fn try_convert_heterosegmented<P, B, A, C: GroupCount + Default>(
self,
) -> PyResult<GcParameters<P, B, A, (), C>>
Expand Down
Loading