Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/Database.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/database/AtomDescriptionByParameters.hpp"
8 : #include "pairinteraction/database/AtomDescriptionByRanges.hpp"
9 : #include "pairinteraction/database/GitHubDownloader.hpp"
10 : #include "pairinteraction/database/ParquetManager.hpp"
11 : #include "pairinteraction/enums/OperatorType.hpp"
12 : #include "pairinteraction/enums/Parity.hpp"
13 : #include "pairinteraction/ket/KetAtom.hpp"
14 : #include "pairinteraction/utils/TaskControl.hpp"
15 : #include "pairinteraction/utils/hash.hpp"
16 : #include "pairinteraction/utils/id_in_database.hpp"
17 : #include "pairinteraction/utils/paths.hpp"
18 : #include "pairinteraction/utils/streamed.hpp"
19 :
20 : #include <cpptrace/cpptrace.hpp>
21 : #include <duckdb.hpp>
22 : #include <fmt/core.h>
23 : #include <fmt/ranges.h>
24 : #include <fstream>
25 : #include <nlohmann/json.hpp>
26 : #include <oneapi/tbb.h>
27 : #include <spdlog/spdlog.h>
28 : #include <system_error>
29 :
30 : namespace pairinteraction {
31 :
32 27442 : void ensure_consistent_quantum_numbers(bool is_j_total_momentum, double quantum_number_f,
33 : double quantum_number_m, double quantum_number_j_exp) {
34 27442 : if (is_j_total_momentum && quantum_number_f != quantum_number_j_exp) {
35 0 : throw std::runtime_error("If j is the total momentum, f must be equal to j.");
36 : }
37 27442 : if (2 * quantum_number_m != std::rint(2 * quantum_number_m)) {
38 0 : throw std::runtime_error("The quantum number m must be an integer or half-integer.");
39 : }
40 27442 : if (2 * quantum_number_f != std::rint(2 * quantum_number_f)) {
41 0 : throw std::runtime_error("The quantum number f must be an integer or half-integer.");
42 : }
43 27442 : if (quantum_number_f + quantum_number_m != std::rint(quantum_number_f + quantum_number_m)) {
44 0 : throw std::invalid_argument(
45 0 : "The quantum numbers f and m must be both either integers or half-integers.");
46 : }
47 27442 : if (std::abs(quantum_number_m) > quantum_number_f) {
48 1 : throw std::invalid_argument(
49 2 : "The absolute value of the quantum number m must be less than or equal to f.");
50 : }
51 27441 : }
52 :
53 0 : Database::Database() : Database(default_download_missing) {}
54 :
55 0 : Database::Database(bool download_missing)
56 0 : : Database(download_missing, default_use_cache, default_database_dir) {}
57 :
58 0 : Database::Database(std::filesystem::path database_dir)
59 0 : : Database(default_download_missing, default_use_cache, std::move(database_dir)) {}
60 :
61 2 : Database::Database(bool download_missing, bool use_cache, std::filesystem::path database_dir)
62 2 : : download_missing_(download_missing), use_cache_(use_cache),
63 2 : database_dir_(std::move(database_dir)), db(std::make_unique<duckdb::DuckDB>(nullptr)),
64 10 : con(std::make_unique<duckdb::Connection>(*db)) {
65 :
66 2 : if (database_dir_.empty()) {
67 0 : database_dir_ = default_database_dir;
68 : }
69 :
70 : // Ensure the database directory exists
71 2 : if (!std::filesystem::exists(database_dir_)) {
72 0 : std::filesystem::create_directories(database_dir_);
73 : }
74 2 : database_dir_ = std::filesystem::canonical(database_dir_);
75 2 : if (!std::filesystem::is_directory(database_dir_)) {
76 0 : throw std::filesystem::filesystem_error("Cannot access database", database_dir_.string(),
77 0 : std::make_error_code(std::errc::not_a_directory));
78 : }
79 2 : SPDLOG_INFO("Using database directory: {}", database_dir_.string());
80 :
81 : // Ensure that the config directory exists
82 2 : std::filesystem::path configdir = paths::get_config_directory();
83 2 : if (!std::filesystem::exists(configdir)) {
84 1 : std::filesystem::create_directories(configdir);
85 1 : } else if (!std::filesystem::is_directory(configdir)) {
86 0 : throw std::filesystem::filesystem_error("Cannot access config directory ",
87 0 : configdir.string(),
88 0 : std::make_error_code(std::errc::not_a_directory));
89 : }
90 :
91 : // Read in the database_repo_paths if a config file exists, otherwise use the default and
92 : // write it to the config file
93 2 : std::filesystem::path configfile = configdir / "database.json";
94 2 : std::string database_repo_host;
95 2 : std::vector<std::string> database_repo_paths;
96 2 : if (std::filesystem::exists(configfile)) {
97 1 : std::ifstream file(configfile);
98 1 : nlohmann::json doc = nlohmann::json::parse(file, nullptr, false);
99 :
100 2 : if (!doc.is_discarded() && doc.contains("hash") && doc.contains("database_repo_host") &&
101 1 : doc.contains("database_repo_paths")) {
102 1 : database_repo_host = doc["database_repo_host"].get<std::string>();
103 1 : database_repo_paths = doc["database_repo_paths"].get<std::vector<std::string>>();
104 :
105 : // If the values are not equal to the default values but the hash is consistent (i.e.,
106 : // the user has not changed anything manually), clear the values so that they can be
107 : // updated
108 2 : if (database_repo_host != default_database_repo_host ||
109 1 : database_repo_paths != default_database_repo_paths) {
110 0 : std::size_t seed = 0;
111 0 : utils::hash_combine(seed, database_repo_paths);
112 0 : utils::hash_combine(seed, database_repo_host);
113 0 : if (seed == doc["hash"].get<std::size_t>()) {
114 0 : database_repo_host.clear();
115 0 : database_repo_paths.clear();
116 : } else {
117 0 : SPDLOG_INFO("The database repository host and paths have been changed "
118 : "manually. Thus, they will not be updated automatically. To reset "
119 : "them, delete the file '{}'.",
120 : configfile.string());
121 : }
122 : }
123 : }
124 1 : }
125 :
126 : // Read in and store the default values if necessary
127 2 : if (database_repo_host.empty() || database_repo_paths.empty()) {
128 2 : SPDLOG_INFO("Updating the database repository host and paths:");
129 :
130 1 : database_repo_host = default_database_repo_host;
131 1 : database_repo_paths = default_database_repo_paths;
132 1 : std::ofstream file(configfile);
133 1 : nlohmann::json doc;
134 :
135 1 : SPDLOG_INFO("* New host: {}", default_database_repo_host);
136 2 : SPDLOG_INFO("* New paths: {}", fmt::join(default_database_repo_paths, ", "));
137 :
138 1 : doc["database_repo_host"] = default_database_repo_host;
139 1 : doc["database_repo_paths"] = database_repo_paths;
140 :
141 1 : std::size_t seed = 0;
142 1 : utils::hash_combine(seed, default_database_repo_paths);
143 1 : utils::hash_combine(seed, default_database_repo_host);
144 1 : doc["hash"] = seed;
145 :
146 1 : file << doc.dump(4);
147 1 : }
148 :
149 : // Limit the memory usage of duckdb's buffer manager
150 : {
151 2 : auto result = con->Query("PRAGMA max_memory = '8GB';");
152 2 : if (result->HasError()) {
153 0 : throw cpptrace::runtime_error("Error setting the memory limit: " + result->GetError());
154 : }
155 2 : }
156 :
157 : // Instantiate a database manager that provides access to database tables. If a table
158 : // is outdated/not available locally, it will be downloaded if download_missing_ is true.
159 2 : if (!download_missing_) {
160 2 : database_repo_paths.clear();
161 : }
162 2 : downloader = std::make_unique<GitHubDownloader>();
163 4 : manager = std::make_unique<ParquetManager>(database_dir_, *downloader, database_repo_paths,
164 4 : *con, use_cache_);
165 2 : manager->scan_local();
166 2 : manager->scan_remote();
167 :
168 : // Print versions of tables
169 2 : std::istringstream iss(manager->get_versions_info());
170 24 : for (std::string line; std::getline(iss, line);) {
171 22 : SPDLOG_INFO(line);
172 2 : }
173 6 : }
174 :
175 1 : Database::~Database() = default;
176 :
177 450 : std::shared_ptr<const KetAtom> Database::get_ket(const std::string &species,
178 : const AtomDescriptionByParameters &description) {
179 : // Check that the specifications are valid
180 450 : if (!description.quantum_number_m.has_value()) {
181 0 : throw std::invalid_argument("The quantum number m must be specified.");
182 : }
183 483 : if (description.quantum_number_f.has_value() &&
184 33 : 2 * description.quantum_number_f.value() !=
185 33 : std::rint(2 * description.quantum_number_f.value())) {
186 0 : throw std::invalid_argument("The quantum number f must be an integer or half-integer.");
187 : }
188 450 : if (description.quantum_number_f.has_value() && description.quantum_number_f.value() < 0) {
189 0 : throw std::invalid_argument("The quantum number f must be positive.");
190 : }
191 826 : if (description.quantum_number_j.has_value() &&
192 376 : 2 * description.quantum_number_j.value() !=
193 376 : std::rint(2 * description.quantum_number_j.value())) {
194 0 : throw std::invalid_argument("The quantum number j must be an integer or half-integer.");
195 : }
196 450 : if (description.quantum_number_j.has_value() && description.quantum_number_j.value() < 0) {
197 0 : throw std::invalid_argument("The quantum number j must be positive.");
198 : }
199 900 : if (description.quantum_number_m.has_value() &&
200 450 : 2 * description.quantum_number_m.value() !=
201 450 : std::rint(2 * description.quantum_number_m.value())) {
202 0 : throw std::invalid_argument("The quantum number m must be an integer or half-integer.");
203 : }
204 :
205 : // Describe the state
206 450 : std::string where;
207 450 : std::string separator;
208 450 : if (description.energy.has_value()) {
209 : // The following condition derives from demanding that quantum number n that corresponds to
210 : // the energy "E_n = -1/(2*n^2)" is not off by more than 1 from the actual quantum number n,
211 : // i.e., "sqrt(-1/(2*E_n)) - sqrt(-1/(2*E_{n-1})) = 1"
212 0 : where += separator +
213 0 : fmt::format("SQRT(-1/(2*energy)) BETWEEN {} AND {}",
214 0 : std::sqrt(-1 / (2 * description.energy.value())) - 0.5,
215 0 : std::sqrt(-1 / (2 * description.energy.value())) + 0.5);
216 0 : separator = " AND ";
217 : }
218 450 : if (description.quantum_number_f.has_value()) {
219 66 : where += separator + fmt::format("f = {}", description.quantum_number_f.value());
220 33 : separator = " AND ";
221 : }
222 450 : if (description.parity != Parity::UNKNOWN) {
223 0 : where += separator + fmt::format("parity = {}", fmt::streamed(description.parity));
224 0 : separator = " AND ";
225 : }
226 450 : if (description.quantum_number_n.has_value()) {
227 890 : where += separator + fmt::format("n = {}", description.quantum_number_n.value());
228 445 : separator = " AND ";
229 : }
230 450 : if (description.quantum_number_nu.has_value()) {
231 5 : where += separator +
232 15 : fmt::format("nu BETWEEN {} AND {}", description.quantum_number_nu.value() - 0.5,
233 10 : description.quantum_number_nu.value() + 0.5);
234 5 : separator = " AND ";
235 : }
236 450 : if (description.quantum_number_nui.has_value()) {
237 0 : where += separator +
238 0 : fmt::format("exp_nui BETWEEN {} AND {}", description.quantum_number_nui.value() - 0.5,
239 0 : description.quantum_number_nui.value() + 0.5);
240 0 : separator = " AND ";
241 : }
242 450 : if (description.quantum_number_l.has_value()) {
243 450 : where += separator +
244 1350 : fmt::format("exp_l BETWEEN {} AND {}", description.quantum_number_l.value() - 0.5,
245 900 : description.quantum_number_l.value() + 0.5);
246 450 : separator = " AND ";
247 : }
248 450 : if (description.quantum_number_s.has_value()) {
249 32 : where += separator +
250 96 : fmt::format("exp_s BETWEEN {} AND {}", description.quantum_number_s.value() - 0.5,
251 64 : description.quantum_number_s.value() + 0.5);
252 32 : separator = " AND ";
253 : }
254 450 : if (description.quantum_number_j.has_value()) {
255 376 : where += separator +
256 1128 : fmt::format("exp_j BETWEEN {} AND {}", description.quantum_number_j.value() - 0.5,
257 752 : description.quantum_number_j.value() + 0.5);
258 376 : separator = " AND ";
259 : }
260 450 : if (description.quantum_number_l_ryd.has_value()) {
261 0 : where += separator +
262 0 : fmt::format("exp_l_ryd BETWEEN {} AND {}",
263 0 : description.quantum_number_l_ryd.value() - 0.5,
264 0 : description.quantum_number_l_ryd.value() + 0.5);
265 0 : separator = " AND ";
266 : }
267 450 : if (description.quantum_number_j_ryd.has_value()) {
268 0 : where += separator +
269 0 : fmt::format("exp_j_ryd BETWEEN {} AND {}",
270 0 : description.quantum_number_j_ryd.value() - 0.5,
271 0 : description.quantum_number_j_ryd.value() + 0.5);
272 0 : separator = " AND ";
273 : }
274 450 : if (separator.empty()) {
275 0 : where += "FALSE";
276 : }
277 :
278 450 : std::string orderby;
279 450 : separator = "";
280 450 : if (description.energy.has_value()) {
281 0 : orderby += separator +
282 0 : fmt::format("(SQRT(-1/(2*energy)) - {})^2",
283 0 : std::sqrt(-1 / (2 * description.energy.value())));
284 0 : separator = " + ";
285 : }
286 450 : if (description.quantum_number_nu.has_value()) {
287 10 : orderby += separator + fmt::format("(nu - {})^2", description.quantum_number_nu.value());
288 5 : separator = " + ";
289 : }
290 450 : if (description.quantum_number_nui.has_value()) {
291 : orderby +=
292 0 : separator + fmt::format("(exp_nui - {})^2", description.quantum_number_nui.value());
293 0 : separator = " + ";
294 : }
295 450 : if (description.quantum_number_l.has_value()) {
296 900 : orderby += separator + fmt::format("(exp_l - {})^2", description.quantum_number_l.value());
297 450 : separator = " + ";
298 : }
299 450 : if (description.quantum_number_s.has_value()) {
300 64 : orderby += separator + fmt::format("(exp_s - {})^2", description.quantum_number_s.value());
301 32 : separator = " + ";
302 : }
303 450 : if (description.quantum_number_j.has_value()) {
304 752 : orderby += separator + fmt::format("(exp_j - {})^2", description.quantum_number_j.value());
305 376 : separator = " + ";
306 : }
307 450 : if (description.quantum_number_l_ryd.has_value()) {
308 : orderby +=
309 0 : separator + fmt::format("(exp_l_ryd - {})^2", description.quantum_number_l_ryd.value());
310 0 : separator = " + ";
311 : }
312 450 : if (description.quantum_number_j_ryd.has_value()) {
313 : orderby +=
314 0 : separator + fmt::format("(exp_j_ryd - {})^2", description.quantum_number_j_ryd.value());
315 0 : separator = " + ";
316 : }
317 450 : if (separator.empty()) {
318 0 : orderby += "id";
319 : }
320 :
321 : // Ask the database for the described state
322 450 : set_task_status("Loading atomic ket from database...");
323 900 : auto result = con->Query(fmt::format(
324 : R"(SELECT energy, f, parity, id, n, nu, exp_nui, std_nui, exp_l, std_l, exp_s, std_s,
325 : exp_j, std_j, exp_l_ryd, std_l_ryd, exp_j_ryd, std_j_ryd, is_j_total_momentum, is_calculated_with_mqdt, underspecified_channel_contribution, {} AS order_val FROM '{}' WHERE {} ORDER BY order_val ASC LIMIT 2)",
326 1350 : orderby, manager->get_path(species, "states"), where));
327 :
328 450 : if (result->HasError()) {
329 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
330 : }
331 :
332 450 : if (result->RowCount() == 0) {
333 126 : throw std::invalid_argument("No state found.");
334 : }
335 :
336 : // Check the types of the columns
337 324 : const auto &types = result->types;
338 324 : const auto &labels = result->names;
339 : const std::vector<duckdb::LogicalType> ref_types = {
340 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::BIGINT,
341 : duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT, duckdb::LogicalType::DOUBLE,
342 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
343 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
344 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
345 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
346 : duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::DOUBLE,
347 7452 : duckdb::LogicalType::DOUBLE};
348 :
349 7452 : for (size_t i = 0; i < types.size(); i++) {
350 7128 : if (types[i] != ref_types[i]) {
351 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'. Got " +
352 0 : types[i].ToString() + " but expected " +
353 0 : ref_types[i].ToString());
354 : }
355 : }
356 :
357 : // Get the first chunk of the results (the first chunk is sufficient as we need two rows at
358 : // most)
359 324 : auto chunk = result->Fetch();
360 :
361 : // Check that the ket is uniquely specified
362 324 : if (chunk->size() > 1) {
363 4 : auto order_val_0 = duckdb::FlatVector::GetData<double>(chunk->data[21])[0];
364 4 : auto order_val_1 = duckdb::FlatVector::GetData<double>(chunk->data[21])[1];
365 :
366 4 : if (order_val_1 - order_val_0 <= order_val_0) {
367 : // Get a list of possible kets
368 1 : std::vector<KetAtom> kets;
369 1 : kets.reserve(2);
370 3 : for (size_t i = 0; i < 2; ++i) {
371 2 : auto result_quantum_number_m = description.quantum_number_m.value();
372 2 : auto result_energy = duckdb::FlatVector::GetData<double>(chunk->data[0])[i];
373 : auto result_quantum_number_f =
374 2 : duckdb::FlatVector::GetData<double>(chunk->data[1])[i];
375 2 : auto result_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[2])[i];
376 4 : auto result_id = utils::get_linearized_id_in_database(
377 2 : duckdb::FlatVector::GetData<int64_t>(chunk->data[3])[i],
378 2 : result_quantum_number_m);
379 : auto result_quantum_number_n =
380 2 : duckdb::FlatVector::GetData<int64_t>(chunk->data[4])[i];
381 : auto result_quantum_number_nu =
382 2 : duckdb::FlatVector::GetData<double>(chunk->data[5])[i];
383 : auto result_quantum_number_nui_exp =
384 2 : duckdb::FlatVector::GetData<double>(chunk->data[6])[i];
385 : auto result_quantum_number_nui_std =
386 2 : duckdb::FlatVector::GetData<double>(chunk->data[7])[i];
387 : auto result_quantum_number_l_exp =
388 2 : duckdb::FlatVector::GetData<double>(chunk->data[8])[i];
389 : auto result_quantum_number_l_std =
390 2 : duckdb::FlatVector::GetData<double>(chunk->data[9])[i];
391 : auto result_quantum_number_s_exp =
392 2 : duckdb::FlatVector::GetData<double>(chunk->data[10])[i];
393 : auto result_quantum_number_s_std =
394 2 : duckdb::FlatVector::GetData<double>(chunk->data[11])[i];
395 : auto result_quantum_number_j_exp =
396 2 : duckdb::FlatVector::GetData<double>(chunk->data[12])[i];
397 : auto result_quantum_number_j_std =
398 2 : duckdb::FlatVector::GetData<double>(chunk->data[13])[i];
399 : auto result_quantum_number_l_ryd_exp =
400 2 : duckdb::FlatVector::GetData<double>(chunk->data[14])[i];
401 : auto result_quantum_number_l_ryd_std =
402 2 : duckdb::FlatVector::GetData<double>(chunk->data[15])[i];
403 : auto result_quantum_number_j_ryd_exp =
404 2 : duckdb::FlatVector::GetData<double>(chunk->data[16])[i];
405 : auto result_quantum_number_j_ryd_std =
406 2 : duckdb::FlatVector::GetData<double>(chunk->data[17])[i];
407 : auto result_is_j_total_momentum =
408 2 : duckdb::FlatVector::GetData<bool>(chunk->data[18])[i];
409 : auto result_is_calculated_with_mqdt =
410 2 : duckdb::FlatVector::GetData<bool>(chunk->data[19])[i];
411 : auto result_underspecified_channel_contribution =
412 2 : duckdb::FlatVector::GetData<double>(chunk->data[20])[i];
413 0 : kets.emplace_back(typename KetAtom::Private(), result_energy,
414 : result_quantum_number_f, result_quantum_number_m,
415 2 : static_cast<Parity>(result_parity), species,
416 : result_quantum_number_n, result_quantum_number_nu,
417 : result_quantum_number_nui_exp, result_quantum_number_nui_std,
418 : result_quantum_number_l_exp, result_quantum_number_l_std,
419 : result_quantum_number_s_exp, result_quantum_number_s_std,
420 : result_quantum_number_j_exp, result_quantum_number_j_std,
421 : result_quantum_number_l_ryd_exp, result_quantum_number_l_ryd_std,
422 : result_quantum_number_j_ryd_exp, result_quantum_number_j_ryd_std,
423 : result_is_j_total_momentum, result_is_calculated_with_mqdt,
424 : result_underspecified_channel_contribution, *this, result_id);
425 : }
426 :
427 : // Throw an error with the possible kets
428 2 : throw std::invalid_argument(
429 2 : fmt::format("The ket is not uniquely specified. Possible kets are:\n{}\n{}",
430 2 : fmt::streamed(kets[0]), fmt::streamed(kets[1])));
431 1 : }
432 : }
433 :
434 : // Construct the state
435 323 : auto result_quantum_number_m = description.quantum_number_m.value();
436 323 : auto result_energy = duckdb::FlatVector::GetData<double>(chunk->data[0])[0];
437 323 : auto result_quantum_number_f = duckdb::FlatVector::GetData<double>(chunk->data[1])[0];
438 323 : auto result_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[2])[0];
439 646 : auto result_id = utils::get_linearized_id_in_database(
440 323 : duckdb::FlatVector::GetData<int64_t>(chunk->data[3])[0], result_quantum_number_m);
441 323 : auto result_quantum_number_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[4])[0];
442 323 : auto result_quantum_number_nu = duckdb::FlatVector::GetData<double>(chunk->data[5])[0];
443 323 : auto result_quantum_number_nui_exp = duckdb::FlatVector::GetData<double>(chunk->data[6])[0];
444 323 : auto result_quantum_number_nui_std = duckdb::FlatVector::GetData<double>(chunk->data[7])[0];
445 323 : auto result_quantum_number_l_exp = duckdb::FlatVector::GetData<double>(chunk->data[8])[0];
446 323 : auto result_quantum_number_l_std = duckdb::FlatVector::GetData<double>(chunk->data[9])[0];
447 323 : auto result_quantum_number_s_exp = duckdb::FlatVector::GetData<double>(chunk->data[10])[0];
448 323 : auto result_quantum_number_s_std = duckdb::FlatVector::GetData<double>(chunk->data[11])[0];
449 323 : auto result_quantum_number_j_exp = duckdb::FlatVector::GetData<double>(chunk->data[12])[0];
450 323 : auto result_quantum_number_j_std = duckdb::FlatVector::GetData<double>(chunk->data[13])[0];
451 323 : auto result_quantum_number_l_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[14])[0];
452 323 : auto result_quantum_number_l_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[15])[0];
453 323 : auto result_quantum_number_j_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[16])[0];
454 323 : auto result_quantum_number_j_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[17])[0];
455 323 : auto result_is_j_total_momentum = duckdb::FlatVector::GetData<bool>(chunk->data[18])[0];
456 323 : auto result_is_calculated_with_mqdt = duckdb::FlatVector::GetData<bool>(chunk->data[19])[0];
457 : auto result_underspecified_channel_contribution =
458 323 : duckdb::FlatVector::GetData<double>(chunk->data[20])[0];
459 :
460 : // Check database consistency
461 323 : ensure_consistent_quantum_numbers(result_is_j_total_momentum, result_quantum_number_f,
462 : result_quantum_number_m, result_quantum_number_j_exp);
463 :
464 : return std::make_shared<const KetAtom>(
465 0 : typename KetAtom::Private(), result_energy, result_quantum_number_f,
466 0 : result_quantum_number_m, static_cast<Parity>(result_parity), species,
467 : result_quantum_number_n, result_quantum_number_nu, result_quantum_number_nui_exp,
468 : result_quantum_number_nui_std, result_quantum_number_l_exp, result_quantum_number_l_std,
469 : result_quantum_number_s_exp, result_quantum_number_s_std, result_quantum_number_j_exp,
470 : result_quantum_number_j_std, result_quantum_number_l_ryd_exp,
471 : result_quantum_number_l_ryd_std, result_quantum_number_j_ryd_exp,
472 : result_quantum_number_j_ryd_std, result_is_j_total_momentum, result_is_calculated_with_mqdt,
473 644 : result_underspecified_channel_contribution, *this, result_id);
474 1162 : }
475 :
476 : template <typename Scalar>
477 : std::shared_ptr<const BasisAtom<Scalar>>
478 233 : Database::get_basis(const std::string &species, const AtomDescriptionByRanges &description,
479 : std::vector<size_t> additional_ket_ids) {
480 : // Describe the states
481 233 : std::string where = "(";
482 233 : std::string separator;
483 233 : if (description.parity != Parity::UNKNOWN) {
484 0 : where += separator + fmt::format("parity = {}", fmt::streamed(description.parity));
485 0 : separator = " AND ";
486 : }
487 233 : if (description.range_energy.is_finite()) {
488 63 : where += separator +
489 63 : fmt::format("energy BETWEEN {} AND {}", description.range_energy.min(),
490 63 : description.range_energy.max());
491 63 : separator = " AND ";
492 : }
493 233 : if (description.range_quantum_number_f.is_finite()) {
494 2 : where += separator +
495 2 : fmt::format("f BETWEEN {} AND {}", description.range_quantum_number_f.min(),
496 2 : description.range_quantum_number_f.max());
497 2 : separator = " AND ";
498 : }
499 233 : if (description.range_quantum_number_m.is_finite()) {
500 80 : where += separator +
501 80 : fmt::format("m BETWEEN {} AND {}", description.range_quantum_number_m.min(),
502 80 : description.range_quantum_number_m.max());
503 80 : separator = " AND ";
504 : }
505 233 : if (description.range_quantum_number_n.is_finite()) {
506 196 : where += separator +
507 196 : fmt::format("n BETWEEN {} AND {}", description.range_quantum_number_n.min(),
508 196 : description.range_quantum_number_n.max());
509 196 : separator = " AND ";
510 : }
511 233 : if (description.range_quantum_number_nu.is_finite()) {
512 7 : where += separator +
513 7 : fmt::format("nu BETWEEN {} AND {}", description.range_quantum_number_nu.min(),
514 7 : description.range_quantum_number_nu.max());
515 7 : separator = " AND ";
516 : }
517 233 : if (description.range_quantum_number_nui.is_finite()) {
518 2 : where += separator +
519 : fmt::format("exp_nui BETWEEN {}-2*std_nui AND {}+2*std_nui",
520 2 : description.range_quantum_number_nui.min(),
521 2 : description.range_quantum_number_nui.max());
522 2 : separator = " AND ";
523 : }
524 233 : if (description.range_quantum_number_l.is_finite()) {
525 185 : where += separator +
526 : fmt::format("exp_l BETWEEN {}-2*std_l AND {}+2*std_l",
527 185 : description.range_quantum_number_l.min(),
528 185 : description.range_quantum_number_l.max());
529 185 : separator = " AND ";
530 : }
531 233 : if (description.range_quantum_number_s.is_finite()) {
532 2 : where += separator +
533 : fmt::format("exp_s BETWEEN {}-2*std_s AND {}+2*std_s",
534 2 : description.range_quantum_number_s.min(),
535 2 : description.range_quantum_number_s.max());
536 2 : separator = " AND ";
537 : }
538 233 : if (description.range_quantum_number_j.is_finite()) {
539 2 : where += separator +
540 : fmt::format("exp_j BETWEEN {}-2*std_j AND {}+2*std_j",
541 2 : description.range_quantum_number_j.min(),
542 2 : description.range_quantum_number_j.max());
543 2 : separator = " AND ";
544 : }
545 233 : if (description.range_quantum_number_l_ryd.is_finite()) {
546 2 : where += separator +
547 : fmt::format("exp_l_ryd BETWEEN {}-2*std_l_ryd AND {}+2*std_l_ryd",
548 2 : description.range_quantum_number_l_ryd.min(),
549 2 : description.range_quantum_number_l_ryd.max());
550 2 : separator = " AND ";
551 : }
552 233 : if (description.range_quantum_number_j_ryd.is_finite()) {
553 2 : where += separator +
554 : fmt::format("exp_j_ryd BETWEEN {}-2*std_j_ryd AND {}+2*std_j_ryd",
555 2 : description.range_quantum_number_j_ryd.min(),
556 2 : description.range_quantum_number_j_ryd.max());
557 2 : separator = " AND ";
558 : }
559 233 : if (separator.empty()) {
560 32 : where += "FALSE";
561 : }
562 233 : where += ")";
563 233 : if (!additional_ket_ids.empty()) {
564 156 : where += fmt::format(" OR {} IN ({})", utils::SQL_TERM_FOR_LINEARIZED_ID_IN_DATABASE,
565 156 : fmt::join(additional_ket_ids, ","));
566 : }
567 :
568 : // Create a table containing the described states
569 233 : std::string id_of_kets;
570 : {
571 233 : auto result = con->Query(R"(SELECT UUID()::varchar)");
572 233 : if (result->HasError()) {
573 0 : throw cpptrace::runtime_error("Error selecting id_of_kets: " + result->GetError());
574 : }
575 233 : id_of_kets =
576 466 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
577 233 : }
578 : {
579 233 : set_task_status("Selecting atomic basis states...");
580 699 : auto result = con->Query(fmt::format(
581 : R"(CREATE TEMP TABLE '{}' AS SELECT *, {} AS ketid FROM (
582 : SELECT *,
583 : UNNEST(list_transform(generate_series(0,(2*f)::bigint),
584 : x -> x::double-f)) AS m FROM '{}'
585 : ) WHERE {})",
586 : id_of_kets, utils::SQL_TERM_FOR_LINEARIZED_ID_IN_DATABASE,
587 466 : manager->get_path(species, "states"), where));
588 :
589 233 : if (result->HasError()) {
590 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
591 : }
592 233 : }
593 :
594 : // Ask the table for the extreme values of the quantum numbers
595 : {
596 233 : set_task_status("Validating atomic basis coverage...");
597 233 : std::string select;
598 233 : std::string separator;
599 233 : if (description.range_energy.is_finite()) {
600 63 : select += separator + "MIN(energy) AS min_energy, MAX(energy) AS max_energy";
601 63 : separator = ", ";
602 : }
603 233 : if (description.range_quantum_number_f.is_finite()) {
604 2 : select += separator + "MIN(f) AS min_f, MAX(f) AS max_f";
605 2 : separator = ", ";
606 : }
607 233 : if (description.range_quantum_number_m.is_finite()) {
608 80 : select += separator + "MIN(m) AS min_m, MAX(m) AS max_m";
609 80 : separator = ", ";
610 : }
611 233 : if (description.range_quantum_number_n.is_finite()) {
612 196 : select += separator + "MIN(n) AS min_n, MAX(n) AS max_n";
613 196 : separator = ", ";
614 : }
615 233 : if (description.range_quantum_number_nu.is_finite()) {
616 7 : select += separator + "MIN(nu) AS min_nu, MAX(nu) AS max_nu";
617 7 : separator = ", ";
618 : }
619 233 : if (description.range_quantum_number_nui.is_finite()) {
620 2 : select += separator + "MIN(exp_nui) AS min_nui, MAX(exp_nui) AS max_nui";
621 2 : separator = ", ";
622 : }
623 233 : if (description.range_quantum_number_l.is_finite()) {
624 185 : select += separator + "MIN(exp_l) AS min_l, MAX(exp_l) AS max_l";
625 185 : separator = ", ";
626 : }
627 233 : if (description.range_quantum_number_s.is_finite()) {
628 2 : select += separator + "MIN(exp_s) AS min_s, MAX(exp_s) AS max_s";
629 2 : separator = ", ";
630 : }
631 233 : if (description.range_quantum_number_j.is_finite()) {
632 2 : select += separator + "MIN(exp_j) AS min_j, MAX(exp_j) AS max_j";
633 2 : separator = ", ";
634 : }
635 233 : if (description.range_quantum_number_l_ryd.is_finite()) {
636 2 : select += separator + "MIN(exp_l_ryd) AS min_l_ryd, MAX(exp_l_ryd) AS max_l_ryd";
637 2 : separator = ", ";
638 : }
639 233 : if (description.range_quantum_number_j_ryd.is_finite()) {
640 2 : select += separator + "MIN(exp_j_ryd) AS min_j_ryd, MAX(exp_j_ryd) AS max_j_ryd";
641 2 : separator = ", ";
642 : }
643 :
644 233 : if (!separator.empty()) {
645 402 : auto result = con->Query(fmt::format(R"(SELECT {} FROM '{}')", select, id_of_kets));
646 :
647 201 : if (result->HasError()) {
648 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
649 : }
650 :
651 201 : auto chunk = result->Fetch();
652 :
653 1287 : for (size_t i = 0; i < chunk->ColumnCount(); i++) {
654 1086 : if (duckdb::FlatVector::IsNull(chunk->data[i], 0)) {
655 0 : throw std::invalid_argument("No state found.");
656 : }
657 : }
658 :
659 201 : size_t idx = 0;
660 201 : if (description.range_energy.is_finite()) {
661 63 : auto min_energy = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
662 126 : if (std::sqrt(-1 / (2 * min_energy)) - 1 >
663 63 : std::sqrt(-1 / (2 * description.range_energy.min()))) {
664 0 : SPDLOG_DEBUG("No state found with the requested minimum energy. Requested: {}, "
665 : "found: {}.",
666 : description.range_energy.min(), min_energy);
667 : }
668 63 : auto max_energy = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
669 126 : if (std::sqrt(-1 / (2 * max_energy)) + 1 <
670 63 : std::sqrt(-1 / (2 * description.range_energy.max()))) {
671 0 : SPDLOG_DEBUG("No state found with the requested maximum energy. Requested: {}, "
672 : "found: {}.",
673 : description.range_energy.max(), max_energy);
674 : }
675 : }
676 201 : if (description.range_quantum_number_f.is_finite()) {
677 2 : auto min_f = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
678 2 : if (min_f > description.range_quantum_number_f.min()) {
679 2 : SPDLOG_DEBUG("No state found with the requested minimum quantum number f. "
680 : "Requested: {}, found: {}.",
681 : description.range_quantum_number_f.min(), min_f);
682 : }
683 2 : auto max_f = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
684 2 : if (max_f < description.range_quantum_number_f.max()) {
685 2 : SPDLOG_DEBUG("No state found with the requested maximum quantum number f. "
686 : "Requested: {}, found: {}.",
687 : description.range_quantum_number_f.max(), max_f);
688 : }
689 : }
690 201 : if (description.range_quantum_number_m.is_finite()) {
691 80 : auto min_m = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
692 80 : if (min_m > description.range_quantum_number_m.min()) {
693 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number m. "
694 : "Requested: {}, found: {}.",
695 : description.range_quantum_number_m.min(), min_m);
696 : }
697 80 : auto max_m = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
698 80 : if (max_m < description.range_quantum_number_m.max()) {
699 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number m. "
700 : "Requested: {}, found: {}.",
701 : description.range_quantum_number_m.max(), max_m);
702 : }
703 : }
704 201 : if (description.range_quantum_number_n.is_finite()) {
705 196 : auto min_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[idx++])[0];
706 196 : if (min_n > description.range_quantum_number_n.min()) {
707 46 : SPDLOG_DEBUG("No state found with the requested minimum quantum number n. "
708 : "Requested: {}, found: {}.",
709 : description.range_quantum_number_n.min(), min_n);
710 : }
711 196 : auto max_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[idx++])[0];
712 196 : if (max_n < description.range_quantum_number_n.max()) {
713 36 : SPDLOG_DEBUG("No state found with the requested maximum quantum number n. "
714 : "Requested: {}, found: {}.",
715 : description.range_quantum_number_n.max(), max_n);
716 : }
717 : }
718 201 : if (description.range_quantum_number_nu.is_finite()) {
719 7 : auto min_nu = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
720 7 : if (min_nu - 1 > description.range_quantum_number_nu.min()) {
721 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number nu. "
722 : "Requested: {}, found: {}.",
723 : description.range_quantum_number_nu.min(), min_nu);
724 : }
725 7 : auto max_nu = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
726 7 : if (max_nu + 1 < description.range_quantum_number_nu.max()) {
727 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number nu. "
728 : "Requested: {}, found: {}.",
729 : description.range_quantum_number_nu.max(), max_nu);
730 : }
731 : }
732 201 : if (description.range_quantum_number_nui.is_finite()) {
733 2 : auto min_nui = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
734 2 : if (min_nui - 1 > description.range_quantum_number_nui.min()) {
735 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number nui. "
736 : "Requested: {}, found: {}.",
737 : description.range_quantum_number_nui.min(), min_nui);
738 : }
739 2 : auto max_nui = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
740 2 : if (max_nui + 1 < description.range_quantum_number_nui.max()) {
741 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number nui. "
742 : "Requested: {}, found: {}.",
743 : description.range_quantum_number_nui.max(), max_nui);
744 : }
745 : }
746 201 : if (description.range_quantum_number_l.is_finite()) {
747 185 : auto min_l = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
748 185 : if (min_l - 1 > description.range_quantum_number_l.min()) {
749 29 : SPDLOG_DEBUG("No state found with the requested minimum quantum number l. "
750 : "Requested: {}, found: {}.",
751 : description.range_quantum_number_l.min(), min_l);
752 : }
753 185 : auto max_l = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
754 185 : if (max_l + 1 < description.range_quantum_number_l.max()) {
755 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number l. "
756 : "Requested: {}, found: {}.",
757 : description.range_quantum_number_l.max(), max_l);
758 : }
759 : }
760 201 : if (description.range_quantum_number_s.is_finite()) {
761 2 : auto min_s = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
762 2 : if (min_s - 1 > description.range_quantum_number_s.min()) {
763 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number s. "
764 : "Requested: {}, found: {}.",
765 : description.range_quantum_number_s.min(), min_s);
766 : }
767 2 : auto max_s = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
768 2 : if (max_s + 1 < description.range_quantum_number_s.max()) {
769 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number s. "
770 : "Requested: {}, found: {}.",
771 : description.range_quantum_number_s.max(), max_s);
772 : }
773 : }
774 201 : if (description.range_quantum_number_j.is_finite()) {
775 2 : auto min_j = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
776 2 : if (min_j - 1 > description.range_quantum_number_j.min()) {
777 2 : SPDLOG_DEBUG("No state found with the requested minimum quantum number j. "
778 : "Requested: {}, found: {}.",
779 : description.range_quantum_number_j.min(), min_j);
780 : }
781 2 : auto max_j = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
782 2 : if (max_j + 1 < description.range_quantum_number_j.max()) {
783 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number j. "
784 : "Requested: {}, found: {}.",
785 : description.range_quantum_number_j.max(), max_j);
786 : }
787 : }
788 201 : if (description.range_quantum_number_l_ryd.is_finite()) {
789 2 : auto min_l_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
790 2 : if (min_l_ryd - 1 > description.range_quantum_number_l_ryd.min()) {
791 2 : SPDLOG_DEBUG("No state found with the requested minimum quantum number l_ryd. "
792 : "Requested: {}, found: {}.",
793 : description.range_quantum_number_l_ryd.min(), min_l_ryd);
794 : }
795 2 : auto max_l_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
796 2 : if (max_l_ryd + 1 < description.range_quantum_number_l_ryd.max()) {
797 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number l_ryd. "
798 : "Requested: {}, found: {}.",
799 : description.range_quantum_number_l_ryd.max(), max_l_ryd);
800 : }
801 : }
802 201 : if (description.range_quantum_number_j_ryd.is_finite()) {
803 2 : auto min_j_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
804 2 : if (min_j_ryd - 1 > description.range_quantum_number_j_ryd.min()) {
805 2 : SPDLOG_DEBUG("No state found with the requested minimum quantum number j_ryd. "
806 : "Requested: {}, found: {}.",
807 : description.range_quantum_number_j_ryd.min(), min_j_ryd);
808 : }
809 2 : auto max_j_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
810 2 : if (max_j_ryd + 1 < description.range_quantum_number_j_ryd.max()) {
811 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number j_ryd. "
812 : "Requested: {}, found: {}.",
813 : description.range_quantum_number_j_ryd.max(), max_j_ryd);
814 : }
815 : }
816 201 : }
817 233 : }
818 :
819 : // Ask the table for the described states
820 233 : set_task_status("Loading atomic basis states...");
821 466 : auto result = con->Query(fmt::format(
822 : R"(SELECT energy, f, m, parity, ketid, n, nu, exp_nui, std_nui, exp_l, std_l,
823 : exp_s, std_s, exp_j, std_j, exp_l_ryd, std_l_ryd, exp_j_ryd, std_j_ryd, is_j_total_momentum, is_calculated_with_mqdt, underspecified_channel_contribution FROM '{}' ORDER BY ketid ASC)",
824 : id_of_kets));
825 :
826 233 : if (result->HasError()) {
827 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
828 : }
829 :
830 233 : if (result->RowCount() == 0) {
831 0 : throw std::invalid_argument("No state found.");
832 : }
833 :
834 : // Check the types of the columns
835 233 : const auto &types = result->types;
836 233 : const auto &labels = result->names;
837 5359 : const std::vector<duckdb::LogicalType> ref_types = {
838 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
839 : duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT,
840 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
841 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
842 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
843 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
844 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::BOOLEAN,
845 : duckdb::LogicalType::DOUBLE};
846 :
847 5359 : for (size_t i = 0; i < types.size(); i++) {
848 5126 : if (types[i] != ref_types[i]) {
849 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'. Got " +
850 0 : types[i].ToString() + " but expected " +
851 0 : ref_types[i].ToString());
852 : }
853 : }
854 :
855 : // Construct the states
856 233 : std::vector<std::shared_ptr<const KetAtom>> kets;
857 233 : kets.reserve(result->RowCount());
858 233 : double last_energy = std::numeric_limits<double>::lowest();
859 233 : bool is_calculated_with_mqdt = false;
860 233 : double min_quantum_number_nu = std::numeric_limits<double>::max();
861 :
862 466 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
863 233 : set_task_status("Constructing atomic basis...");
864 :
865 233 : auto *chunk_energy = duckdb::FlatVector::GetData<double>(chunk->data[0]);
866 233 : auto *chunk_quantum_number_f = duckdb::FlatVector::GetData<double>(chunk->data[1]);
867 233 : auto *chunk_quantum_number_m = duckdb::FlatVector::GetData<double>(chunk->data[2]);
868 233 : auto *chunk_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[3]);
869 233 : auto *chunk_id = duckdb::FlatVector::GetData<int64_t>(chunk->data[4]);
870 233 : auto *chunk_quantum_number_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[5]);
871 233 : auto *chunk_quantum_number_nu = duckdb::FlatVector::GetData<double>(chunk->data[6]);
872 233 : auto *chunk_quantum_number_nui_exp = duckdb::FlatVector::GetData<double>(chunk->data[7]);
873 233 : auto *chunk_quantum_number_nui_std = duckdb::FlatVector::GetData<double>(chunk->data[8]);
874 233 : auto *chunk_quantum_number_l_exp = duckdb::FlatVector::GetData<double>(chunk->data[9]);
875 233 : auto *chunk_quantum_number_l_std = duckdb::FlatVector::GetData<double>(chunk->data[10]);
876 233 : auto *chunk_quantum_number_s_exp = duckdb::FlatVector::GetData<double>(chunk->data[11]);
877 233 : auto *chunk_quantum_number_s_std = duckdb::FlatVector::GetData<double>(chunk->data[12]);
878 233 : auto *chunk_quantum_number_j_exp = duckdb::FlatVector::GetData<double>(chunk->data[13]);
879 233 : auto *chunk_quantum_number_j_std = duckdb::FlatVector::GetData<double>(chunk->data[14]);
880 233 : auto *chunk_quantum_number_l_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[15]);
881 233 : auto *chunk_quantum_number_l_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[16]);
882 233 : auto *chunk_quantum_number_j_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[17]);
883 233 : auto *chunk_quantum_number_j_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[18]);
884 233 : auto *chunk_is_j_total_momentum = duckdb::FlatVector::GetData<bool>(chunk->data[19]);
885 233 : auto *chunk_is_calculated_with_mqdt = duckdb::FlatVector::GetData<bool>(chunk->data[20]);
886 : auto *chunk_underspecified_channel_contribution =
887 233 : duckdb::FlatVector::GetData<double>(chunk->data[21]);
888 :
889 27352 : for (size_t i = 0; i < chunk->size(); i++) {
890 :
891 : // Check database consistency
892 27119 : ensure_consistent_quantum_numbers(chunk_is_j_total_momentum[i],
893 27119 : chunk_quantum_number_f[i], chunk_quantum_number_m[i],
894 27119 : chunk_quantum_number_j_exp[i]);
895 27119 : if (chunk_energy[i] < last_energy) {
896 0 : throw std::runtime_error("The states are not sorted by energy.");
897 : }
898 27119 : last_energy = chunk_energy[i];
899 :
900 27119 : is_calculated_with_mqdt |= chunk_is_calculated_with_mqdt[i];
901 27119 : min_quantum_number_nu = std::min(min_quantum_number_nu, chunk_quantum_number_nu[i]);
902 :
903 : // Append a new state
904 27119 : kets.push_back(std::make_shared<const KetAtom>(
905 0 : typename KetAtom::Private(), chunk_energy[i], chunk_quantum_number_f[i],
906 27119 : chunk_quantum_number_m[i], static_cast<Parity>(chunk_parity[i]), species,
907 27119 : chunk_quantum_number_n[i], chunk_quantum_number_nu[i],
908 27119 : chunk_quantum_number_nui_exp[i], chunk_quantum_number_nui_std[i],
909 27119 : chunk_quantum_number_l_exp[i], chunk_quantum_number_l_std[i],
910 27119 : chunk_quantum_number_s_exp[i], chunk_quantum_number_s_std[i],
911 27119 : chunk_quantum_number_j_exp[i], chunk_quantum_number_j_std[i],
912 27119 : chunk_quantum_number_l_ryd_exp[i], chunk_quantum_number_l_ryd_std[i],
913 27119 : chunk_quantum_number_j_ryd_exp[i], chunk_quantum_number_j_ryd_std[i],
914 27119 : chunk_is_j_total_momentum[i], chunk_is_calculated_with_mqdt[i],
915 27119 : chunk_underspecified_channel_contribution[i], *this, chunk_id[i]));
916 : }
917 : }
918 :
919 : // Show a warning for low-lying states
920 233 : if (min_quantum_number_nu < 25) {
921 36 : if (is_calculated_with_mqdt) {
922 16 : SPDLOG_WARN("The multi-channel quantum defect theory might produce inaccurate results "
923 : "for effective principal quantum numbers < 25. The models get increasingly "
924 : "unreliable for small principal quantum numbers, leading to inaccurate "
925 : "matrix elements and energies. Due to missing data, even some states might "
926 : "not be present.");
927 : } else {
928 56 : SPDLOG_WARN(
929 : "The single-channel quantum defect theory can be inaccurate for effective "
930 : "principal quantum numbers < 25. This can lead to inaccurate matrix elements.");
931 : }
932 : }
933 :
934 0 : return std::make_shared<const BasisAtom<Scalar>>(typename BasisAtom<Scalar>::Private(),
935 466 : std::move(kets), std::move(id_of_kets), *this);
936 466 : }
937 :
938 : template <typename Scalar>
939 4013 : Eigen::SparseMatrix<Scalar, Eigen::RowMajor> Database::get_matrix_elements_in_canonical_basis(
940 : std::shared_ptr<const BasisAtom<Scalar>> initial_basis,
941 : std::shared_ptr<const BasisAtom<Scalar>> final_basis, OperatorType type, int q) {
942 : using real_t = typename traits::NumTraits<Scalar>::real_t;
943 : using cached_matrix_ptr_t = std::shared_ptr<const cached_matrix_t>;
944 :
945 4013 : std::string specifier;
946 4013 : int kappa{};
947 4013 : switch (type) {
948 3640 : case OperatorType::ELECTRIC_DIPOLE:
949 3640 : specifier = "matrix_elements_d";
950 3640 : kappa = 1;
951 3640 : break;
952 257 : case OperatorType::ELECTRIC_QUADRUPOLE:
953 257 : specifier = "matrix_elements_q";
954 257 : kappa = 2;
955 257 : break;
956 49 : case OperatorType::ELECTRIC_QUADRUPOLE_ZERO:
957 49 : specifier = "matrix_elements_q0";
958 49 : kappa = 0;
959 49 : break;
960 0 : case OperatorType::ELECTRIC_OCTUPOLE:
961 0 : specifier = "matrix_elements_o";
962 0 : kappa = 3;
963 0 : break;
964 59 : case OperatorType::MAGNETIC_DIPOLE:
965 59 : specifier = "matrix_elements_mu";
966 59 : kappa = 1;
967 59 : break;
968 4 : case OperatorType::ENERGY:
969 4 : specifier = "energy";
970 4 : kappa = 0;
971 4 : break;
972 4 : case OperatorType::IDENTITY:
973 4 : specifier = "identity";
974 4 : kappa = 0;
975 4 : break;
976 0 : default:
977 0 : throw std::invalid_argument("Unknown operator type.");
978 : }
979 :
980 4013 : if (initial_basis->get_id_of_kets() != final_basis->get_id_of_kets()) {
981 0 : throw std::invalid_argument(
982 : "The initial and final basis must be expressed using the same kets.");
983 : }
984 4013 : std::string id_of_kets = initial_basis->get_id_of_kets();
985 4013 : std::string cache_key = fmt::format("{}_{}_{}", specifier, q, id_of_kets);
986 4013 : auto &matrix_elements_cache = get_matrix_elements_cache();
987 4013 : std::promise<cached_matrix_ptr_t> matrix_promise;
988 4013 : auto [cache_it, inserted] =
989 : matrix_elements_cache.insert({cache_key, matrix_promise.get_future().share()});
990 :
991 4013 : if (inserted) {
992 : try {
993 505 : Eigen::Index dim = initial_basis->get_number_of_kets();
994 :
995 505 : std::vector<int> outerIndexPtr;
996 505 : std::vector<int> innerIndices;
997 505 : std::vector<real_t> values;
998 :
999 505 : if (specifier == "identity") {
1000 1 : outerIndexPtr.reserve(dim + 1);
1001 1 : innerIndices.reserve(dim);
1002 1 : values.reserve(dim);
1003 :
1004 91 : for (int i = 0; i < dim; i++) {
1005 90 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1006 90 : innerIndices.push_back(i);
1007 90 : values.push_back(1);
1008 : }
1009 1 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1010 :
1011 : } else {
1012 : // Check that the specifications are valid
1013 504 : if (std::abs(q) > kappa) {
1014 0 : throw std::invalid_argument("Invalid q.");
1015 : }
1016 :
1017 : // Ask the database for the operator
1018 504 : set_task_status("Loading matrix elements from database...");
1019 504 : std::string species = initial_basis->get_species();
1020 504 : duckdb::unique_ptr<duckdb::MaterializedQueryResult> result;
1021 504 : if (specifier != "energy") {
1022 2008 : result = con->Query(fmt::format(
1023 : R"(WITH s AS (
1024 : SELECT id, f, m, ketid FROM '{}'
1025 : ),
1026 : b AS (
1027 : SELECT MIN(f) AS min_f, MAX(f) AS max_f,
1028 : MIN(id) AS min_id, MAX(id) AS max_id
1029 : FROM s
1030 : ),
1031 : w_filtered AS (
1032 : SELECT *
1033 : FROM '{}'
1034 : WHERE kappa = {} AND q = {} AND
1035 : f_initial BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b) AND
1036 : f_final BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b)
1037 : ),
1038 : e_filtered AS (
1039 : SELECT *
1040 : FROM '{}'
1041 : WHERE
1042 : id_initial BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b) AND
1043 : id_final BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b)
1044 : )
1045 : SELECT
1046 : s2.ketid AS row,
1047 : s1.ketid AS col,
1048 : e.val*w.val AS val
1049 : FROM e_filtered AS e
1050 : JOIN s AS s1 ON e.id_initial = s1.id
1051 : JOIN s AS s2 ON e.id_final = s2.id
1052 : JOIN w_filtered AS w ON
1053 : w.f_initial = s1.f AND w.m_initial = s1.m AND
1054 : w.f_final = s2.f AND w.m_final = s2.m
1055 : ORDER BY row ASC, col ASC)",
1056 1004 : id_of_kets, manager->get_path("misc", "wigner"), kappa, q,
1057 : manager->get_path(species, specifier)));
1058 : } else {
1059 4 : result = con->Query(fmt::format(
1060 : R"(SELECT ketid as row, ketid as col, energy as val FROM '{}' ORDER BY row ASC)",
1061 : id_of_kets));
1062 : }
1063 :
1064 504 : if (result->HasError()) {
1065 0 : throw cpptrace::runtime_error("Error querying the database: " +
1066 0 : result->GetError());
1067 : }
1068 :
1069 : // Check the types of the columns
1070 504 : const auto &types = result->types;
1071 504 : const auto &labels = result->names;
1072 2016 : const std::vector<duckdb::LogicalType> ref_types = {duckdb::LogicalType::BIGINT,
1073 : duckdb::LogicalType::BIGINT,
1074 : duckdb::LogicalType::DOUBLE};
1075 2016 : for (size_t i = 0; i < types.size(); i++) {
1076 1512 : if (types[i] != ref_types[i]) {
1077 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'.");
1078 : }
1079 : }
1080 :
1081 504 : set_task_status("Constructing matrix elements...");
1082 :
1083 : // Construct the matrix
1084 504 : int num_entries = static_cast<int>(result->RowCount());
1085 504 : outerIndexPtr.reserve(dim + 1);
1086 504 : innerIndices.reserve(num_entries);
1087 504 : values.reserve(num_entries);
1088 :
1089 504 : int last_row = -1;
1090 :
1091 1655 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
1092 1151 : auto *chunk_row = duckdb::FlatVector::GetData<int64_t>(chunk->data[0]);
1093 1151 : auto *chunk_col = duckdb::FlatVector::GetData<int64_t>(chunk->data[1]);
1094 1151 : auto *chunk_val = duckdb::FlatVector::GetData<double>(chunk->data[2]);
1095 :
1096 1686913 : for (size_t i = 0; i < chunk->size(); i++) {
1097 1685762 : int row = final_basis->get_ket_index_from_id(chunk_row[i]);
1098 1685762 : if (row != last_row) {
1099 63114 : if (row < last_row) {
1100 0 : throw std::runtime_error("The rows are not sorted.");
1101 : }
1102 138780 : for (; last_row < row; last_row++) {
1103 75666 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1104 : }
1105 : }
1106 1685762 : innerIndices.push_back(initial_basis->get_ket_index_from_id(chunk_col[i]));
1107 1685762 : values.push_back(chunk_val[i]);
1108 : }
1109 : }
1110 :
1111 1772 : for (; last_row < dim + 1; last_row++) {
1112 1268 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1113 : }
1114 504 : }
1115 :
1116 505 : Eigen::Map<const cached_matrix_t> matrix_map(
1117 505 : dim, dim, values.size(), outerIndexPtr.data(), innerIndices.data(), values.data());
1118 :
1119 505 : auto cached_matrix = std::make_shared<const cached_matrix_t>(matrix_map);
1120 505 : matrix_promise.set_value(std::move(cached_matrix));
1121 505 : } catch (...) {
1122 0 : matrix_promise.set_exception(std::current_exception());
1123 0 : throw;
1124 : }
1125 : }
1126 :
1127 4013 : set_task_status("Returning matrix elements in canonical basis...");
1128 :
1129 8026 : return cache_it->second.get()->template cast<Scalar>();
1130 4517 : }
1131 :
1132 3 : bool Database::get_download_missing() const { return download_missing_; }
1133 :
1134 1 : bool Database::get_use_cache() const { return use_cache_; }
1135 :
1136 5 : std::filesystem::path Database::get_database_dir() const { return database_dir_; }
1137 :
1138 0 : std::string Database::get_versions_info() const { return manager->get_versions_info(); }
1139 :
1140 4013 : Database::matrix_elements_cache_t &Database::get_matrix_elements_cache() {
1141 4013 : static matrix_elements_cache_t matrix_elements_cache;
1142 4013 : return matrix_elements_cache;
1143 : }
1144 :
1145 34 : Database &Database::get_global_instance() {
1146 68 : return get_global_instance_without_checks(default_download_missing, default_use_cache,
1147 68 : default_database_dir);
1148 : }
1149 :
1150 0 : Database &Database::get_global_instance(bool download_missing) {
1151 0 : Database &database = get_global_instance_without_checks(download_missing, default_use_cache,
1152 : default_database_dir);
1153 0 : if (download_missing != database.download_missing_) {
1154 0 : throw std::invalid_argument(
1155 0 : "The 'download_missing' argument must not change between calls to the method.");
1156 : }
1157 0 : return database;
1158 : }
1159 :
1160 0 : Database &Database::get_global_instance(std::filesystem::path database_dir) {
1161 0 : if (database_dir.empty()) {
1162 0 : database_dir = default_database_dir;
1163 : }
1164 0 : Database &database = get_global_instance_without_checks(default_download_missing,
1165 : default_use_cache, database_dir);
1166 0 : if (!std::filesystem::exists(database_dir) ||
1167 0 : std::filesystem::canonical(database_dir) != database.database_dir_) {
1168 0 : throw std::invalid_argument(
1169 0 : "The 'database_dir' argument must not change between calls to the method.");
1170 : }
1171 0 : return database;
1172 : }
1173 :
1174 1 : Database &Database::get_global_instance(bool download_missing, bool use_cache,
1175 : std::filesystem::path database_dir) {
1176 1 : if (database_dir.empty()) {
1177 0 : database_dir = default_database_dir;
1178 : }
1179 : Database &database =
1180 1 : get_global_instance_without_checks(download_missing, use_cache, database_dir);
1181 1 : if (download_missing != database.download_missing_ || use_cache != database.use_cache_ ||
1182 3 : !std::filesystem::exists(database_dir) ||
1183 2 : std::filesystem::canonical(database_dir) != database.database_dir_) {
1184 0 : throw std::invalid_argument(
1185 : "The 'download_missing', 'use_cache' and 'database_dir' arguments must not "
1186 0 : "change between calls to the method.");
1187 : }
1188 1 : return database;
1189 : }
1190 :
1191 35 : Database &Database::get_global_instance_without_checks(bool download_missing, bool use_cache,
1192 : std::filesystem::path database_dir) {
1193 35 : static Database database(download_missing, use_cache, std::move(database_dir));
1194 35 : return database;
1195 : }
1196 :
1197 : struct database_dir_noexcept : std::filesystem::path {
1198 2 : explicit database_dir_noexcept() noexcept try : std
1199 2 : ::filesystem::path(paths::get_cache_directory() / "database") {}
1200 0 : catch (...) {
1201 0 : SPDLOG_ERROR("Error getting the PairInteraction cache directory.");
1202 0 : std::terminate();
1203 2 : }
1204 : };
1205 :
1206 : const std::filesystem::path Database::default_database_dir = database_dir_noexcept();
1207 :
1208 : // Explicit instantiations
1209 : // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
1210 : #define INSTANTIATE_GETTERS(SCALAR) \
1211 : template std::shared_ptr<const BasisAtom<SCALAR>> Database::get_basis<SCALAR>( \
1212 : const std::string &species, const AtomDescriptionByRanges &description, \
1213 : std::vector<size_t> additional_ket_ids); \
1214 : template Eigen::SparseMatrix<SCALAR, Eigen::RowMajor> \
1215 : Database::get_matrix_elements_in_canonical_basis<SCALAR>( \
1216 : std::shared_ptr<const BasisAtom<SCALAR>> initial_basis, \
1217 : std::shared_ptr<const BasisAtom<SCALAR>> final_basis, OperatorType type, int q);
1218 : // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
1219 :
1220 : INSTANTIATE_GETTERS(double)
1221 : INSTANTIATE_GETTERS(std::complex<double>)
1222 :
1223 : #undef INSTANTIATE_GETTERS
1224 : } // namespace pairinteraction
|