From 69d21f34ed61d99814cd7c3bf8f5d923a7a40266 Mon Sep 17 00:00:00 2001 From: Philipp Rehner Date: Sun, 15 Mar 2026 10:49:57 +0100 Subject: [PATCH] Allow parameter import from SQL --- Cargo.toml | 1 + crates/feos-core/Cargo.toml | 1 + crates/feos-core/src/errors.rs | 3 + crates/feos-core/src/parameter/database.rs | 116 +++++++++++++++++++++ crates/feos-core/src/parameter/mod.rs | 6 +- py-feos/Cargo.toml | 3 +- py-feos/src/parameter/mod.rs | 42 ++++++++ 7 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 crates/feos-core/src/parameter/database.rs diff --git a/Cargo.toml b/Cargo.toml index 5f4bd6586..daca710d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ gauss-quad = "0.2" approx = "0.5" criterion = "0.8" paste = "1.0" +rusqlite = "0.39" feos-core = { version = "0.9", path = "crates/feos-core" } feos-dft = { version = "0.9", path = "crates/feos-dft" } diff --git a/crates/feos-core/Cargo.toml b/crates/feos-core/Cargo.toml index 94717bdb3..15cfd525f 100644 --- a/crates/feos-core/Cargo.toml +++ b/crates/feos-core/Cargo.toml @@ -27,6 +27,7 @@ indexmap = { workspace = true, features = ["serde"] } rayon = { workspace = true, optional = true } typenum = { workspace = true } itertools = { workspace = true } +rusqlite = { workspace = true, features = ["bundled"], optional = true } [dev-dependencies] approx = { workspace = true } diff --git a/crates/feos-core/src/errors.rs b/crates/feos-core/src/errors.rs index 24834480f..80c73b309 100644 --- a/crates/feos-core/src/errors.rs +++ b/crates/feos-core/src/errors.rs @@ -62,6 +62,9 @@ pub enum FeosError { #[cfg(feature = "ndarray")] #[error(transparent)] ShapeError(#[from] ndarray::ShapeError), + #[cfg(feature = "rusqlite")] + #[error(transparent)] + RusqLiteError(#[from] rusqlite::Error), } /// Convenience type for `Result`. diff --git a/crates/feos-core/src/parameter/database.rs b/crates/feos-core/src/parameter/database.rs new file mode 100644 index 000000000..9a6edabe8 --- /dev/null +++ b/crates/feos-core/src/parameter/database.rs @@ -0,0 +1,116 @@ +use super::{ + BinaryAssociationRecord, BinaryRecord, CombiningRule, IdentifierOption, Parameters, PureRecord, +}; +use crate::FeosResult; +use rusqlite::{Connection, ToSql, params_from_iter}; +use serde::de::DeserializeOwned; +use std::path::Path; + +impl + Clone> Parameters { + pub fn from_database( + substances: &[S], + file: F, + identifier_option: IdentifierOption, + ) -> FeosResult + where + F: AsRef, + S: ToSql, + P: DeserializeOwned + Clone, + B: DeserializeOwned + Clone, + A: DeserializeOwned + Clone, + { + let conn = Connection::open(file)?; + let pure_records = PureRecord::from_database(substances, &conn, identifier_option)?; + let binary_records = BinaryRecord::from_database(substances, &conn, identifier_option)?; + Self::new(pure_records, binary_records) + } +} + +impl PureRecord { + pub fn from_database( + substances: &[S], + connection: &Connection, + identifier_option: IdentifierOption, + ) -> FeosResult> + where + S: ToSql, + M: DeserializeOwned, + A: DeserializeOwned, + { + let values = (0..substances.len()) + .map(|i| format!("({i},?)")) + .collect::>() + .join(","); + + let query = format!( + " + WITH input(idx, ident) AS ( + VALUES {values} + ) + SELECT pr.pure_record + FROM input + JOIN pure_records pr + ON pr.{identifier_option} = input.ident + " + ); + let mut stmt = connection.prepare(&query)?; + stmt.query_and_then(params_from_iter(substances), |r| { + Ok(serde_json::from_str(&r.get::<_, String>("pure_record")?)?) + })? + .collect() + } +} + +impl BinaryRecord { + pub fn from_database( + substances: &[S], + connection: &Connection, + identifier_option: IdentifierOption, + ) -> FeosResult> + where + S: ToSql, + B: DeserializeOwned, + A: DeserializeOwned, + { + let values = (0..substances.len()) + .map(|i| format!("({i},?)")) + .collect::>() + .join(","); + + let query = format!( + " + WITH input(idx, ident) AS ( + VALUES {values} + ) + SELECT i1.idx AS comp1, i2.idx AS comp2, br.model_record, br.association_sites + FROM binary_records br + JOIN pure_records p1 ON br.id1 = p1.id + JOIN pure_records p2 ON br.id2 = p2.id + JOIN input i1 ON p1.{identifier_option} = i1.ident + JOIN input i2 ON p2.{identifier_option} = i2.ident + " + ); + let mut stmt = connection.prepare(&query)?; + stmt.query_and_then(params_from_iter(substances), |r| { + let mut id1: i32 = r.get("comp1")?; + let mut id2: i32 = r.get("comp2")?; + let model_record = serde_json::from_str(&r.get::<_, String>("model_record")?)?; + let mut association_sites: Vec> = + serde_json::from_str(&r.get::<_, String>("association_sites")?)?; + if id1 > id2 { + association_sites + .iter_mut() + .for_each(|a| std::mem::swap(&mut a.id1, &mut a.id2)); + std::mem::swap(&mut id1, &mut id2); + }; + + Ok(BinaryRecord::with_association( + id1 as usize, + id2 as usize, + model_record, + association_sites, + )) + })? + .collect() + } +} diff --git a/crates/feos-core/src/parameter/mod.rs b/crates/feos-core/src/parameter/mod.rs index 578109c68..350773c71 100644 --- a/crates/feos-core/src/parameter/mod.rs +++ b/crates/feos-core/src/parameter/mod.rs @@ -15,6 +15,8 @@ use std::path::Path; mod association; mod chemical_record; +#[cfg(feature = "rusqlite")] +mod database; mod identifier; mod model_record; @@ -349,9 +351,7 @@ impl + Clone> Parameters { let records = PureRecord::from_multiple_json(input, identifier_option)?; let binary_records = if let Some(path) = file_binary { - let file = File::open(path)?; - let reader = BufReader::new(file); - serde_json::from_reader(reader)? + BinaryRecord::from_json(path)? } else { Vec::new() }; diff --git a/py-feos/Cargo.toml b/py-feos/Cargo.toml index 0ab58879b..c68f5dd04 100644 --- a/py-feos/Cargo.toml +++ b/py-feos/Cargo.toml @@ -37,9 +37,10 @@ rayon = { workspace = true, optional = true } itertools = { workspace = true } typenum = { workspace = true } paste = { workspace = true } +rusqlite = { workspace = true, features = ["bundled"]} feos = { workspace = true } -feos-core = { workspace = true, features = ["ndarray"] } +feos-core = { workspace = true, features = ["ndarray", "rusqlite"] } feos-derive = { workspace = true } feos-dft = { workspace = true, optional = true } diff --git a/py-feos/src/parameter/mod.rs b/py-feos/src/parameter/mod.rs index 115f8f79a..503d41042 100644 --- a/py-feos/src/parameter/mod.rs +++ b/py-feos/src/parameter/mod.rs @@ -1,4 +1,5 @@ use crate::error::PyFeosError; +use feos_core::FeosError; use feos_core::parameter::*; use indexmap::IndexSet; use pyo3::prelude::*; @@ -222,6 +223,42 @@ impl PyParameters { }) } + /// Creates parameters from a database file. + /// + /// Parameters + /// ---------- + /// substances : List[str] + /// The substances to search. + /// path : str + /// Path to database file. + /// identifier_option : IdentifierOption, optional, defaults to IdentifierOption.Name + /// Identifier that is used to search substance. + #[staticmethod] + #[pyo3( + signature = (substances, path, identifier_option=PyIdentifierOption::Name), + text_signature = "(substances, path, identifier_option=IdentifierOption.Name)" + )] + fn from_database( + substances: Vec, + path: String, + identifier_option: PyIdentifierOption, + ) -> PyResult { + let identifier_option = IdentifierOption::from(identifier_option); + let conn = rusqlite::Connection::open(path) + .map_err(FeosError::from) + .map_err(PyFeosError::from)?; + + let pure_records = PureRecord::from_database(&substances, &conn, identifier_option) + .map_err(PyFeosError::from)?; + let binary_records = BinaryRecord::from_database(&substances, &conn, identifier_option) + .map_err(PyFeosError::from)?; + + Ok(Self { + pure_records, + binary_records, + }) + } + /// Generates JSON-formatted string for pure and binary records (if initialized). /// /// Parameters @@ -488,6 +525,10 @@ impl PyParameters { for &p in ¶ms { format_optional(o, model_record.get(p)); } + } else { + for _ in ¶ms { + format_optional(o, None); + } } if !r.association_sites.is_empty() { let s = &r.association_sites[0]; @@ -557,6 +598,7 @@ impl PyGcParameters { .map_err(PyFeosError::from)?) } + #[cfg(feature = "gc_pcsaft")] pub fn try_convert_heterosegmented( self, ) -> PyResult>