Skip to content

Commit 69d21f3

Browse files
committed
Allow parameter import from SQL
1 parent d54297b commit 69d21f3

File tree

7 files changed

+168
-4
lines changed

7 files changed

+168
-4
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ gauss-quad = "0.2"
4343
approx = "0.5"
4444
criterion = "0.8"
4545
paste = "1.0"
46+
rusqlite = "0.39"
4647

4748
feos-core = { version = "0.9", path = "crates/feos-core" }
4849
feos-dft = { version = "0.9", path = "crates/feos-dft" }

crates/feos-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ indexmap = { workspace = true, features = ["serde"] }
2727
rayon = { workspace = true, optional = true }
2828
typenum = { workspace = true }
2929
itertools = { workspace = true }
30+
rusqlite = { workspace = true, features = ["bundled"], optional = true }
3031

3132
[dev-dependencies]
3233
approx = { workspace = true }

crates/feos-core/src/errors.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ pub enum FeosError {
6262
#[cfg(feature = "ndarray")]
6363
#[error(transparent)]
6464
ShapeError(#[from] ndarray::ShapeError),
65+
#[cfg(feature = "rusqlite")]
66+
#[error(transparent)]
67+
RusqLiteError(#[from] rusqlite::Error),
6568
}
6669

6770
/// Convenience type for `Result<T, FeosError>`.
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
use super::{
2+
BinaryAssociationRecord, BinaryRecord, CombiningRule, IdentifierOption, Parameters, PureRecord,
3+
};
4+
use crate::FeosResult;
5+
use rusqlite::{Connection, ToSql, params_from_iter};
6+
use serde::de::DeserializeOwned;
7+
use std::path::Path;
8+
9+
impl<P: Clone, B: Clone, A: CombiningRule<P> + Clone> Parameters<P, B, A> {
10+
pub fn from_database<F, S>(
11+
substances: &[S],
12+
file: F,
13+
identifier_option: IdentifierOption,
14+
) -> FeosResult<Self>
15+
where
16+
F: AsRef<Path>,
17+
S: ToSql,
18+
P: DeserializeOwned + Clone,
19+
B: DeserializeOwned + Clone,
20+
A: DeserializeOwned + Clone,
21+
{
22+
let conn = Connection::open(file)?;
23+
let pure_records = PureRecord::from_database(substances, &conn, identifier_option)?;
24+
let binary_records = BinaryRecord::from_database(substances, &conn, identifier_option)?;
25+
Self::new(pure_records, binary_records)
26+
}
27+
}
28+
29+
impl<M, A> PureRecord<M, A> {
30+
pub fn from_database<S>(
31+
substances: &[S],
32+
connection: &Connection,
33+
identifier_option: IdentifierOption,
34+
) -> FeosResult<Vec<Self>>
35+
where
36+
S: ToSql,
37+
M: DeserializeOwned,
38+
A: DeserializeOwned,
39+
{
40+
let values = (0..substances.len())
41+
.map(|i| format!("({i},?)"))
42+
.collect::<Vec<_>>()
43+
.join(",");
44+
45+
let query = format!(
46+
"
47+
WITH input(idx, ident) AS (
48+
VALUES {values}
49+
)
50+
SELECT pr.pure_record
51+
FROM input
52+
JOIN pure_records pr
53+
ON pr.{identifier_option} = input.ident
54+
"
55+
);
56+
let mut stmt = connection.prepare(&query)?;
57+
stmt.query_and_then(params_from_iter(substances), |r| {
58+
Ok(serde_json::from_str(&r.get::<_, String>("pure_record")?)?)
59+
})?
60+
.collect()
61+
}
62+
}
63+
64+
impl<B, A> BinaryRecord<usize, B, A> {
65+
pub fn from_database<S>(
66+
substances: &[S],
67+
connection: &Connection,
68+
identifier_option: IdentifierOption,
69+
) -> FeosResult<Vec<Self>>
70+
where
71+
S: ToSql,
72+
B: DeserializeOwned,
73+
A: DeserializeOwned,
74+
{
75+
let values = (0..substances.len())
76+
.map(|i| format!("({i},?)"))
77+
.collect::<Vec<_>>()
78+
.join(",");
79+
80+
let query = format!(
81+
"
82+
WITH input(idx, ident) AS (
83+
VALUES {values}
84+
)
85+
SELECT i1.idx AS comp1, i2.idx AS comp2, br.model_record, br.association_sites
86+
FROM binary_records br
87+
JOIN pure_records p1 ON br.id1 = p1.id
88+
JOIN pure_records p2 ON br.id2 = p2.id
89+
JOIN input i1 ON p1.{identifier_option} = i1.ident
90+
JOIN input i2 ON p2.{identifier_option} = i2.ident
91+
"
92+
);
93+
let mut stmt = connection.prepare(&query)?;
94+
stmt.query_and_then(params_from_iter(substances), |r| {
95+
let mut id1: i32 = r.get("comp1")?;
96+
let mut id2: i32 = r.get("comp2")?;
97+
let model_record = serde_json::from_str(&r.get::<_, String>("model_record")?)?;
98+
let mut association_sites: Vec<BinaryAssociationRecord<_>> =
99+
serde_json::from_str(&r.get::<_, String>("association_sites")?)?;
100+
if id1 > id2 {
101+
association_sites
102+
.iter_mut()
103+
.for_each(|a| std::mem::swap(&mut a.id1, &mut a.id2));
104+
std::mem::swap(&mut id1, &mut id2);
105+
};
106+
107+
Ok(BinaryRecord::with_association(
108+
id1 as usize,
109+
id2 as usize,
110+
model_record,
111+
association_sites,
112+
))
113+
})?
114+
.collect()
115+
}
116+
}

crates/feos-core/src/parameter/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ use std::path::Path;
1515

1616
mod association;
1717
mod chemical_record;
18+
#[cfg(feature = "rusqlite")]
19+
mod database;
1820
mod identifier;
1921
mod model_record;
2022

@@ -349,9 +351,7 @@ impl<P: Clone, B: Clone, A: CombiningRule<P> + Clone> Parameters<P, B, A> {
349351
let records = PureRecord::from_multiple_json(input, identifier_option)?;
350352

351353
let binary_records = if let Some(path) = file_binary {
352-
let file = File::open(path)?;
353-
let reader = BufReader::new(file);
354-
serde_json::from_reader(reader)?
354+
BinaryRecord::from_json(path)?
355355
} else {
356356
Vec::new()
357357
};

py-feos/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ rayon = { workspace = true, optional = true }
3737
itertools = { workspace = true }
3838
typenum = { workspace = true }
3939
paste = { workspace = true }
40+
rusqlite = { workspace = true, features = ["bundled"]}
4041

4142
feos = { workspace = true }
42-
feos-core = { workspace = true, features = ["ndarray"] }
43+
feos-core = { workspace = true, features = ["ndarray", "rusqlite"] }
4344
feos-derive = { workspace = true }
4445
feos-dft = { workspace = true, optional = true }
4546

py-feos/src/parameter/mod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::error::PyFeosError;
2+
use feos_core::FeosError;
23
use feos_core::parameter::*;
34
use indexmap::IndexSet;
45
use pyo3::prelude::*;
@@ -222,6 +223,42 @@ impl PyParameters {
222223
})
223224
}
224225

226+
/// Creates parameters from a database file.
227+
///
228+
/// Parameters
229+
/// ----------
230+
/// substances : List[str]
231+
/// The substances to search.
232+
/// path : str
233+
/// Path to database file.
234+
/// identifier_option : IdentifierOption, optional, defaults to IdentifierOption.Name
235+
/// Identifier that is used to search substance.
236+
#[staticmethod]
237+
#[pyo3(
238+
signature = (substances, path, identifier_option=PyIdentifierOption::Name),
239+
text_signature = "(substances, path, identifier_option=IdentifierOption.Name)"
240+
)]
241+
fn from_database(
242+
substances: Vec<String>,
243+
path: String,
244+
identifier_option: PyIdentifierOption,
245+
) -> PyResult<Self> {
246+
let identifier_option = IdentifierOption::from(identifier_option);
247+
let conn = rusqlite::Connection::open(path)
248+
.map_err(FeosError::from)
249+
.map_err(PyFeosError::from)?;
250+
251+
let pure_records = PureRecord::from_database(&substances, &conn, identifier_option)
252+
.map_err(PyFeosError::from)?;
253+
let binary_records = BinaryRecord::from_database(&substances, &conn, identifier_option)
254+
.map_err(PyFeosError::from)?;
255+
256+
Ok(Self {
257+
pure_records,
258+
binary_records,
259+
})
260+
}
261+
225262
/// Generates JSON-formatted string for pure and binary records (if initialized).
226263
///
227264
/// Parameters
@@ -488,6 +525,10 @@ impl PyParameters {
488525
for &p in &params {
489526
format_optional(o, model_record.get(p));
490527
}
528+
} else {
529+
for _ in &params {
530+
format_optional(o, None);
531+
}
491532
}
492533
if !r.association_sites.is_empty() {
493534
let s = &r.association_sites[0];
@@ -557,6 +598,7 @@ impl PyGcParameters {
557598
.map_err(PyFeosError::from)?)
558599
}
559600

601+
#[cfg(feature = "gc_pcsaft")]
560602
pub fn try_convert_heterosegmented<P, B, A, C: GroupCount + Default>(
561603
self,
562604
) -> PyResult<GcParameters<P, B, A, (), C>>

0 commit comments

Comments
 (0)