Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/Database.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/database/AtomDescriptionByParameters.hpp"
8 : #include "pairinteraction/database/AtomDescriptionByRanges.hpp"
9 : #include "pairinteraction/database/GitHubDownloader.hpp"
10 : #include "pairinteraction/database/ParquetManager.hpp"
11 : #include "pairinteraction/enums/OperatorType.hpp"
12 : #include "pairinteraction/enums/Parity.hpp"
13 : #include "pairinteraction/ket/KetAtom.hpp"
14 : #include "pairinteraction/utils/hash.hpp"
15 : #include "pairinteraction/utils/id_in_database.hpp"
16 : #include "pairinteraction/utils/paths.hpp"
17 : #include "pairinteraction/utils/streamed.hpp"
18 :
19 : #include <cpptrace/cpptrace.hpp>
20 : #include <duckdb.hpp>
21 : #include <fmt/core.h>
22 : #include <fmt/ranges.h>
23 : #include <fstream>
24 : #include <nlohmann/json.hpp>
25 : #include <oneapi/tbb.h>
26 : #include <spdlog/spdlog.h>
27 : #include <system_error>
28 :
29 : namespace pairinteraction {
30 :
31 1252 : void ensure_consistent_quantum_numbers(bool is_j_total_momentum, double quantum_number_f,
32 : double quantum_number_m, double quantum_number_j_exp) {
33 1252 : if (is_j_total_momentum && quantum_number_f != quantum_number_j_exp) {
34 0 : throw std::runtime_error("If j is the total momentum, f must be equal to j.");
35 : }
36 1252 : if (2 * quantum_number_m != std::rint(2 * quantum_number_m)) {
37 0 : throw std::runtime_error("The quantum number m must be an integer or half-integer.");
38 : }
39 1252 : if (2 * quantum_number_f != std::rint(2 * quantum_number_f)) {
40 0 : throw std::runtime_error("The quantum number f must be an integer or half-integer.");
41 : }
42 1252 : if (quantum_number_f + quantum_number_m != std::rint(quantum_number_f + quantum_number_m)) {
43 0 : throw std::invalid_argument(
44 0 : "The quantum numbers f and m must be both either integers or half-integers.");
45 : }
46 1252 : if (std::abs(quantum_number_m) > quantum_number_f) {
47 1 : throw std::invalid_argument(
48 2 : "The absolute value of the quantum number m must be less than or equal to f.");
49 : }
50 1251 : }
51 :
52 0 : Database::Database() : Database(default_download_missing) {}
53 :
54 0 : Database::Database(bool download_missing)
55 0 : : Database(download_missing, default_use_cache, default_database_dir) {}
56 :
57 0 : Database::Database(std::filesystem::path database_dir)
58 0 : : Database(default_download_missing, default_use_cache, std::move(database_dir)) {}
59 :
60 5 : Database::Database(bool download_missing, bool use_cache, std::filesystem::path database_dir)
61 5 : : download_missing_(download_missing), use_cache_(use_cache),
62 5 : database_dir_(std::move(database_dir)), db(std::make_unique<duckdb::DuckDB>(nullptr)),
63 25 : con(std::make_unique<duckdb::Connection>(*db)) {
64 :
65 5 : if (database_dir_.empty()) {
66 0 : database_dir_ = default_database_dir;
67 : }
68 :
69 : // Ensure the database directory exists
70 5 : if (!std::filesystem::exists(database_dir_)) {
71 0 : std::filesystem::create_directories(database_dir_);
72 : }
73 5 : database_dir_ = std::filesystem::canonical(database_dir_);
74 5 : if (!std::filesystem::is_directory(database_dir_)) {
75 0 : throw std::filesystem::filesystem_error("Cannot access database", database_dir_.string(),
76 0 : std::make_error_code(std::errc::not_a_directory));
77 : }
78 10 : SPDLOG_INFO("Using database directory: {}", database_dir_.string());
79 :
80 : // Ensure that the config directory exists
81 5 : std::filesystem::path configdir = paths::get_config_directory();
82 5 : if (!std::filesystem::exists(configdir)) {
83 1 : std::filesystem::create_directories(configdir);
84 4 : } else if (!std::filesystem::is_directory(configdir)) {
85 0 : throw std::filesystem::filesystem_error("Cannot access config directory ",
86 0 : configdir.string(),
87 0 : std::make_error_code(std::errc::not_a_directory));
88 : }
89 :
90 : // Read in the database_repo_paths if a config file exists, otherwise use the default and
91 : // write it to the config file
92 5 : std::filesystem::path configfile = configdir / "database.json";
93 5 : std::string database_repo_host;
94 5 : std::vector<std::string> database_repo_paths;
95 5 : if (std::filesystem::exists(configfile)) {
96 4 : std::ifstream file(configfile);
97 4 : nlohmann::json doc = nlohmann::json::parse(file, nullptr, false);
98 :
99 8 : if (!doc.is_discarded() && doc.contains("hash") && doc.contains("database_repo_host") &&
100 4 : doc.contains("database_repo_paths")) {
101 4 : database_repo_host = doc["database_repo_host"].get<std::string>();
102 4 : database_repo_paths = doc["database_repo_paths"].get<std::vector<std::string>>();
103 :
104 : // If the values are not equal to the default values but the hash is consistent (i.e.,
105 : // the user has not changed anything manually), clear the values so that they can be
106 : // updated
107 8 : if (database_repo_host != default_database_repo_host ||
108 4 : database_repo_paths != default_database_repo_paths) {
109 0 : std::size_t seed = 0;
110 0 : utils::hash_combine(seed, database_repo_paths);
111 0 : utils::hash_combine(seed, database_repo_host);
112 0 : if (seed == doc["hash"].get<std::size_t>()) {
113 0 : database_repo_host.clear();
114 0 : database_repo_paths.clear();
115 : } else {
116 0 : SPDLOG_INFO("The database repository host and paths have been changed "
117 : "manually. Thus, they will not be updated automatically. To reset "
118 : "them, delete the file '{}'.",
119 : configfile.string());
120 : }
121 : }
122 : }
123 4 : }
124 :
125 : // Read in and store the default values if necessary
126 5 : if (database_repo_host.empty() || database_repo_paths.empty()) {
127 2 : SPDLOG_INFO("Updating the database repository host and paths:");
128 :
129 1 : database_repo_host = default_database_repo_host;
130 1 : database_repo_paths = default_database_repo_paths;
131 1 : std::ofstream file(configfile);
132 1 : nlohmann::json doc;
133 :
134 2 : SPDLOG_INFO("* New host: {}", default_database_repo_host);
135 3 : SPDLOG_INFO("* New paths: {}", fmt::join(default_database_repo_paths, ", "));
136 :
137 1 : doc["database_repo_host"] = default_database_repo_host;
138 1 : doc["database_repo_paths"] = database_repo_paths;
139 :
140 1 : std::size_t seed = 0;
141 1 : utils::hash_combine(seed, default_database_repo_paths);
142 1 : utils::hash_combine(seed, default_database_repo_host);
143 1 : doc["hash"] = seed;
144 :
145 1 : file << doc.dump(4);
146 1 : }
147 :
148 : // Limit the memory usage of duckdb's buffer manager
149 : {
150 5 : auto result = con->Query("PRAGMA max_memory = '8GB';");
151 5 : if (result->HasError()) {
152 0 : throw cpptrace::runtime_error("Error setting the memory limit: " + result->GetError());
153 : }
154 5 : }
155 :
156 : // Instantiate a database manager that provides access to database tables. If a table
157 : // is outdated/not available locally, it will be downloaded if download_missing_ is true.
158 5 : if (!download_missing_) {
159 5 : database_repo_paths.clear();
160 : }
161 5 : downloader = std::make_unique<GitHubDownloader>();
162 10 : manager = std::make_unique<ParquetManager>(database_dir_, *downloader, database_repo_paths,
163 10 : *con, use_cache_);
164 5 : manager->scan_local();
165 5 : manager->scan_remote();
166 :
167 : // Print versions of tables
168 5 : std::istringstream iss(manager->get_versions_info());
169 60 : for (std::string line; std::getline(iss, line);) {
170 55 : SPDLOG_INFO(line);
171 5 : }
172 15 : }
173 :
174 5 : Database::~Database() = default;
175 :
176 31 : std::shared_ptr<const KetAtom> Database::get_ket(const std::string &species,
177 : const AtomDescriptionByParameters &description) {
178 : // Check that the specifications are valid
179 31 : if (!description.quantum_number_m.has_value()) {
180 0 : throw std::invalid_argument("The quantum number m must be specified.");
181 : }
182 39 : if (description.quantum_number_f.has_value() &&
183 8 : 2 * description.quantum_number_f.value() !=
184 8 : std::rint(2 * description.quantum_number_f.value())) {
185 0 : throw std::invalid_argument("The quantum number f must be an integer or half-integer.");
186 : }
187 31 : if (description.quantum_number_f.has_value() && description.quantum_number_f.value() < 0) {
188 0 : throw std::invalid_argument("The quantum number f must be positive.");
189 : }
190 51 : if (description.quantum_number_j.has_value() &&
191 20 : 2 * description.quantum_number_j.value() !=
192 20 : std::rint(2 * description.quantum_number_j.value())) {
193 0 : throw std::invalid_argument("The quantum number j must be an integer or half-integer.");
194 : }
195 31 : if (description.quantum_number_j.has_value() && description.quantum_number_j.value() < 0) {
196 0 : throw std::invalid_argument("The quantum number j must be positive.");
197 : }
198 62 : if (description.quantum_number_m.has_value() &&
199 31 : 2 * description.quantum_number_m.value() !=
200 31 : std::rint(2 * description.quantum_number_m.value())) {
201 0 : throw std::invalid_argument("The quantum number m must be an integer or half-integer.");
202 : }
203 :
204 : // Describe the state
205 31 : std::string where;
206 31 : std::string separator;
207 31 : if (description.energy.has_value()) {
208 : // The following condition derives from demanding that quantum number n that corresponds to
209 : // the energy "E_n = -1/(2*n^2)" is not off by more than 1 from the actual quantum number n,
210 : // i.e., "sqrt(-1/(2*E_n)) - sqrt(-1/(2*E_{n-1})) = 1"
211 0 : where += separator +
212 0 : fmt::format("SQRT(-1/(2*energy)) BETWEEN {} AND {}",
213 0 : std::sqrt(-1 / (2 * description.energy.value())) - 0.5,
214 0 : std::sqrt(-1 / (2 * description.energy.value())) + 0.5);
215 0 : separator = " AND ";
216 : }
217 31 : if (description.quantum_number_f.has_value()) {
218 16 : where += separator + fmt::format("f = {}", description.quantum_number_f.value());
219 8 : separator = " AND ";
220 : }
221 31 : if (description.parity != Parity::UNKNOWN) {
222 0 : where += separator + fmt::format("parity = {}", fmt::streamed(description.parity));
223 0 : separator = " AND ";
224 : }
225 31 : if (description.quantum_number_n.has_value()) {
226 62 : where += separator + fmt::format("n = {}", description.quantum_number_n.value());
227 31 : separator = " AND ";
228 : }
229 31 : if (description.quantum_number_nu.has_value()) {
230 0 : where += separator +
231 0 : fmt::format("nu BETWEEN {} AND {}", description.quantum_number_nu.value() - 0.5,
232 0 : description.quantum_number_nu.value() + 0.5);
233 0 : separator = " AND ";
234 : }
235 31 : if (description.quantum_number_nui.has_value()) {
236 0 : where += separator +
237 0 : fmt::format("exp_nui BETWEEN {} AND {}", description.quantum_number_nui.value() - 0.5,
238 0 : description.quantum_number_nui.value() + 0.5);
239 0 : separator = " AND ";
240 : }
241 31 : if (description.quantum_number_l.has_value()) {
242 31 : where += separator +
243 93 : fmt::format("exp_l BETWEEN {} AND {}", description.quantum_number_l.value() - 0.5,
244 62 : description.quantum_number_l.value() + 0.5);
245 31 : separator = " AND ";
246 : }
247 31 : if (description.quantum_number_s.has_value()) {
248 8 : where += separator +
249 24 : fmt::format("exp_s BETWEEN {} AND {}", description.quantum_number_s.value() - 0.5,
250 16 : description.quantum_number_s.value() + 0.5);
251 8 : separator = " AND ";
252 : }
253 31 : if (description.quantum_number_j.has_value()) {
254 20 : where += separator +
255 60 : fmt::format("exp_j BETWEEN {} AND {}", description.quantum_number_j.value() - 0.5,
256 40 : description.quantum_number_j.value() + 0.5);
257 20 : separator = " AND ";
258 : }
259 31 : if (description.quantum_number_l_ryd.has_value()) {
260 0 : where += separator +
261 0 : fmt::format("exp_l_ryd BETWEEN {} AND {}",
262 0 : description.quantum_number_l_ryd.value() - 0.5,
263 0 : description.quantum_number_l_ryd.value() + 0.5);
264 0 : separator = " AND ";
265 : }
266 31 : if (description.quantum_number_j_ryd.has_value()) {
267 0 : where += separator +
268 0 : fmt::format("exp_j_ryd BETWEEN {} AND {}",
269 0 : description.quantum_number_j_ryd.value() - 0.5,
270 0 : description.quantum_number_j_ryd.value() + 0.5);
271 0 : separator = " AND ";
272 : }
273 31 : if (separator.empty()) {
274 0 : where += "FALSE";
275 : }
276 :
277 31 : std::string orderby;
278 31 : separator = "";
279 31 : if (description.energy.has_value()) {
280 0 : orderby += separator +
281 0 : fmt::format("(SQRT(-1/(2*energy)) - {})^2",
282 0 : std::sqrt(-1 / (2 * description.energy.value())));
283 0 : separator = " + ";
284 : }
285 31 : if (description.quantum_number_nu.has_value()) {
286 0 : orderby += separator + fmt::format("(nu - {})^2", description.quantum_number_nu.value());
287 0 : separator = " + ";
288 : }
289 31 : if (description.quantum_number_nui.has_value()) {
290 : orderby +=
291 0 : separator + fmt::format("(exp_nui - {})^2", description.quantum_number_nui.value());
292 0 : separator = " + ";
293 : }
294 31 : if (description.quantum_number_l.has_value()) {
295 62 : orderby += separator + fmt::format("(exp_l - {})^2", description.quantum_number_l.value());
296 31 : separator = " + ";
297 : }
298 31 : if (description.quantum_number_s.has_value()) {
299 16 : orderby += separator + fmt::format("(exp_s - {})^2", description.quantum_number_s.value());
300 8 : separator = " + ";
301 : }
302 31 : if (description.quantum_number_j.has_value()) {
303 40 : orderby += separator + fmt::format("(exp_j - {})^2", description.quantum_number_j.value());
304 20 : separator = " + ";
305 : }
306 31 : if (description.quantum_number_l_ryd.has_value()) {
307 : orderby +=
308 0 : separator + fmt::format("(exp_l_ryd - {})^2", description.quantum_number_l_ryd.value());
309 0 : separator = " + ";
310 : }
311 31 : if (description.quantum_number_j_ryd.has_value()) {
312 : orderby +=
313 0 : separator + fmt::format("(exp_j_ryd - {})^2", description.quantum_number_j_ryd.value());
314 0 : separator = " + ";
315 : }
316 31 : if (separator.empty()) {
317 0 : orderby += "id";
318 : }
319 :
320 : // Ask the database for the described state
321 31 : auto result = con->Query(fmt::format(
322 : R"(SELECT energy, f, parity, id, n, nu, exp_nui, std_nui, exp_l, std_l, exp_s, std_s,
323 : exp_j, std_j, exp_l_ryd, std_l_ryd, exp_j_ryd, std_j_ryd, is_j_total_momentum, is_calculated_with_mqdt, underspecified_channel_contribution, {} AS order_val FROM '{}' WHERE {} ORDER BY order_val ASC LIMIT 2)",
324 93 : orderby, manager->get_path(species, "states"), where));
325 :
326 31 : if (result->HasError()) {
327 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
328 : }
329 :
330 31 : if (result->RowCount() == 0) {
331 0 : throw std::invalid_argument("No state found.");
332 : }
333 :
334 : // Check the types of the columns
335 31 : const auto &types = result->types;
336 31 : const auto &labels = result->names;
337 : const std::vector<duckdb::LogicalType> ref_types = {
338 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::BIGINT,
339 : duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT, duckdb::LogicalType::DOUBLE,
340 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
341 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
342 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
343 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
344 : duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::DOUBLE,
345 713 : duckdb::LogicalType::DOUBLE};
346 :
347 713 : for (size_t i = 0; i < types.size(); i++) {
348 682 : if (types[i] != ref_types[i]) {
349 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'. Got " +
350 0 : types[i].ToString() + " but expected " +
351 0 : ref_types[i].ToString());
352 : }
353 : }
354 :
355 : // Get the first chunk of the results (the first chunk is sufficient as we need two rows at
356 : // most)
357 31 : auto chunk = result->Fetch();
358 :
359 : // Check that the ket is uniquely specified
360 31 : if (chunk->size() > 1) {
361 1 : auto order_val_0 = duckdb::FlatVector::GetData<double>(chunk->data[21])[0];
362 1 : auto order_val_1 = duckdb::FlatVector::GetData<double>(chunk->data[21])[1];
363 :
364 1 : if (order_val_1 - order_val_0 <= order_val_0) {
365 : // Get a list of possible kets
366 1 : std::vector<KetAtom> kets;
367 1 : kets.reserve(2);
368 3 : for (size_t i = 0; i < 2; ++i) {
369 2 : auto result_quantum_number_m = description.quantum_number_m.value();
370 2 : auto result_energy = duckdb::FlatVector::GetData<double>(chunk->data[0])[i];
371 : auto result_quantum_number_f =
372 2 : duckdb::FlatVector::GetData<double>(chunk->data[1])[i];
373 2 : auto result_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[2])[i];
374 4 : auto result_id = utils::get_linearized_id_in_database(
375 2 : duckdb::FlatVector::GetData<int64_t>(chunk->data[3])[i],
376 2 : result_quantum_number_m);
377 : auto result_quantum_number_n =
378 2 : duckdb::FlatVector::GetData<int64_t>(chunk->data[4])[i];
379 : auto result_quantum_number_nu =
380 2 : duckdb::FlatVector::GetData<double>(chunk->data[5])[i];
381 : auto result_quantum_number_nui_exp =
382 2 : duckdb::FlatVector::GetData<double>(chunk->data[6])[i];
383 : auto result_quantum_number_nui_std =
384 2 : duckdb::FlatVector::GetData<double>(chunk->data[7])[i];
385 : auto result_quantum_number_l_exp =
386 2 : duckdb::FlatVector::GetData<double>(chunk->data[8])[i];
387 : auto result_quantum_number_l_std =
388 2 : duckdb::FlatVector::GetData<double>(chunk->data[9])[i];
389 : auto result_quantum_number_s_exp =
390 2 : duckdb::FlatVector::GetData<double>(chunk->data[10])[i];
391 : auto result_quantum_number_s_std =
392 2 : duckdb::FlatVector::GetData<double>(chunk->data[11])[i];
393 : auto result_quantum_number_j_exp =
394 2 : duckdb::FlatVector::GetData<double>(chunk->data[12])[i];
395 : auto result_quantum_number_j_std =
396 2 : duckdb::FlatVector::GetData<double>(chunk->data[13])[i];
397 : auto result_quantum_number_l_ryd_exp =
398 2 : duckdb::FlatVector::GetData<double>(chunk->data[14])[i];
399 : auto result_quantum_number_l_ryd_std =
400 2 : duckdb::FlatVector::GetData<double>(chunk->data[15])[i];
401 : auto result_quantum_number_j_ryd_exp =
402 2 : duckdb::FlatVector::GetData<double>(chunk->data[16])[i];
403 : auto result_quantum_number_j_ryd_std =
404 2 : duckdb::FlatVector::GetData<double>(chunk->data[17])[i];
405 : auto result_is_j_total_momentum =
406 2 : duckdb::FlatVector::GetData<bool>(chunk->data[18])[i];
407 : auto result_is_calculated_with_mqdt =
408 2 : duckdb::FlatVector::GetData<bool>(chunk->data[19])[i];
409 : auto result_underspecified_channel_contribution =
410 2 : duckdb::FlatVector::GetData<double>(chunk->data[20])[i];
411 0 : kets.emplace_back(typename KetAtom::Private(), result_energy,
412 : result_quantum_number_f, result_quantum_number_m,
413 2 : static_cast<Parity>(result_parity), species,
414 : result_quantum_number_n, result_quantum_number_nu,
415 : result_quantum_number_nui_exp, result_quantum_number_nui_std,
416 : result_quantum_number_l_exp, result_quantum_number_l_std,
417 : result_quantum_number_s_exp, result_quantum_number_s_std,
418 : result_quantum_number_j_exp, result_quantum_number_j_std,
419 : result_quantum_number_l_ryd_exp, result_quantum_number_l_ryd_std,
420 : result_quantum_number_j_ryd_exp, result_quantum_number_j_ryd_std,
421 : result_is_j_total_momentum, result_is_calculated_with_mqdt,
422 : result_underspecified_channel_contribution, *this, result_id);
423 : }
424 :
425 : // Throw an error with the possible kets
426 2 : throw std::invalid_argument(
427 1 : fmt::format("The ket is not uniquely specified. Possible kets are:\n{}\n{}",
428 2 : fmt::streamed(kets[0]), fmt::streamed(kets[1])));
429 1 : }
430 : }
431 :
432 : // Construct the state
433 30 : auto result_quantum_number_m = description.quantum_number_m.value();
434 30 : auto result_energy = duckdb::FlatVector::GetData<double>(chunk->data[0])[0];
435 30 : auto result_quantum_number_f = duckdb::FlatVector::GetData<double>(chunk->data[1])[0];
436 30 : auto result_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[2])[0];
437 60 : auto result_id = utils::get_linearized_id_in_database(
438 30 : duckdb::FlatVector::GetData<int64_t>(chunk->data[3])[0], result_quantum_number_m);
439 30 : auto result_quantum_number_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[4])[0];
440 30 : auto result_quantum_number_nu = duckdb::FlatVector::GetData<double>(chunk->data[5])[0];
441 30 : auto result_quantum_number_nui_exp = duckdb::FlatVector::GetData<double>(chunk->data[6])[0];
442 30 : auto result_quantum_number_nui_std = duckdb::FlatVector::GetData<double>(chunk->data[7])[0];
443 30 : auto result_quantum_number_l_exp = duckdb::FlatVector::GetData<double>(chunk->data[8])[0];
444 30 : auto result_quantum_number_l_std = duckdb::FlatVector::GetData<double>(chunk->data[9])[0];
445 30 : auto result_quantum_number_s_exp = duckdb::FlatVector::GetData<double>(chunk->data[10])[0];
446 30 : auto result_quantum_number_s_std = duckdb::FlatVector::GetData<double>(chunk->data[11])[0];
447 30 : auto result_quantum_number_j_exp = duckdb::FlatVector::GetData<double>(chunk->data[12])[0];
448 30 : auto result_quantum_number_j_std = duckdb::FlatVector::GetData<double>(chunk->data[13])[0];
449 30 : auto result_quantum_number_l_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[14])[0];
450 30 : auto result_quantum_number_l_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[15])[0];
451 30 : auto result_quantum_number_j_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[16])[0];
452 30 : auto result_quantum_number_j_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[17])[0];
453 30 : auto result_is_j_total_momentum = duckdb::FlatVector::GetData<bool>(chunk->data[18])[0];
454 30 : auto result_is_calculated_with_mqdt = duckdb::FlatVector::GetData<bool>(chunk->data[19])[0];
455 : auto result_underspecified_channel_contribution =
456 30 : duckdb::FlatVector::GetData<double>(chunk->data[20])[0];
457 :
458 : // Check database consistency
459 30 : ensure_consistent_quantum_numbers(result_is_j_total_momentum, result_quantum_number_f,
460 : result_quantum_number_m, result_quantum_number_j_exp);
461 :
462 : return std::make_shared<const KetAtom>(
463 0 : typename KetAtom::Private(), result_energy, result_quantum_number_f,
464 0 : result_quantum_number_m, static_cast<Parity>(result_parity), species,
465 : result_quantum_number_n, result_quantum_number_nu, result_quantum_number_nui_exp,
466 : result_quantum_number_nui_std, result_quantum_number_l_exp, result_quantum_number_l_std,
467 : result_quantum_number_s_exp, result_quantum_number_s_std, result_quantum_number_j_exp,
468 : result_quantum_number_j_std, result_quantum_number_l_ryd_exp,
469 : result_quantum_number_l_ryd_std, result_quantum_number_j_ryd_exp,
470 : result_quantum_number_j_ryd_std, result_is_j_total_momentum, result_is_calculated_with_mqdt,
471 58 : result_underspecified_channel_contribution, *this, result_id);
472 72 : }
473 :
474 : template <typename Scalar>
475 : std::shared_ptr<const BasisAtom<Scalar>>
476 28 : Database::get_basis(const std::string &species, const AtomDescriptionByRanges &description,
477 : std::vector<size_t> additional_ket_ids) {
478 : // Describe the states
479 28 : std::string where = "(";
480 28 : std::string separator;
481 28 : if (description.parity != Parity::UNKNOWN) {
482 0 : where += separator + fmt::format("parity = {}", fmt::streamed(description.parity));
483 0 : separator = " AND ";
484 : }
485 28 : if (description.range_energy.is_finite()) {
486 0 : where += separator +
487 0 : fmt::format("energy BETWEEN {} AND {}", description.range_energy.min(),
488 0 : description.range_energy.max());
489 0 : separator = " AND ";
490 : }
491 28 : if (description.range_quantum_number_f.is_finite()) {
492 0 : where += separator +
493 0 : fmt::format("f BETWEEN {} AND {}", description.range_quantum_number_f.min(),
494 0 : description.range_quantum_number_f.max());
495 0 : separator = " AND ";
496 : }
497 28 : if (description.range_quantum_number_m.is_finite()) {
498 6 : where += separator +
499 6 : fmt::format("m BETWEEN {} AND {}", description.range_quantum_number_m.min(),
500 6 : description.range_quantum_number_m.max());
501 6 : separator = " AND ";
502 : }
503 28 : if (description.range_quantum_number_n.is_finite()) {
504 24 : where += separator +
505 24 : fmt::format("n BETWEEN {} AND {}", description.range_quantum_number_n.min(),
506 24 : description.range_quantum_number_n.max());
507 24 : separator = " AND ";
508 : }
509 28 : if (description.range_quantum_number_nu.is_finite()) {
510 1 : where += separator +
511 1 : fmt::format("nu BETWEEN {} AND {}", description.range_quantum_number_nu.min(),
512 1 : description.range_quantum_number_nu.max());
513 1 : separator = " AND ";
514 : }
515 28 : if (description.range_quantum_number_nui.is_finite()) {
516 0 : where += separator +
517 : fmt::format("exp_nui BETWEEN {}-2*std_nui AND {}+2*std_nui",
518 0 : description.range_quantum_number_nui.min(),
519 0 : description.range_quantum_number_nui.max());
520 0 : separator = " AND ";
521 : }
522 28 : if (description.range_quantum_number_l.is_finite()) {
523 25 : where += separator +
524 : fmt::format("exp_l BETWEEN {}-2*std_l AND {}+2*std_l",
525 25 : description.range_quantum_number_l.min(),
526 25 : description.range_quantum_number_l.max());
527 25 : separator = " AND ";
528 : }
529 28 : if (description.range_quantum_number_s.is_finite()) {
530 0 : where += separator +
531 : fmt::format("exp_s BETWEEN {}-2*std_s AND {}+2*std_s",
532 0 : description.range_quantum_number_s.min(),
533 0 : description.range_quantum_number_s.max());
534 0 : separator = " AND ";
535 : }
536 28 : if (description.range_quantum_number_j.is_finite()) {
537 0 : where += separator +
538 : fmt::format("exp_j BETWEEN {}-2*std_j AND {}+2*std_j",
539 0 : description.range_quantum_number_j.min(),
540 0 : description.range_quantum_number_j.max());
541 0 : separator = " AND ";
542 : }
543 28 : if (description.range_quantum_number_l_ryd.is_finite()) {
544 0 : where += separator +
545 : fmt::format("exp_l_ryd BETWEEN {}-2*std_l_ryd AND {}+2*std_l_ryd",
546 0 : description.range_quantum_number_l_ryd.min(),
547 0 : description.range_quantum_number_l_ryd.max());
548 0 : separator = " AND ";
549 : }
550 28 : if (description.range_quantum_number_j_ryd.is_finite()) {
551 0 : where += separator +
552 : fmt::format("exp_j_ryd BETWEEN {}-2*std_j_ryd AND {}+2*std_j_ryd",
553 0 : description.range_quantum_number_j_ryd.min(),
554 0 : description.range_quantum_number_j_ryd.max());
555 0 : separator = " AND ";
556 : }
557 28 : if (separator.empty()) {
558 3 : where += "FALSE";
559 : }
560 28 : where += ")";
561 28 : if (!additional_ket_ids.empty()) {
562 3 : where += fmt::format(" OR {} IN ({})", utils::SQL_TERM_FOR_LINEARIZED_ID_IN_DATABASE,
563 6 : fmt::join(additional_ket_ids, ","));
564 : }
565 :
566 : // Create a table containing the described states
567 28 : std::string id_of_kets;
568 : {
569 28 : auto result = con->Query(R"(SELECT UUID()::varchar)");
570 28 : if (result->HasError()) {
571 0 : throw cpptrace::runtime_error("Error selecting id_of_kets: " + result->GetError());
572 : }
573 28 : id_of_kets =
574 56 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
575 28 : }
576 : {
577 56 : auto result = con->Query(fmt::format(
578 : R"(CREATE TEMP TABLE '{}' AS SELECT *, {} AS ketid FROM (
579 : SELECT *,
580 : UNNEST(list_transform(generate_series(0,(2*f)::bigint),
581 : x -> x::double-f)) AS m FROM '{}'
582 : ) WHERE {})",
583 : id_of_kets, utils::SQL_TERM_FOR_LINEARIZED_ID_IN_DATABASE,
584 : manager->get_path(species, "states"), where));
585 :
586 28 : if (result->HasError()) {
587 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
588 : }
589 28 : }
590 :
591 : // Ask the table for the extreme values of the quantum numbers
592 : {
593 28 : std::string select;
594 28 : std::string separator;
595 28 : if (description.range_energy.is_finite()) {
596 0 : select += separator + "MIN(energy) AS min_energy, MAX(energy) AS max_energy";
597 0 : separator = ", ";
598 : }
599 28 : if (description.range_quantum_number_f.is_finite()) {
600 0 : select += separator + "MIN(f) AS min_f, MAX(f) AS max_f";
601 0 : separator = ", ";
602 : }
603 28 : if (description.range_quantum_number_m.is_finite()) {
604 6 : select += separator + "MIN(m) AS min_m, MAX(m) AS max_m";
605 6 : separator = ", ";
606 : }
607 28 : if (description.range_quantum_number_n.is_finite()) {
608 24 : select += separator + "MIN(n) AS min_n, MAX(n) AS max_n";
609 24 : separator = ", ";
610 : }
611 28 : if (description.range_quantum_number_nu.is_finite()) {
612 1 : select += separator + "MIN(nu) AS min_nu, MAX(nu) AS max_nu";
613 1 : separator = ", ";
614 : }
615 28 : if (description.range_quantum_number_nui.is_finite()) {
616 0 : select += separator + "MIN(exp_nui) AS min_nui, MAX(exp_nui) AS max_nui";
617 0 : separator = ", ";
618 : }
619 28 : if (description.range_quantum_number_l.is_finite()) {
620 25 : select += separator + "MIN(exp_l) AS min_l, MAX(exp_l) AS max_l";
621 25 : separator = ", ";
622 : }
623 28 : if (description.range_quantum_number_s.is_finite()) {
624 0 : select += separator + "MIN(exp_s) AS min_s, MAX(exp_s) AS max_s";
625 0 : separator = ", ";
626 : }
627 28 : if (description.range_quantum_number_j.is_finite()) {
628 0 : select += separator + "MIN(exp_j) AS min_j, MAX(exp_j) AS max_j";
629 0 : separator = ", ";
630 : }
631 28 : if (description.range_quantum_number_l_ryd.is_finite()) {
632 0 : select += separator + "MIN(exp_l_ryd) AS min_l_ryd, MAX(exp_l_ryd) AS max_l_ryd";
633 0 : separator = ", ";
634 : }
635 28 : if (description.range_quantum_number_j_ryd.is_finite()) {
636 0 : select += separator + "MIN(exp_j_ryd) AS min_j_ryd, MAX(exp_j_ryd) AS max_j_ryd";
637 0 : separator = ", ";
638 : }
639 :
640 28 : if (!separator.empty()) {
641 50 : auto result = con->Query(fmt::format(R"(SELECT {} FROM '{}')", select, id_of_kets));
642 :
643 25 : if (result->HasError()) {
644 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
645 : }
646 :
647 25 : auto chunk = result->Fetch();
648 :
649 137 : for (size_t i = 0; i < chunk->ColumnCount(); i++) {
650 112 : if (duckdb::FlatVector::IsNull(chunk->data[i], 0)) {
651 0 : throw std::invalid_argument("No state found.");
652 : }
653 : }
654 :
655 25 : size_t idx = 0;
656 25 : if (description.range_energy.is_finite()) {
657 0 : auto min_energy = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
658 0 : if (std::sqrt(-1 / (2 * min_energy)) - 1 >
659 0 : std::sqrt(-1 / (2 * description.range_energy.min()))) {
660 0 : SPDLOG_DEBUG("No state found with the requested minimum energy. Requested: {}, "
661 : "found: {}.",
662 : description.range_energy.min(), min_energy);
663 : }
664 0 : auto max_energy = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
665 0 : if (std::sqrt(-1 / (2 * max_energy)) + 1 <
666 0 : std::sqrt(-1 / (2 * description.range_energy.max()))) {
667 0 : SPDLOG_DEBUG("No state found with the requested maximum energy. Requested: {}, "
668 : "found: {}.",
669 : description.range_energy.max(), max_energy);
670 : }
671 : }
672 25 : if (description.range_quantum_number_f.is_finite()) {
673 0 : auto min_f = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
674 0 : if (min_f > description.range_quantum_number_f.min()) {
675 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number f. "
676 : "Requested: {}, found: {}.",
677 : description.range_quantum_number_f.min(), min_f);
678 : }
679 0 : auto max_f = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
680 0 : if (max_f < description.range_quantum_number_f.max()) {
681 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number f. "
682 : "Requested: {}, found: {}.",
683 : description.range_quantum_number_f.max(), max_f);
684 : }
685 : }
686 25 : if (description.range_quantum_number_m.is_finite()) {
687 6 : auto min_m = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
688 6 : if (min_m > description.range_quantum_number_m.min()) {
689 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number m. "
690 : "Requested: {}, found: {}.",
691 : description.range_quantum_number_m.min(), min_m);
692 : }
693 6 : auto max_m = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
694 6 : if (max_m < description.range_quantum_number_m.max()) {
695 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number m. "
696 : "Requested: {}, found: {}.",
697 : description.range_quantum_number_m.max(), max_m);
698 : }
699 : }
700 25 : if (description.range_quantum_number_n.is_finite()) {
701 24 : auto min_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[idx++])[0];
702 24 : if (min_n > description.range_quantum_number_n.min()) {
703 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number n. "
704 : "Requested: {}, found: {}.",
705 : description.range_quantum_number_n.min(), min_n);
706 : }
707 24 : auto max_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[idx++])[0];
708 24 : if (max_n < description.range_quantum_number_n.max()) {
709 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number n. "
710 : "Requested: {}, found: {}.",
711 : description.range_quantum_number_n.max(), max_n);
712 : }
713 : }
714 25 : if (description.range_quantum_number_nu.is_finite()) {
715 1 : auto min_nu = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
716 1 : if (min_nu - 1 > description.range_quantum_number_nu.min()) {
717 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number nu. "
718 : "Requested: {}, found: {}.",
719 : description.range_quantum_number_nu.min(), min_nu);
720 : }
721 1 : auto max_nu = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
722 1 : if (max_nu + 1 < description.range_quantum_number_nu.max()) {
723 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number nu. "
724 : "Requested: {}, found: {}.",
725 : description.range_quantum_number_nu.max(), max_nu);
726 : }
727 : }
728 25 : if (description.range_quantum_number_nui.is_finite()) {
729 0 : auto min_nui = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
730 0 : if (min_nui - 1 > description.range_quantum_number_nui.min()) {
731 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number nui. "
732 : "Requested: {}, found: {}.",
733 : description.range_quantum_number_nui.min(), min_nui);
734 : }
735 0 : auto max_nui = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
736 0 : if (max_nui + 1 < description.range_quantum_number_nui.max()) {
737 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number nui. "
738 : "Requested: {}, found: {}.",
739 : description.range_quantum_number_nui.max(), max_nui);
740 : }
741 : }
742 25 : if (description.range_quantum_number_l.is_finite()) {
743 25 : auto min_l = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
744 25 : if (min_l - 1 > description.range_quantum_number_l.min()) {
745 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number l. "
746 : "Requested: {}, found: {}.",
747 : description.range_quantum_number_l.min(), min_l);
748 : }
749 25 : auto max_l = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
750 25 : if (max_l + 1 < description.range_quantum_number_l.max()) {
751 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number l. "
752 : "Requested: {}, found: {}.",
753 : description.range_quantum_number_l.max(), max_l);
754 : }
755 : }
756 25 : if (description.range_quantum_number_s.is_finite()) {
757 0 : auto min_s = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
758 0 : if (min_s - 1 > description.range_quantum_number_s.min()) {
759 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number s. "
760 : "Requested: {}, found: {}.",
761 : description.range_quantum_number_s.min(), min_s);
762 : }
763 0 : auto max_s = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
764 0 : if (max_s + 1 < description.range_quantum_number_s.max()) {
765 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number s. "
766 : "Requested: {}, found: {}.",
767 : description.range_quantum_number_s.max(), max_s);
768 : }
769 : }
770 25 : if (description.range_quantum_number_j.is_finite()) {
771 0 : auto min_j = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
772 0 : if (min_j - 1 > description.range_quantum_number_j.min()) {
773 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number j. "
774 : "Requested: {}, found: {}.",
775 : description.range_quantum_number_j.min(), min_j);
776 : }
777 0 : auto max_j = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
778 0 : if (max_j + 1 < description.range_quantum_number_j.max()) {
779 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number j. "
780 : "Requested: {}, found: {}.",
781 : description.range_quantum_number_j.max(), max_j);
782 : }
783 : }
784 25 : if (description.range_quantum_number_l_ryd.is_finite()) {
785 0 : auto min_l_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
786 0 : if (min_l_ryd - 1 > description.range_quantum_number_l_ryd.min()) {
787 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number l_ryd. "
788 : "Requested: {}, found: {}.",
789 : description.range_quantum_number_l_ryd.min(), min_l_ryd);
790 : }
791 0 : auto max_l_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
792 0 : if (max_l_ryd + 1 < description.range_quantum_number_l_ryd.max()) {
793 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number l_ryd. "
794 : "Requested: {}, found: {}.",
795 : description.range_quantum_number_l_ryd.max(), max_l_ryd);
796 : }
797 : }
798 25 : if (description.range_quantum_number_j_ryd.is_finite()) {
799 0 : auto min_j_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
800 0 : if (min_j_ryd - 1 > description.range_quantum_number_j_ryd.min()) {
801 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number j_ryd. "
802 : "Requested: {}, found: {}.",
803 : description.range_quantum_number_j_ryd.min(), min_j_ryd);
804 : }
805 0 : auto max_j_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
806 0 : if (max_j_ryd + 1 < description.range_quantum_number_j_ryd.max()) {
807 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number j_ryd. "
808 : "Requested: {}, found: {}.",
809 : description.range_quantum_number_j_ryd.max(), max_j_ryd);
810 : }
811 : }
812 25 : }
813 28 : }
814 :
815 : // Ask the table for the described states
816 56 : auto result = con->Query(fmt::format(
817 : R"(SELECT energy, f, m, parity, ketid, n, nu, exp_nui, std_nui, exp_l, std_l,
818 : exp_s, std_s, exp_j, std_j, exp_l_ryd, std_l_ryd, exp_j_ryd, std_j_ryd, is_j_total_momentum, is_calculated_with_mqdt, underspecified_channel_contribution FROM '{}' ORDER BY ketid ASC)",
819 : id_of_kets));
820 :
821 28 : if (result->HasError()) {
822 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
823 : }
824 :
825 28 : if (result->RowCount() == 0) {
826 0 : throw std::invalid_argument("No state found.");
827 : }
828 :
829 : // Check the types of the columns
830 28 : const auto &types = result->types;
831 28 : const auto &labels = result->names;
832 644 : const std::vector<duckdb::LogicalType> ref_types = {
833 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
834 : duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT,
835 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
836 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
837 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
838 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
839 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::BOOLEAN,
840 : duckdb::LogicalType::DOUBLE};
841 :
842 644 : for (size_t i = 0; i < types.size(); i++) {
843 616 : if (types[i] != ref_types[i]) {
844 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'. Got " +
845 0 : types[i].ToString() + " but expected " +
846 0 : ref_types[i].ToString());
847 : }
848 : }
849 :
850 : // Construct the states
851 28 : std::vector<std::shared_ptr<const KetAtom>> kets;
852 28 : kets.reserve(result->RowCount());
853 28 : double last_energy = std::numeric_limits<double>::lowest();
854 28 : bool is_calculated_with_mqdt = false;
855 28 : double min_quantum_number_nu = std::numeric_limits<double>::max();
856 :
857 56 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
858 :
859 28 : auto *chunk_energy = duckdb::FlatVector::GetData<double>(chunk->data[0]);
860 28 : auto *chunk_quantum_number_f = duckdb::FlatVector::GetData<double>(chunk->data[1]);
861 28 : auto *chunk_quantum_number_m = duckdb::FlatVector::GetData<double>(chunk->data[2]);
862 28 : auto *chunk_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[3]);
863 28 : auto *chunk_id = duckdb::FlatVector::GetData<int64_t>(chunk->data[4]);
864 28 : auto *chunk_quantum_number_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[5]);
865 28 : auto *chunk_quantum_number_nu = duckdb::FlatVector::GetData<double>(chunk->data[6]);
866 28 : auto *chunk_quantum_number_nui_exp = duckdb::FlatVector::GetData<double>(chunk->data[7]);
867 28 : auto *chunk_quantum_number_nui_std = duckdb::FlatVector::GetData<double>(chunk->data[8]);
868 28 : auto *chunk_quantum_number_l_exp = duckdb::FlatVector::GetData<double>(chunk->data[9]);
869 28 : auto *chunk_quantum_number_l_std = duckdb::FlatVector::GetData<double>(chunk->data[10]);
870 28 : auto *chunk_quantum_number_s_exp = duckdb::FlatVector::GetData<double>(chunk->data[11]);
871 28 : auto *chunk_quantum_number_s_std = duckdb::FlatVector::GetData<double>(chunk->data[12]);
872 28 : auto *chunk_quantum_number_j_exp = duckdb::FlatVector::GetData<double>(chunk->data[13]);
873 28 : auto *chunk_quantum_number_j_std = duckdb::FlatVector::GetData<double>(chunk->data[14]);
874 28 : auto *chunk_quantum_number_l_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[15]);
875 28 : auto *chunk_quantum_number_l_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[16]);
876 28 : auto *chunk_quantum_number_j_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[17]);
877 28 : auto *chunk_quantum_number_j_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[18]);
878 28 : auto *chunk_is_j_total_momentum = duckdb::FlatVector::GetData<bool>(chunk->data[19]);
879 28 : auto *chunk_is_calculated_with_mqdt = duckdb::FlatVector::GetData<bool>(chunk->data[20]);
880 : auto *chunk_underspecified_channel_contribution =
881 28 : duckdb::FlatVector::GetData<double>(chunk->data[21]);
882 :
883 1250 : for (size_t i = 0; i < chunk->size(); i++) {
884 :
885 : // Check database consistency
886 1222 : ensure_consistent_quantum_numbers(chunk_is_j_total_momentum[i],
887 1222 : chunk_quantum_number_f[i], chunk_quantum_number_m[i],
888 1222 : chunk_quantum_number_j_exp[i]);
889 1222 : if (chunk_energy[i] < last_energy) {
890 0 : throw std::runtime_error("The states are not sorted by energy.");
891 : }
892 1222 : last_energy = chunk_energy[i];
893 :
894 1222 : is_calculated_with_mqdt |= chunk_is_calculated_with_mqdt[i];
895 1222 : min_quantum_number_nu = std::min(min_quantum_number_nu, chunk_quantum_number_nu[i]);
896 :
897 : // Append a new state
898 1222 : kets.push_back(std::make_shared<const KetAtom>(
899 0 : typename KetAtom::Private(), chunk_energy[i], chunk_quantum_number_f[i],
900 1222 : chunk_quantum_number_m[i], static_cast<Parity>(chunk_parity[i]), species,
901 1222 : chunk_quantum_number_n[i], chunk_quantum_number_nu[i],
902 1222 : chunk_quantum_number_nui_exp[i], chunk_quantum_number_nui_std[i],
903 1222 : chunk_quantum_number_l_exp[i], chunk_quantum_number_l_std[i],
904 1222 : chunk_quantum_number_s_exp[i], chunk_quantum_number_s_std[i],
905 1222 : chunk_quantum_number_j_exp[i], chunk_quantum_number_j_std[i],
906 1222 : chunk_quantum_number_l_ryd_exp[i], chunk_quantum_number_l_ryd_std[i],
907 1222 : chunk_quantum_number_j_ryd_exp[i], chunk_quantum_number_j_ryd_std[i],
908 1222 : chunk_is_j_total_momentum[i], chunk_is_calculated_with_mqdt[i],
909 1222 : chunk_underspecified_channel_contribution[i], *this, chunk_id[i]));
910 : }
911 : }
912 :
913 : // Show a warning for low-lying states
914 28 : if (min_quantum_number_nu < 25) {
915 0 : if (is_calculated_with_mqdt) {
916 0 : SPDLOG_WARN("The multi-channel quantum defect theory might produce inaccurate results "
917 : "for effective principal quantum numbers < 25. The models get increasingly "
918 : "unreliable for small principal quantum numbers, leading to inaccurate "
919 : "matrix elements and energies. Due to missing data, even some states might "
920 : "not be present.");
921 : } else {
922 0 : SPDLOG_WARN(
923 : "The single-channel quantum defect theory can be inaccurate for effective "
924 : "principal quantum numbers < 25. This can lead to inaccurate matrix elements.");
925 : }
926 : }
927 :
928 0 : return std::make_shared<const BasisAtom<Scalar>>(typename BasisAtom<Scalar>::Private(),
929 56 : std::move(kets), std::move(id_of_kets), *this);
930 56 : }
931 :
932 : template <typename Scalar>
933 : Eigen::SparseMatrix<Scalar, Eigen::RowMajor>
934 196 : Database::get_matrix_elements(std::shared_ptr<const BasisAtom<Scalar>> initial_basis,
935 : std::shared_ptr<const BasisAtom<Scalar>> final_basis,
936 : OperatorType type, int q) {
937 : using real_t = typename traits::NumTraits<Scalar>::real_t;
938 :
939 196 : std::string specifier;
940 197 : int kappa{};
941 197 : switch (type) {
942 188 : case OperatorType::ELECTRIC_DIPOLE:
943 188 : specifier = "matrix_elements_d";
944 188 : kappa = 1;
945 188 : break;
946 1 : case OperatorType::ELECTRIC_QUADRUPOLE:
947 1 : specifier = "matrix_elements_q";
948 1 : kappa = 2;
949 1 : break;
950 0 : case OperatorType::ELECTRIC_QUADRUPOLE_ZERO:
951 0 : specifier = "matrix_elements_q0";
952 0 : kappa = 0;
953 0 : break;
954 0 : case OperatorType::ELECTRIC_OCTUPOLE:
955 0 : specifier = "matrix_elements_o";
956 0 : kappa = 3;
957 0 : break;
958 0 : case OperatorType::MAGNETIC_DIPOLE:
959 0 : specifier = "matrix_elements_mu";
960 0 : kappa = 1;
961 0 : break;
962 4 : case OperatorType::ENERGY:
963 4 : specifier = "energy";
964 4 : kappa = 0;
965 4 : break;
966 4 : case OperatorType::IDENTITY:
967 4 : specifier = "identity";
968 4 : kappa = 0;
969 4 : break;
970 0 : default:
971 0 : throw std::invalid_argument("Unknown operator type.");
972 : }
973 :
974 197 : if (initial_basis->get_id_of_kets() != final_basis->get_id_of_kets()) {
975 0 : throw std::invalid_argument(
976 : "The initial and final basis must be expressed using the same kets.");
977 : }
978 197 : std::string id_of_kets = initial_basis->get_id_of_kets();
979 0 : std::string cache_key = fmt::format("{}_{}_{}", specifier, q, id_of_kets);
980 :
981 197 : if (!get_matrix_elements_cache().contains(cache_key)) {
982 59 : Eigen::Index dim = initial_basis->get_number_of_kets();
983 :
984 59 : std::vector<int> outerIndexPtr;
985 59 : std::vector<int> innerIndices;
986 59 : std::vector<real_t> values;
987 :
988 58 : if (specifier == "identity") {
989 1 : outerIndexPtr.reserve(dim + 1);
990 1 : innerIndices.reserve(dim);
991 1 : values.reserve(dim);
992 :
993 91 : for (int i = 0; i < dim; i++) {
994 90 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
995 90 : innerIndices.push_back(i);
996 90 : values.push_back(1);
997 : }
998 1 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
999 :
1000 : } else {
1001 : // Check that the specifications are valid
1002 57 : if (std::abs(q) > kappa) {
1003 0 : throw std::invalid_argument("Invalid q.");
1004 : }
1005 :
1006 : // Ask the database for the operator
1007 57 : std::string species = initial_basis->get_species();
1008 58 : duckdb::unique_ptr<duckdb::MaterializedQueryResult> result;
1009 58 : if (specifier != "energy") {
1010 112 : result = con->Query(fmt::format(
1011 : R"(WITH s AS (
1012 : SELECT id, f, m, ketid FROM '{}'
1013 : ),
1014 : b AS (
1015 : SELECT MIN(f) AS min_f, MAX(f) AS max_f,
1016 : MIN(id) AS min_id, MAX(id) AS max_id
1017 : FROM s
1018 : ),
1019 : w_filtered AS (
1020 : SELECT *
1021 : FROM '{}'
1022 : WHERE kappa = {} AND q = {} AND
1023 : f_initial BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b) AND
1024 : f_final BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b)
1025 : ),
1026 : e_filtered AS (
1027 : SELECT *
1028 : FROM '{}'
1029 : WHERE
1030 : id_initial BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b) AND
1031 : id_final BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b)
1032 : )
1033 : SELECT
1034 : s2.ketid AS row,
1035 : s1.ketid AS col,
1036 : e.val*w.val AS val
1037 : FROM e_filtered AS e
1038 : JOIN s AS s1 ON e.id_initial = s1.id
1039 : JOIN s AS s2 ON e.id_final = s2.id
1040 : JOIN w_filtered AS w ON
1041 : w.f_initial = s1.f AND w.m_initial = s1.m AND
1042 : w.f_final = s2.f AND w.m_final = s2.m
1043 : ORDER BY row ASC, col ASC)",
1044 : id_of_kets, manager->get_path("misc", "wigner"), kappa, q,
1045 : manager->get_path(species, specifier)));
1046 : } else {
1047 4 : result = con->Query(fmt::format(
1048 : R"(SELECT ketid as row, ketid as col, energy as val FROM '{}' ORDER BY row ASC)",
1049 : id_of_kets));
1050 : }
1051 :
1052 58 : if (result->HasError()) {
1053 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
1054 : }
1055 :
1056 : // Check the types of the columns
1057 58 : const auto &types = result->types;
1058 58 : const auto &labels = result->names;
1059 232 : const std::vector<duckdb::LogicalType> ref_types = {duckdb::LogicalType::BIGINT,
1060 : duckdb::LogicalType::BIGINT,
1061 : duckdb::LogicalType::DOUBLE};
1062 232 : for (size_t i = 0; i < types.size(); i++) {
1063 174 : if (types[i] != ref_types[i]) {
1064 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'.");
1065 : }
1066 : }
1067 :
1068 : // Construct the matrix
1069 58 : int num_entries = static_cast<int>(result->RowCount());
1070 58 : outerIndexPtr.reserve(dim + 1);
1071 58 : innerIndices.reserve(num_entries);
1072 58 : values.reserve(num_entries);
1073 :
1074 58 : int last_row = -1;
1075 :
1076 116 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
1077 :
1078 58 : auto *chunk_row = duckdb::FlatVector::GetData<int64_t>(chunk->data[0]);
1079 58 : auto *chunk_col = duckdb::FlatVector::GetData<int64_t>(chunk->data[1]);
1080 58 : auto *chunk_val = duckdb::FlatVector::GetData<double>(chunk->data[2]);
1081 :
1082 20710 : for (size_t i = 0; i < chunk->size(); i++) {
1083 20652 : int row = final_basis->get_ket_index_from_id(chunk_row[i]);
1084 20652 : if (row != last_row) {
1085 2713 : if (row < last_row) {
1086 0 : throw std::runtime_error("The rows are not sorted.");
1087 : }
1088 5907 : for (; last_row < row; last_row++) {
1089 3194 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1090 : }
1091 : }
1092 20652 : innerIndices.push_back(initial_basis->get_ket_index_from_id(chunk_col[i]));
1093 20652 : values.push_back(chunk_val[i]);
1094 : }
1095 : }
1096 :
1097 221 : for (; last_row < dim + 1; last_row++) {
1098 163 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1099 : }
1100 58 : }
1101 :
1102 59 : Eigen::Map<const Eigen::SparseMatrix<real_t, Eigen::RowMajor>> matrix_map(
1103 59 : dim, dim, values.size(), outerIndexPtr.data(), innerIndices.data(), values.data());
1104 :
1105 : // Cache the matrix
1106 59 : get_matrix_elements_cache()[cache_key] = matrix_map;
1107 59 : }
1108 :
1109 : // Construct the operator and return it
1110 197 : return final_basis->get_coefficients().adjoint() *
1111 323 : get_matrix_elements_cache()[cache_key].template cast<Scalar>() *
1112 662 : initial_basis->get_coefficients();
1113 255 : }
1114 :
1115 2 : bool Database::get_download_missing() const { return download_missing_; }
1116 :
1117 0 : bool Database::get_use_cache() const { return use_cache_; }
1118 :
1119 0 : std::filesystem::path Database::get_database_dir() const { return database_dir_; }
1120 :
1121 : oneapi::tbb::concurrent_unordered_map<std::string, Eigen::SparseMatrix<double, Eigen::RowMajor>> &
1122 453 : Database::get_matrix_elements_cache() {
1123 : static oneapi::tbb::concurrent_unordered_map<std::string,
1124 : Eigen::SparseMatrix<double, Eigen::RowMajor>>
1125 453 : matrix_elements_cache;
1126 452 : return matrix_elements_cache;
1127 : }
1128 :
1129 32 : Database &Database::get_global_instance() {
1130 64 : return get_global_instance_without_checks(default_download_missing, default_use_cache,
1131 64 : default_database_dir);
1132 : }
1133 :
1134 0 : Database &Database::get_global_instance(bool download_missing) {
1135 0 : Database &database = get_global_instance_without_checks(download_missing, default_use_cache,
1136 : default_database_dir);
1137 0 : if (download_missing != database.download_missing_) {
1138 0 : throw std::invalid_argument(
1139 0 : "The 'download_missing' argument must not change between calls to the method.");
1140 : }
1141 0 : return database;
1142 : }
1143 :
1144 0 : Database &Database::get_global_instance(std::filesystem::path database_dir) {
1145 0 : if (database_dir.empty()) {
1146 0 : database_dir = default_database_dir;
1147 : }
1148 0 : Database &database = get_global_instance_without_checks(default_download_missing,
1149 : default_use_cache, database_dir);
1150 0 : if (!std::filesystem::exists(database_dir) ||
1151 0 : std::filesystem::canonical(database_dir) != database.database_dir_) {
1152 0 : throw std::invalid_argument(
1153 0 : "The 'database_dir' argument must not change between calls to the method.");
1154 : }
1155 0 : return database;
1156 : }
1157 :
1158 1 : Database &Database::get_global_instance(bool download_missing, bool use_cache,
1159 : std::filesystem::path database_dir) {
1160 1 : if (database_dir.empty()) {
1161 0 : database_dir = default_database_dir;
1162 : }
1163 : Database &database =
1164 1 : get_global_instance_without_checks(download_missing, use_cache, database_dir);
1165 1 : if (download_missing != database.download_missing_ || use_cache != database.use_cache_ ||
1166 3 : !std::filesystem::exists(database_dir) ||
1167 2 : std::filesystem::canonical(database_dir) != database.database_dir_) {
1168 0 : throw std::invalid_argument(
1169 : "The 'download_missing', 'use_cache' and 'database_dir' arguments must not "
1170 0 : "change between calls to the method.");
1171 : }
1172 1 : return database;
1173 : }
1174 :
1175 33 : Database &Database::get_global_instance_without_checks(bool download_missing, bool use_cache,
1176 : std::filesystem::path database_dir) {
1177 33 : static Database database(download_missing, use_cache, std::move(database_dir));
1178 33 : return database;
1179 : }
1180 :
1181 : struct database_dir_noexcept : std::filesystem::path {
1182 7 : explicit database_dir_noexcept() noexcept try : std
1183 7 : ::filesystem::path(paths::get_cache_directory() / "database") {}
1184 0 : catch (...) {
1185 0 : SPDLOG_ERROR("Error getting the PairInteraction cache directory.");
1186 0 : std::terminate();
1187 7 : }
1188 : };
1189 :
1190 : const std::filesystem::path Database::default_database_dir = database_dir_noexcept();
1191 :
1192 : // Explicit instantiations
1193 : // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
1194 : #define INSTANTIATE_GETTERS(SCALAR) \
1195 : template std::shared_ptr<const BasisAtom<SCALAR>> Database::get_basis<SCALAR>( \
1196 : const std::string &species, const AtomDescriptionByRanges &description, \
1197 : std::vector<size_t> additional_ket_ids); \
1198 : template Eigen::SparseMatrix<SCALAR, Eigen::RowMajor> Database::get_matrix_elements<SCALAR>( \
1199 : std::shared_ptr<const BasisAtom<SCALAR>> initial_basis, \
1200 : std::shared_ptr<const BasisAtom<SCALAR>> final_basis, OperatorType type, int q);
1201 : // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
1202 :
1203 : INSTANTIATE_GETTERS(double)
1204 : INSTANTIATE_GETTERS(std::complex<double>)
1205 :
1206 : #undef INSTANTIATE_GETTERS
1207 : } // namespace pairinteraction
|