Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/Database.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/database/AtomDescriptionByParameters.hpp"
8 : #include "pairinteraction/database/AtomDescriptionByRanges.hpp"
9 : #include "pairinteraction/database/GitHubDownloader.hpp"
10 : #include "pairinteraction/database/ParquetManager.hpp"
11 : #include "pairinteraction/enums/OperatorType.hpp"
12 : #include "pairinteraction/enums/Parity.hpp"
13 : #include "pairinteraction/ket/KetAtom.hpp"
14 : #include "pairinteraction/utils/hash.hpp"
15 : #include "pairinteraction/utils/id_in_database.hpp"
16 : #include "pairinteraction/utils/paths.hpp"
17 : #include "pairinteraction/utils/streamed.hpp"
18 :
19 : #include <cpptrace/cpptrace.hpp>
20 : #include <duckdb.hpp>
21 : #include <fmt/core.h>
22 : #include <fmt/ranges.h>
23 : #include <fstream>
24 : #include <nlohmann/json.hpp>
25 : #include <oneapi/tbb.h>
26 : #include <spdlog/spdlog.h>
27 : #include <system_error>
28 :
29 : namespace pairinteraction {
30 0 : Database::Database() : Database(default_download_missing) {}
31 :
32 0 : Database::Database(bool download_missing)
33 0 : : Database(download_missing, default_use_cache, default_database_dir) {}
34 :
35 0 : Database::Database(std::filesystem::path database_dir)
36 0 : : Database(default_download_missing, default_use_cache, std::move(database_dir)) {}
37 :
38 2 : Database::Database(bool download_missing, bool use_cache, std::filesystem::path database_dir)
39 2 : : download_missing_(download_missing), use_cache_(use_cache),
40 2 : database_dir_(std::move(database_dir)), db(std::make_unique<duckdb::DuckDB>(nullptr)),
41 10 : con(std::make_unique<duckdb::Connection>(*db)) {
42 :
43 2 : if (database_dir_.empty()) {
44 0 : database_dir_ = default_database_dir;
45 : }
46 :
47 : // Ensure the database directory exists
48 2 : if (!std::filesystem::exists(database_dir_)) {
49 0 : std::filesystem::create_directories(database_dir_);
50 : }
51 2 : database_dir_ = std::filesystem::canonical(database_dir_);
52 2 : if (!std::filesystem::is_directory(database_dir_)) {
53 0 : throw std::filesystem::filesystem_error("Cannot access database", database_dir_.string(),
54 0 : std::make_error_code(std::errc::not_a_directory));
55 : }
56 4 : SPDLOG_INFO("Using database directory: {}", database_dir_.string());
57 :
58 : // Ensure that the config directory exists
59 2 : std::filesystem::path configdir = paths::get_config_directory();
60 2 : if (!std::filesystem::exists(configdir)) {
61 1 : std::filesystem::create_directories(configdir);
62 1 : } else if (!std::filesystem::is_directory(configdir)) {
63 0 : throw std::filesystem::filesystem_error("Cannot access config directory ",
64 0 : configdir.string(),
65 0 : std::make_error_code(std::errc::not_a_directory));
66 : }
67 :
68 : // Read in the database_repo_paths if a config file exists, otherwise use the default and
69 : // write it to the config file
70 2 : std::filesystem::path configfile = configdir / "database.json";
71 2 : std::string database_repo_host;
72 2 : std::vector<std::string> database_repo_paths;
73 2 : if (std::filesystem::exists(configfile)) {
74 1 : std::ifstream file(configfile);
75 1 : nlohmann::json doc = nlohmann::json::parse(file, nullptr, false);
76 :
77 2 : if (!doc.is_discarded() && doc.contains("hash") && doc.contains("database_repo_host") &&
78 1 : doc.contains("database_repo_paths")) {
79 1 : database_repo_host = doc["database_repo_host"].get<std::string>();
80 1 : database_repo_paths = doc["database_repo_paths"].get<std::vector<std::string>>();
81 :
82 : // If the values are not equal to the default values but the hash is consistent (i.e.,
83 : // the user has not changed anything manually), clear the values so that they can be
84 : // updated
85 2 : if (database_repo_host != default_database_repo_host ||
86 1 : database_repo_paths != default_database_repo_paths) {
87 0 : std::size_t seed = 0;
88 0 : utils::hash_combine(seed, database_repo_paths);
89 0 : utils::hash_combine(seed, database_repo_host);
90 0 : if (seed == doc["hash"].get<std::size_t>()) {
91 0 : database_repo_host.clear();
92 0 : database_repo_paths.clear();
93 : } else {
94 0 : SPDLOG_INFO("The database repository host and paths have been changed "
95 : "manually. Thus, they will not be updated automatically. To reset "
96 : "them, delete the file '{}'.",
97 : configfile.string());
98 : }
99 : }
100 : }
101 1 : }
102 :
103 : // Read in and store the default values if necessary
104 2 : if (database_repo_host.empty() || database_repo_paths.empty()) {
105 2 : SPDLOG_INFO("Updating the database repository host and paths:");
106 :
107 1 : database_repo_host = default_database_repo_host;
108 1 : database_repo_paths = default_database_repo_paths;
109 1 : std::ofstream file(configfile);
110 1 : nlohmann::json doc;
111 :
112 2 : SPDLOG_INFO("* New host: {}", default_database_repo_host);
113 3 : SPDLOG_INFO("* New paths: {}", fmt::join(default_database_repo_paths, ", "));
114 :
115 1 : doc["database_repo_host"] = default_database_repo_host;
116 1 : doc["database_repo_paths"] = database_repo_paths;
117 :
118 1 : std::size_t seed = 0;
119 1 : utils::hash_combine(seed, default_database_repo_paths);
120 1 : utils::hash_combine(seed, default_database_repo_host);
121 1 : doc["hash"] = seed;
122 :
123 1 : file << doc.dump(4);
124 1 : }
125 :
126 : // Limit the memory usage of duckdb's buffer manager
127 : {
128 2 : auto result = con->Query("PRAGMA max_memory = '8GB';");
129 2 : if (result->HasError()) {
130 0 : throw cpptrace::runtime_error("Error setting the memory limit: " + result->GetError());
131 : }
132 2 : }
133 :
134 : // Instantiate a database manager that provides access to database tables. If a table
135 : // is outdated/not available locally, it will be downloaded if download_missing_ is true.
136 2 : if (!download_missing_) {
137 2 : database_repo_paths.clear();
138 : }
139 2 : downloader = std::make_unique<GitHubDownloader>();
140 4 : manager = std::make_unique<ParquetManager>(database_dir_, *downloader, database_repo_paths,
141 4 : *con, use_cache_);
142 2 : manager->scan_local();
143 2 : manager->scan_remote();
144 :
145 : // Print versions of tables
146 2 : std::istringstream iss(manager->get_versions_info());
147 24 : for (std::string line; std::getline(iss, line);) {
148 22 : SPDLOG_INFO(line);
149 2 : }
150 6 : }
151 :
152 2 : Database::~Database() = default;
153 :
154 101 : std::shared_ptr<const KetAtom> Database::get_ket(const std::string &species,
155 : const AtomDescriptionByParameters &description) {
156 : // Check that the specifications are valid
157 101 : if (!description.quantum_number_m.has_value()) {
158 0 : throw std::invalid_argument("The quantum number m must be specified.");
159 : }
160 122 : if (description.quantum_number_f.has_value() &&
161 21 : 2 * description.quantum_number_f.value() !=
162 21 : std::rint(2 * description.quantum_number_f.value())) {
163 0 : throw std::invalid_argument("The quantum number f must be an integer or half-integer.");
164 : }
165 101 : if (description.quantum_number_f.has_value() && description.quantum_number_f.value() < 0) {
166 0 : throw std::invalid_argument("The quantum number f must be positive.");
167 : }
168 165 : if (description.quantum_number_j.has_value() &&
169 64 : 2 * description.quantum_number_j.value() !=
170 64 : std::rint(2 * description.quantum_number_j.value())) {
171 0 : throw std::invalid_argument("The quantum number j must be an integer or half-integer.");
172 : }
173 101 : if (description.quantum_number_j.has_value() && description.quantum_number_j.value() < 0) {
174 0 : throw std::invalid_argument("The quantum number j must be positive.");
175 : }
176 202 : if (description.quantum_number_m.has_value() &&
177 101 : 2 * description.quantum_number_m.value() !=
178 101 : std::rint(2 * description.quantum_number_m.value())) {
179 0 : throw std::invalid_argument("The quantum number m must be an integer or half-integer.");
180 : }
181 :
182 : // Describe the state
183 101 : std::string where;
184 101 : std::string separator;
185 101 : if (description.energy.has_value()) {
186 : // The following condition derives from demanding that quantum number n that corresponds to
187 : // the energy "E_n = -1/(2*n^2)" is not off by more than 1 from the actual quantum number n,
188 : // i.e., "sqrt(-1/(2*E_n)) - sqrt(-1/(2*E_{n-1})) = 1"
189 0 : where += separator +
190 0 : fmt::format("SQRT(-1/(2*energy)) BETWEEN {} AND {}",
191 0 : std::sqrt(-1 / (2 * description.energy.value())) - 0.5,
192 0 : std::sqrt(-1 / (2 * description.energy.value())) + 0.5);
193 0 : separator = " AND ";
194 : }
195 101 : if (description.quantum_number_f.has_value()) {
196 42 : where += separator + fmt::format("f = {}", description.quantum_number_f.value());
197 21 : separator = " AND ";
198 : }
199 101 : if (description.parity != Parity::UNKNOWN) {
200 0 : where += separator + fmt::format("parity = {}", fmt::streamed(description.parity));
201 0 : separator = " AND ";
202 : }
203 101 : if (description.quantum_number_n.has_value()) {
204 196 : where += separator + fmt::format("n = {}", description.quantum_number_n.value());
205 98 : separator = " AND ";
206 : }
207 101 : if (description.quantum_number_nu.has_value()) {
208 3 : where += separator +
209 9 : fmt::format("nu BETWEEN {} AND {}", description.quantum_number_nu.value() - 0.5,
210 6 : description.quantum_number_nu.value() + 0.5);
211 3 : separator = " AND ";
212 : }
213 101 : if (description.quantum_number_nui.has_value()) {
214 0 : where += separator +
215 0 : fmt::format("exp_nui BETWEEN {} AND {}", description.quantum_number_nui.value() - 0.5,
216 0 : description.quantum_number_nui.value() + 0.5);
217 0 : separator = " AND ";
218 : }
219 101 : if (description.quantum_number_l.has_value()) {
220 101 : where += separator +
221 303 : fmt::format("exp_l BETWEEN {} AND {}", description.quantum_number_l.value() - 0.5,
222 202 : description.quantum_number_l.value() + 0.5);
223 101 : separator = " AND ";
224 : }
225 101 : if (description.quantum_number_s.has_value()) {
226 20 : where += separator +
227 60 : fmt::format("exp_s BETWEEN {} AND {}", description.quantum_number_s.value() - 0.5,
228 40 : description.quantum_number_s.value() + 0.5);
229 20 : separator = " AND ";
230 : }
231 101 : if (description.quantum_number_j.has_value()) {
232 64 : where += separator +
233 192 : fmt::format("exp_j BETWEEN {} AND {}", description.quantum_number_j.value() - 0.5,
234 128 : description.quantum_number_j.value() + 0.5);
235 64 : separator = " AND ";
236 : }
237 101 : if (description.quantum_number_l_ryd.has_value()) {
238 0 : where += separator +
239 0 : fmt::format("exp_l_ryd BETWEEN {} AND {}",
240 0 : description.quantum_number_l_ryd.value() - 0.5,
241 0 : description.quantum_number_l_ryd.value() + 0.5);
242 0 : separator = " AND ";
243 : }
244 101 : if (description.quantum_number_j_ryd.has_value()) {
245 0 : where += separator +
246 0 : fmt::format("exp_j_ryd BETWEEN {} AND {}",
247 0 : description.quantum_number_j_ryd.value() - 0.5,
248 0 : description.quantum_number_j_ryd.value() + 0.5);
249 0 : separator = " AND ";
250 : }
251 101 : if (separator.empty()) {
252 0 : where += "FALSE";
253 : }
254 :
255 101 : std::string orderby;
256 101 : separator = "";
257 101 : if (description.energy.has_value()) {
258 0 : orderby += separator +
259 0 : fmt::format("(SQRT(-1/(2*energy)) - {})^2",
260 0 : std::sqrt(-1 / (2 * description.energy.value())));
261 0 : separator = " + ";
262 : }
263 101 : if (description.quantum_number_nu.has_value()) {
264 6 : orderby += separator + fmt::format("(nu - {})^2", description.quantum_number_nu.value());
265 3 : separator = " + ";
266 : }
267 101 : if (description.quantum_number_nui.has_value()) {
268 : orderby +=
269 0 : separator + fmt::format("(exp_nui - {})^2", description.quantum_number_nui.value());
270 0 : separator = " + ";
271 : }
272 101 : if (description.quantum_number_l.has_value()) {
273 202 : orderby += separator + fmt::format("(exp_l - {})^2", description.quantum_number_l.value());
274 101 : separator = " + ";
275 : }
276 101 : if (description.quantum_number_s.has_value()) {
277 40 : orderby += separator + fmt::format("(exp_s - {})^2", description.quantum_number_s.value());
278 20 : separator = " + ";
279 : }
280 101 : if (description.quantum_number_j.has_value()) {
281 128 : orderby += separator + fmt::format("(exp_j - {})^2", description.quantum_number_j.value());
282 64 : separator = " + ";
283 : }
284 101 : if (description.quantum_number_l_ryd.has_value()) {
285 : orderby +=
286 0 : separator + fmt::format("(exp_l_ryd - {})^2", description.quantum_number_l_ryd.value());
287 0 : separator = " + ";
288 : }
289 101 : if (description.quantum_number_j_ryd.has_value()) {
290 : orderby +=
291 0 : separator + fmt::format("(exp_j_ryd - {})^2", description.quantum_number_j_ryd.value());
292 0 : separator = " + ";
293 : }
294 101 : if (separator.empty()) {
295 0 : orderby += "id";
296 : }
297 :
298 : // Ask the database for the described state
299 101 : auto result = con->Query(fmt::format(
300 : R"(SELECT energy, f, parity, id, n, nu, exp_nui, std_nui, exp_l, std_l, exp_s, std_s,
301 : exp_j, std_j, exp_l_ryd, std_l_ryd, exp_j_ryd, std_j_ryd, is_j_total_momentum, is_calculated_with_mqdt, underspecified_channel_contribution, {} AS order_val FROM '{}' WHERE {} ORDER BY order_val ASC LIMIT 2)",
302 303 : orderby, manager->get_path(species, "states"), where));
303 :
304 101 : if (result->HasError()) {
305 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
306 : }
307 :
308 101 : if (result->RowCount() == 0) {
309 1 : throw std::invalid_argument("No state found.");
310 : }
311 :
312 : // Check the types of the columns
313 100 : const auto &types = result->types;
314 100 : const auto &labels = result->names;
315 : const std::vector<duckdb::LogicalType> ref_types = {
316 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::BIGINT,
317 : duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT, duckdb::LogicalType::DOUBLE,
318 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
319 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
320 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
321 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
322 : duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::DOUBLE,
323 2300 : duckdb::LogicalType::DOUBLE};
324 :
325 2300 : for (size_t i = 0; i < types.size(); i++) {
326 2200 : if (types[i] != ref_types[i]) {
327 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'. Got " +
328 0 : types[i].ToString() + " but expected " +
329 0 : ref_types[i].ToString());
330 : }
331 : }
332 :
333 : // Get the first chunk of the results (the first chunk is sufficient as we need two rows at
334 : // most)
335 100 : auto chunk = result->Fetch();
336 :
337 : // Check that the ket is uniquely specified
338 100 : if (chunk->size() > 1) {
339 5 : auto order_val_0 = duckdb::FlatVector::GetData<double>(chunk->data[21])[0];
340 5 : auto order_val_1 = duckdb::FlatVector::GetData<double>(chunk->data[21])[1];
341 :
342 5 : if (order_val_1 - order_val_0 <= order_val_0) {
343 : // Get a list of possible kets
344 1 : std::vector<KetAtom> kets;
345 1 : kets.reserve(2);
346 3 : for (size_t i = 0; i < 2; ++i) {
347 2 : auto result_quantum_number_m = description.quantum_number_m.value();
348 2 : auto result_energy = duckdb::FlatVector::GetData<double>(chunk->data[0])[i];
349 : auto result_quantum_number_f =
350 2 : duckdb::FlatVector::GetData<double>(chunk->data[1])[i];
351 2 : auto result_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[2])[i];
352 4 : auto result_id = utils::get_linearized_id_in_database(
353 2 : duckdb::FlatVector::GetData<int64_t>(chunk->data[3])[i],
354 2 : result_quantum_number_m);
355 : auto result_quantum_number_n =
356 2 : duckdb::FlatVector::GetData<int64_t>(chunk->data[4])[i];
357 : auto result_quantum_number_nu =
358 2 : duckdb::FlatVector::GetData<double>(chunk->data[5])[i];
359 : auto result_quantum_number_nui_exp =
360 2 : duckdb::FlatVector::GetData<double>(chunk->data[6])[i];
361 : auto result_quantum_number_nui_std =
362 2 : duckdb::FlatVector::GetData<double>(chunk->data[7])[i];
363 : auto result_quantum_number_l_exp =
364 2 : duckdb::FlatVector::GetData<double>(chunk->data[8])[i];
365 : auto result_quantum_number_l_std =
366 2 : duckdb::FlatVector::GetData<double>(chunk->data[9])[i];
367 : auto result_quantum_number_s_exp =
368 2 : duckdb::FlatVector::GetData<double>(chunk->data[10])[i];
369 : auto result_quantum_number_s_std =
370 2 : duckdb::FlatVector::GetData<double>(chunk->data[11])[i];
371 : auto result_quantum_number_j_exp =
372 2 : duckdb::FlatVector::GetData<double>(chunk->data[12])[i];
373 : auto result_quantum_number_j_std =
374 2 : duckdb::FlatVector::GetData<double>(chunk->data[13])[i];
375 : auto result_quantum_number_l_ryd_exp =
376 2 : duckdb::FlatVector::GetData<double>(chunk->data[14])[i];
377 : auto result_quantum_number_l_ryd_std =
378 2 : duckdb::FlatVector::GetData<double>(chunk->data[15])[i];
379 : auto result_quantum_number_j_ryd_exp =
380 2 : duckdb::FlatVector::GetData<double>(chunk->data[16])[i];
381 : auto result_quantum_number_j_ryd_std =
382 2 : duckdb::FlatVector::GetData<double>(chunk->data[17])[i];
383 : auto result_is_j_total_momentum =
384 2 : duckdb::FlatVector::GetData<bool>(chunk->data[18])[i];
385 : auto result_is_calculated_with_mqdt =
386 2 : duckdb::FlatVector::GetData<bool>(chunk->data[19])[i];
387 : auto result_underspecified_channel_contribution =
388 2 : duckdb::FlatVector::GetData<double>(chunk->data[20])[i];
389 0 : kets.emplace_back(typename KetAtom::Private(), result_energy,
390 : result_quantum_number_f, result_quantum_number_m,
391 2 : static_cast<Parity>(result_parity), species,
392 : result_quantum_number_n, result_quantum_number_nu,
393 : result_quantum_number_nui_exp, result_quantum_number_nui_std,
394 : result_quantum_number_l_exp, result_quantum_number_l_std,
395 : result_quantum_number_s_exp, result_quantum_number_s_std,
396 : result_quantum_number_j_exp, result_quantum_number_j_std,
397 : result_quantum_number_l_ryd_exp, result_quantum_number_l_ryd_std,
398 : result_quantum_number_j_ryd_exp, result_quantum_number_j_ryd_std,
399 : result_is_j_total_momentum, result_is_calculated_with_mqdt,
400 : result_underspecified_channel_contribution, *this, result_id);
401 : }
402 :
403 : // Throw an error with the possible kets
404 2 : throw std::invalid_argument(
405 1 : fmt::format("The ket is not uniquely specified. Possible kets are:\n{}\n{}",
406 2 : fmt::streamed(kets[0]), fmt::streamed(kets[1])));
407 1 : }
408 : }
409 :
410 : // Construct the state
411 99 : auto result_quantum_number_m = description.quantum_number_m.value();
412 99 : auto result_energy = duckdb::FlatVector::GetData<double>(chunk->data[0])[0];
413 99 : auto result_quantum_number_f = duckdb::FlatVector::GetData<double>(chunk->data[1])[0];
414 99 : auto result_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[2])[0];
415 198 : auto result_id = utils::get_linearized_id_in_database(
416 99 : duckdb::FlatVector::GetData<int64_t>(chunk->data[3])[0], result_quantum_number_m);
417 99 : auto result_quantum_number_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[4])[0];
418 99 : auto result_quantum_number_nu = duckdb::FlatVector::GetData<double>(chunk->data[5])[0];
419 99 : auto result_quantum_number_nui_exp = duckdb::FlatVector::GetData<double>(chunk->data[6])[0];
420 99 : auto result_quantum_number_nui_std = duckdb::FlatVector::GetData<double>(chunk->data[7])[0];
421 99 : auto result_quantum_number_l_exp = duckdb::FlatVector::GetData<double>(chunk->data[8])[0];
422 99 : auto result_quantum_number_l_std = duckdb::FlatVector::GetData<double>(chunk->data[9])[0];
423 99 : auto result_quantum_number_s_exp = duckdb::FlatVector::GetData<double>(chunk->data[10])[0];
424 99 : auto result_quantum_number_s_std = duckdb::FlatVector::GetData<double>(chunk->data[11])[0];
425 99 : auto result_quantum_number_j_exp = duckdb::FlatVector::GetData<double>(chunk->data[12])[0];
426 99 : auto result_quantum_number_j_std = duckdb::FlatVector::GetData<double>(chunk->data[13])[0];
427 99 : auto result_quantum_number_l_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[14])[0];
428 99 : auto result_quantum_number_l_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[15])[0];
429 99 : auto result_quantum_number_j_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[16])[0];
430 99 : auto result_quantum_number_j_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[17])[0];
431 99 : auto result_is_j_total_momentum = duckdb::FlatVector::GetData<bool>(chunk->data[18])[0];
432 99 : auto result_is_calculated_with_mqdt = duckdb::FlatVector::GetData<bool>(chunk->data[19])[0];
433 : auto result_underspecified_channel_contribution =
434 99 : duckdb::FlatVector::GetData<double>(chunk->data[20])[0];
435 :
436 : // Check the quantum number m
437 99 : if (std::abs(result_quantum_number_m) > result_quantum_number_f) {
438 1 : throw std::invalid_argument(
439 2 : "The absolute value of the quantum number m must be less than or equal to f.");
440 : }
441 98 : if (result_quantum_number_f + result_quantum_number_m !=
442 98 : std::rint(result_quantum_number_f + result_quantum_number_m)) {
443 0 : throw std::invalid_argument(
444 0 : "The quantum numbers f and m must be both either integers or half-integers.");
445 : }
446 :
447 : #ifndef NDEBUG
448 : // Check database consistency
449 98 : if (result_is_j_total_momentum && result_quantum_number_f != result_quantum_number_j_exp) {
450 0 : throw std::runtime_error("If j is the total momentum, f must be equal to j.");
451 : }
452 : #endif
453 :
454 : return std::make_shared<const KetAtom>(
455 0 : typename KetAtom::Private(), result_energy, result_quantum_number_f,
456 0 : result_quantum_number_m, static_cast<Parity>(result_parity), species,
457 : result_quantum_number_n, result_quantum_number_nu, result_quantum_number_nui_exp,
458 : result_quantum_number_nui_std, result_quantum_number_l_exp, result_quantum_number_l_std,
459 : result_quantum_number_s_exp, result_quantum_number_s_std, result_quantum_number_j_exp,
460 : result_quantum_number_j_std, result_quantum_number_l_ryd_exp,
461 : result_quantum_number_l_ryd_std, result_quantum_number_j_ryd_exp,
462 : result_quantum_number_j_ryd_std, result_is_j_total_momentum, result_is_calculated_with_mqdt,
463 196 : result_underspecified_channel_contribution, *this, result_id);
464 214 : }
465 :
466 : template <typename Scalar>
467 : std::shared_ptr<const BasisAtom<Scalar>>
468 88 : Database::get_basis(const std::string &species, const AtomDescriptionByRanges &description,
469 : std::vector<size_t> additional_ket_ids) {
470 : // Describe the states
471 88 : std::string where = "(";
472 88 : std::string separator;
473 88 : if (description.parity != Parity::UNKNOWN) {
474 0 : where += separator + fmt::format("parity = {}", fmt::streamed(description.parity));
475 0 : separator = " AND ";
476 : }
477 88 : if (description.range_energy.is_finite()) {
478 27 : where += separator +
479 27 : fmt::format("energy BETWEEN {} AND {}", description.range_energy.min(),
480 27 : description.range_energy.max());
481 27 : separator = " AND ";
482 : }
483 88 : if (description.range_quantum_number_f.is_finite()) {
484 0 : where += separator +
485 0 : fmt::format("f BETWEEN {} AND {}", description.range_quantum_number_f.min(),
486 0 : description.range_quantum_number_f.max());
487 0 : separator = " AND ";
488 : }
489 88 : if (description.range_quantum_number_m.is_finite()) {
490 31 : where += separator +
491 31 : fmt::format("m BETWEEN {} AND {}", description.range_quantum_number_m.min(),
492 31 : description.range_quantum_number_m.max());
493 31 : separator = " AND ";
494 : }
495 88 : if (description.range_quantum_number_n.is_finite()) {
496 72 : where += separator +
497 72 : fmt::format("n BETWEEN {} AND {}", description.range_quantum_number_n.min(),
498 72 : description.range_quantum_number_n.max());
499 72 : separator = " AND ";
500 : }
501 88 : if (description.range_quantum_number_nu.is_finite()) {
502 3 : where += separator +
503 3 : fmt::format("nu BETWEEN {} AND {}", description.range_quantum_number_nu.min(),
504 3 : description.range_quantum_number_nu.max());
505 3 : separator = " AND ";
506 : }
507 88 : if (description.range_quantum_number_nui.is_finite()) {
508 0 : where += separator +
509 : fmt::format("exp_nui BETWEEN {}-2*std_nui AND {}+2*std_nui",
510 0 : description.range_quantum_number_nui.min(),
511 0 : description.range_quantum_number_nui.max());
512 0 : separator = " AND ";
513 : }
514 88 : if (description.range_quantum_number_l.is_finite()) {
515 74 : where += separator +
516 : fmt::format("exp_l BETWEEN {}-2*std_l AND {}+2*std_l",
517 74 : description.range_quantum_number_l.min(),
518 74 : description.range_quantum_number_l.max());
519 74 : separator = " AND ";
520 : }
521 88 : if (description.range_quantum_number_s.is_finite()) {
522 0 : where += separator +
523 : fmt::format("exp_s BETWEEN {}-2*std_s AND {}+2*std_s",
524 0 : description.range_quantum_number_s.min(),
525 0 : description.range_quantum_number_s.max());
526 0 : separator = " AND ";
527 : }
528 88 : if (description.range_quantum_number_j.is_finite()) {
529 0 : where += separator +
530 : fmt::format("exp_j BETWEEN {}-2*std_j AND {}+2*std_j",
531 0 : description.range_quantum_number_j.min(),
532 0 : description.range_quantum_number_j.max());
533 0 : separator = " AND ";
534 : }
535 88 : if (description.range_quantum_number_l_ryd.is_finite()) {
536 0 : where += separator +
537 : fmt::format("exp_l_ryd BETWEEN {}-2*std_l_ryd AND {}+2*std_l_ryd",
538 0 : description.range_quantum_number_l_ryd.min(),
539 0 : description.range_quantum_number_l_ryd.max());
540 0 : separator = " AND ";
541 : }
542 88 : if (description.range_quantum_number_j_ryd.is_finite()) {
543 0 : where += separator +
544 : fmt::format("exp_j_ryd BETWEEN {}-2*std_j_ryd AND {}+2*std_j_ryd",
545 0 : description.range_quantum_number_j_ryd.min(),
546 0 : description.range_quantum_number_j_ryd.max());
547 0 : separator = " AND ";
548 : }
549 88 : if (separator.empty()) {
550 13 : where += "FALSE";
551 : }
552 88 : where += ")";
553 88 : if (!additional_ket_ids.empty()) {
554 35 : where += fmt::format(" OR {} IN ({})", utils::SQL_TERM_FOR_LINEARIZED_ID_IN_DATABASE,
555 70 : fmt::join(additional_ket_ids, ","));
556 : }
557 :
558 : // Create a table containing the described states
559 88 : std::string id_of_kets;
560 : {
561 88 : auto result = con->Query(R"(SELECT UUID()::varchar)");
562 88 : if (result->HasError()) {
563 0 : throw cpptrace::runtime_error("Error selecting id_of_kets: " + result->GetError());
564 : }
565 88 : id_of_kets =
566 176 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
567 88 : }
568 : {
569 176 : auto result = con->Query(fmt::format(
570 : R"(CREATE TEMP TABLE '{}' AS SELECT *, {} AS ketid FROM (
571 : SELECT *,
572 : UNNEST(list_transform(generate_series(0,(2*f)::bigint),
573 : x -> x::double-f)) AS m FROM '{}'
574 : ) WHERE {})",
575 : id_of_kets, utils::SQL_TERM_FOR_LINEARIZED_ID_IN_DATABASE,
576 : manager->get_path(species, "states"), where));
577 :
578 88 : if (result->HasError()) {
579 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
580 : }
581 88 : }
582 :
583 : // Ask the table for the extreme values of the quantum numbers
584 : {
585 88 : std::string select;
586 88 : std::string separator;
587 88 : if (description.range_energy.is_finite()) {
588 27 : select += separator + "MIN(energy) AS min_energy, MAX(energy) AS max_energy";
589 27 : separator = ", ";
590 : }
591 88 : if (description.range_quantum_number_f.is_finite()) {
592 0 : select += separator + "MIN(f) AS min_f, MAX(f) AS max_f";
593 0 : separator = ", ";
594 : }
595 88 : if (description.range_quantum_number_m.is_finite()) {
596 31 : select += separator + "MIN(m) AS min_m, MAX(m) AS max_m";
597 31 : separator = ", ";
598 : }
599 88 : if (description.range_quantum_number_n.is_finite()) {
600 72 : select += separator + "MIN(n) AS min_n, MAX(n) AS max_n";
601 72 : separator = ", ";
602 : }
603 88 : if (description.range_quantum_number_nu.is_finite()) {
604 3 : select += separator + "MIN(nu) AS min_nu, MAX(nu) AS max_nu";
605 3 : separator = ", ";
606 : }
607 88 : if (description.range_quantum_number_nui.is_finite()) {
608 0 : select += separator + "MIN(exp_nui) AS min_nui, MAX(exp_nui) AS max_nui";
609 0 : separator = ", ";
610 : }
611 88 : if (description.range_quantum_number_l.is_finite()) {
612 74 : select += separator + "MIN(exp_l) AS min_l, MAX(exp_l) AS max_l";
613 74 : separator = ", ";
614 : }
615 88 : if (description.range_quantum_number_s.is_finite()) {
616 0 : select += separator + "MIN(exp_s) AS min_s, MAX(exp_s) AS max_s";
617 0 : separator = ", ";
618 : }
619 88 : if (description.range_quantum_number_j.is_finite()) {
620 0 : select += separator + "MIN(exp_j) AS min_j, MAX(exp_j) AS max_j";
621 0 : separator = ", ";
622 : }
623 88 : if (description.range_quantum_number_l_ryd.is_finite()) {
624 0 : select += separator + "MIN(exp_l_ryd) AS min_l_ryd, MAX(exp_l_ryd) AS max_l_ryd";
625 0 : separator = ", ";
626 : }
627 88 : if (description.range_quantum_number_j_ryd.is_finite()) {
628 0 : select += separator + "MIN(exp_j_ryd) AS min_j_ryd, MAX(exp_j_ryd) AS max_j_ryd";
629 0 : separator = ", ";
630 : }
631 :
632 88 : if (!separator.empty()) {
633 150 : auto result = con->Query(fmt::format(R"(SELECT {} FROM '{}')", select, id_of_kets));
634 :
635 75 : if (result->HasError()) {
636 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
637 : }
638 :
639 75 : auto chunk = result->Fetch();
640 :
641 489 : for (size_t i = 0; i < chunk->ColumnCount(); i++) {
642 414 : if (duckdb::FlatVector::IsNull(chunk->data[i], 0)) {
643 0 : throw std::invalid_argument("No state found.");
644 : }
645 : }
646 :
647 75 : size_t idx = 0;
648 75 : if (description.range_energy.is_finite()) {
649 27 : auto min_energy = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
650 54 : if (std::sqrt(-1 / (2 * min_energy)) - 1 >
651 27 : std::sqrt(-1 / (2 * description.range_energy.min()))) {
652 0 : SPDLOG_DEBUG("No state found with the requested minimum energy. Requested: {}, "
653 : "found: {}.",
654 : description.range_energy.min(), min_energy);
655 : }
656 27 : auto max_energy = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
657 54 : if (std::sqrt(-1 / (2 * max_energy)) + 1 <
658 27 : std::sqrt(-1 / (2 * description.range_energy.max()))) {
659 0 : SPDLOG_DEBUG("No state found with the requested maximum energy. Requested: {}, "
660 : "found: {}.",
661 : description.range_energy.max(), max_energy);
662 : }
663 : }
664 75 : if (description.range_quantum_number_f.is_finite()) {
665 0 : auto min_f = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
666 0 : if (min_f > description.range_quantum_number_f.min()) {
667 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number f. "
668 : "Requested: {}, found: {}.",
669 : description.range_quantum_number_f.min(), min_f);
670 : }
671 0 : auto max_f = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
672 0 : if (max_f < description.range_quantum_number_f.max()) {
673 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number f. "
674 : "Requested: {}, found: {}.",
675 : description.range_quantum_number_f.max(), max_f);
676 : }
677 : }
678 75 : if (description.range_quantum_number_m.is_finite()) {
679 31 : auto min_m = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
680 31 : if (min_m > description.range_quantum_number_m.min()) {
681 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number m. "
682 : "Requested: {}, found: {}.",
683 : description.range_quantum_number_m.min(), min_m);
684 : }
685 31 : auto max_m = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
686 31 : if (max_m < description.range_quantum_number_m.max()) {
687 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number m. "
688 : "Requested: {}, found: {}.",
689 : description.range_quantum_number_m.max(), max_m);
690 : }
691 : }
692 75 : if (description.range_quantum_number_n.is_finite()) {
693 72 : auto min_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[idx++])[0];
694 72 : if (min_n > description.range_quantum_number_n.min()) {
695 44 : SPDLOG_DEBUG("No state found with the requested minimum quantum number n. "
696 : "Requested: {}, found: {}.",
697 : description.range_quantum_number_n.min(), min_n);
698 : }
699 72 : auto max_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[idx++])[0];
700 72 : if (max_n < description.range_quantum_number_n.max()) {
701 42 : SPDLOG_DEBUG("No state found with the requested maximum quantum number n. "
702 : "Requested: {}, found: {}.",
703 : description.range_quantum_number_n.max(), max_n);
704 : }
705 : }
706 75 : if (description.range_quantum_number_nu.is_finite()) {
707 3 : auto min_nu = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
708 3 : if (min_nu - 1 > description.range_quantum_number_nu.min()) {
709 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number nu. "
710 : "Requested: {}, found: {}.",
711 : description.range_quantum_number_nu.min(), min_nu);
712 : }
713 3 : auto max_nu = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
714 3 : if (max_nu + 1 < description.range_quantum_number_nu.max()) {
715 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number nu. "
716 : "Requested: {}, found: {}.",
717 : description.range_quantum_number_nu.max(), max_nu);
718 : }
719 : }
720 75 : if (description.range_quantum_number_nui.is_finite()) {
721 0 : auto min_nui = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
722 0 : if (min_nui - 1 > description.range_quantum_number_nui.min()) {
723 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number nui. "
724 : "Requested: {}, found: {}.",
725 : description.range_quantum_number_nui.min(), min_nui);
726 : }
727 0 : auto max_nui = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
728 0 : if (max_nui + 1 < description.range_quantum_number_nui.max()) {
729 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number nui. "
730 : "Requested: {}, found: {}.",
731 : description.range_quantum_number_nui.max(), max_nui);
732 : }
733 : }
734 75 : if (description.range_quantum_number_l.is_finite()) {
735 74 : auto min_l = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
736 74 : if (min_l - 1 > description.range_quantum_number_l.min()) {
737 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number l. "
738 : "Requested: {}, found: {}.",
739 : description.range_quantum_number_l.min(), min_l);
740 : }
741 74 : auto max_l = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
742 74 : if (max_l + 1 < description.range_quantum_number_l.max()) {
743 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number l. "
744 : "Requested: {}, found: {}.",
745 : description.range_quantum_number_l.max(), max_l);
746 : }
747 : }
748 75 : if (description.range_quantum_number_s.is_finite()) {
749 0 : auto min_s = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
750 0 : if (min_s - 1 > description.range_quantum_number_s.min()) {
751 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number s. "
752 : "Requested: {}, found: {}.",
753 : description.range_quantum_number_s.min(), min_s);
754 : }
755 0 : auto max_s = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
756 0 : if (max_s + 1 < description.range_quantum_number_s.max()) {
757 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number s. "
758 : "Requested: {}, found: {}.",
759 : description.range_quantum_number_s.max(), max_s);
760 : }
761 : }
762 75 : if (description.range_quantum_number_j.is_finite()) {
763 0 : auto min_j = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
764 0 : if (min_j - 1 > description.range_quantum_number_j.min()) {
765 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number j. "
766 : "Requested: {}, found: {}.",
767 : description.range_quantum_number_j.min(), min_j);
768 : }
769 0 : auto max_j = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
770 0 : if (max_j + 1 < description.range_quantum_number_j.max()) {
771 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number j. "
772 : "Requested: {}, found: {}.",
773 : description.range_quantum_number_j.max(), max_j);
774 : }
775 : }
776 75 : if (description.range_quantum_number_l_ryd.is_finite()) {
777 0 : auto min_l_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
778 0 : if (min_l_ryd - 1 > description.range_quantum_number_l_ryd.min()) {
779 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number l_ryd. "
780 : "Requested: {}, found: {}.",
781 : description.range_quantum_number_l_ryd.min(), min_l_ryd);
782 : }
783 0 : auto max_l_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
784 0 : if (max_l_ryd + 1 < description.range_quantum_number_l_ryd.max()) {
785 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number l_ryd. "
786 : "Requested: {}, found: {}.",
787 : description.range_quantum_number_l_ryd.max(), max_l_ryd);
788 : }
789 : }
790 75 : if (description.range_quantum_number_j_ryd.is_finite()) {
791 0 : auto min_j_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
792 0 : if (min_j_ryd - 1 > description.range_quantum_number_j_ryd.min()) {
793 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number j_ryd. "
794 : "Requested: {}, found: {}.",
795 : description.range_quantum_number_j_ryd.min(), min_j_ryd);
796 : }
797 0 : auto max_j_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
798 0 : if (max_j_ryd + 1 < description.range_quantum_number_j_ryd.max()) {
799 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number j_ryd. "
800 : "Requested: {}, found: {}.",
801 : description.range_quantum_number_j_ryd.max(), max_j_ryd);
802 : }
803 : }
804 75 : }
805 88 : }
806 :
807 : // Ask the table for the described states
808 176 : auto result = con->Query(fmt::format(
809 : R"(SELECT energy, f, m, parity, ketid, n, nu, exp_nui, std_nui, exp_l, std_l,
810 : exp_s, std_s, exp_j, std_j, exp_l_ryd, std_l_ryd, exp_j_ryd, std_j_ryd, is_j_total_momentum, is_calculated_with_mqdt, underspecified_channel_contribution FROM '{}' ORDER BY ketid ASC)",
811 : id_of_kets));
812 :
813 88 : if (result->HasError()) {
814 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
815 : }
816 :
817 88 : if (result->RowCount() == 0) {
818 0 : throw std::invalid_argument("No state found.");
819 : }
820 :
821 : // Check the types of the columns
822 88 : const auto &types = result->types;
823 88 : const auto &labels = result->names;
824 2024 : const std::vector<duckdb::LogicalType> ref_types = {
825 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
826 : duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT,
827 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
828 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
829 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
830 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
831 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::BOOLEAN,
832 : duckdb::LogicalType::DOUBLE};
833 :
834 2024 : for (size_t i = 0; i < types.size(); i++) {
835 1936 : if (types[i] != ref_types[i]) {
836 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'. Got " +
837 0 : types[i].ToString() + " but expected " +
838 0 : ref_types[i].ToString());
839 : }
840 : }
841 :
842 : // Construct the states
843 88 : std::vector<std::shared_ptr<const KetAtom>> kets;
844 88 : kets.reserve(result->RowCount());
845 : #ifndef NDEBUG
846 88 : double last_energy = std::numeric_limits<double>::lowest();
847 : #endif
848 :
849 176 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
850 :
851 88 : auto *chunk_energy = duckdb::FlatVector::GetData<double>(chunk->data[0]);
852 88 : auto *chunk_quantum_number_f = duckdb::FlatVector::GetData<double>(chunk->data[1]);
853 88 : auto *chunk_quantum_number_m = duckdb::FlatVector::GetData<double>(chunk->data[2]);
854 88 : auto *chunk_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[3]);
855 88 : auto *chunk_id = duckdb::FlatVector::GetData<int64_t>(chunk->data[4]);
856 88 : auto *chunk_quantum_number_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[5]);
857 88 : auto *chunk_quantum_number_nu = duckdb::FlatVector::GetData<double>(chunk->data[6]);
858 88 : auto *chunk_quantum_number_nui_exp = duckdb::FlatVector::GetData<double>(chunk->data[7]);
859 88 : auto *chunk_quantum_number_nui_std = duckdb::FlatVector::GetData<double>(chunk->data[8]);
860 88 : auto *chunk_quantum_number_l_exp = duckdb::FlatVector::GetData<double>(chunk->data[9]);
861 88 : auto *chunk_quantum_number_l_std = duckdb::FlatVector::GetData<double>(chunk->data[10]);
862 88 : auto *chunk_quantum_number_s_exp = duckdb::FlatVector::GetData<double>(chunk->data[11]);
863 88 : auto *chunk_quantum_number_s_std = duckdb::FlatVector::GetData<double>(chunk->data[12]);
864 88 : auto *chunk_quantum_number_j_exp = duckdb::FlatVector::GetData<double>(chunk->data[13]);
865 88 : auto *chunk_quantum_number_j_std = duckdb::FlatVector::GetData<double>(chunk->data[14]);
866 88 : auto *chunk_quantum_number_l_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[15]);
867 88 : auto *chunk_quantum_number_l_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[16]);
868 88 : auto *chunk_quantum_number_j_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[17]);
869 88 : auto *chunk_quantum_number_j_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[18]);
870 88 : auto *chunk_is_j_total_momentum = duckdb::FlatVector::GetData<bool>(chunk->data[19]);
871 88 : auto *chunk_is_calculated_with_mqdt = duckdb::FlatVector::GetData<bool>(chunk->data[20]);
872 : auto *chunk_underspecified_channel_contribution =
873 88 : duckdb::FlatVector::GetData<double>(chunk->data[21]);
874 :
875 9364 : for (size_t i = 0; i < chunk->size(); i++) {
876 :
877 : #ifndef NDEBUG
878 : // Check database consistency
879 9276 : if (chunk_is_j_total_momentum[i] &&
880 8930 : chunk_quantum_number_f[i] != chunk_quantum_number_j_exp[i]) {
881 0 : throw std::runtime_error("If j is the total momentum, f must be equal to j.");
882 : }
883 9276 : if (chunk_energy[i] < last_energy) {
884 0 : throw std::runtime_error("The states are not sorted by energy.");
885 : }
886 9276 : last_energy = chunk_energy[i];
887 : #endif
888 :
889 : // Append a new state
890 9276 : kets.push_back(std::make_shared<const KetAtom>(
891 0 : typename KetAtom::Private(), chunk_energy[i], chunk_quantum_number_f[i],
892 9276 : chunk_quantum_number_m[i], static_cast<Parity>(chunk_parity[i]), species,
893 9276 : chunk_quantum_number_n[i], chunk_quantum_number_nu[i],
894 9276 : chunk_quantum_number_nui_exp[i], chunk_quantum_number_nui_std[i],
895 9276 : chunk_quantum_number_l_exp[i], chunk_quantum_number_l_std[i],
896 9276 : chunk_quantum_number_s_exp[i], chunk_quantum_number_s_std[i],
897 9276 : chunk_quantum_number_j_exp[i], chunk_quantum_number_j_std[i],
898 9276 : chunk_quantum_number_l_ryd_exp[i], chunk_quantum_number_l_ryd_std[i],
899 9276 : chunk_quantum_number_j_ryd_exp[i], chunk_quantum_number_j_ryd_std[i],
900 9276 : chunk_is_j_total_momentum[i], chunk_is_calculated_with_mqdt[i],
901 9276 : chunk_underspecified_channel_contribution[i], *this, chunk_id[i]));
902 : }
903 : }
904 :
905 0 : return std::make_shared<const BasisAtom<Scalar>>(typename BasisAtom<Scalar>::Private(),
906 176 : std::move(kets), std::move(id_of_kets), *this);
907 176 : }
908 :
909 : template <typename Scalar>
910 : Eigen::SparseMatrix<Scalar, Eigen::RowMajor>
911 736 : Database::get_matrix_elements(std::shared_ptr<const BasisAtom<Scalar>> initial_basis,
912 : std::shared_ptr<const BasisAtom<Scalar>> final_basis,
913 : OperatorType type, int q) {
914 : using real_t = typename traits::NumTraits<Scalar>::real_t;
915 :
916 736 : std::string specifier;
917 736 : int kappa{};
918 736 : switch (type) {
919 571 : case OperatorType::ELECTRIC_DIPOLE:
920 571 : specifier = "matrix_elements_d";
921 571 : kappa = 1;
922 571 : break;
923 111 : case OperatorType::ELECTRIC_QUADRUPOLE:
924 111 : specifier = "matrix_elements_q";
925 111 : kappa = 2;
926 111 : break;
927 27 : case OperatorType::ELECTRIC_QUADRUPOLE_ZERO:
928 27 : specifier = "matrix_elements_q0";
929 27 : kappa = 0;
930 27 : break;
931 0 : case OperatorType::ELECTRIC_OCTUPOLE:
932 0 : specifier = "matrix_elements_o";
933 0 : kappa = 3;
934 0 : break;
935 19 : case OperatorType::MAGNETIC_DIPOLE:
936 19 : specifier = "matrix_elements_mu";
937 19 : kappa = 1;
938 19 : break;
939 4 : case OperatorType::ENERGY:
940 4 : specifier = "energy";
941 4 : kappa = 0;
942 4 : break;
943 4 : case OperatorType::IDENTITY:
944 4 : specifier = "identity";
945 4 : kappa = 0;
946 4 : break;
947 0 : default:
948 0 : throw std::invalid_argument("Unknown operator type.");
949 : }
950 :
951 736 : if (initial_basis->get_id_of_kets() != final_basis->get_id_of_kets()) {
952 0 : throw std::invalid_argument(
953 : "The initial and final basis must be expressed using the same kets.");
954 : }
955 736 : std::string id_of_kets = initial_basis->get_id_of_kets();
956 0 : std::string cache_key = fmt::format("{}_{}_{}", specifier, q, id_of_kets);
957 :
958 736 : if (!get_matrix_elements_cache().contains(cache_key)) {
959 223 : Eigen::Index dim = initial_basis->get_number_of_kets();
960 :
961 223 : std::vector<int> outerIndexPtr;
962 223 : std::vector<int> innerIndices;
963 223 : std::vector<real_t> values;
964 :
965 223 : if (specifier == "identity") {
966 1 : outerIndexPtr.reserve(dim + 1);
967 1 : innerIndices.reserve(dim);
968 1 : values.reserve(dim);
969 :
970 91 : for (int i = 0; i < dim; i++) {
971 90 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
972 90 : innerIndices.push_back(i);
973 90 : values.push_back(1);
974 : }
975 1 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
976 :
977 : } else {
978 : // Check that the specifications are valid
979 222 : if (std::abs(q) > kappa) {
980 0 : throw std::invalid_argument("Invalid q.");
981 : }
982 :
983 : // Ask the database for the operator
984 222 : std::string species = initial_basis->get_species();
985 222 : duckdb::unique_ptr<duckdb::MaterializedQueryResult> result;
986 222 : if (specifier != "energy") {
987 440 : result = con->Query(fmt::format(
988 : R"(WITH s AS (
989 : SELECT id, f, m, ketid FROM '{}'
990 : ),
991 : b AS (
992 : SELECT MIN(f) AS min_f, MAX(f) AS max_f,
993 : MIN(id) AS min_id, MAX(id) AS max_id
994 : FROM s
995 : ),
996 : w_filtered AS (
997 : SELECT *
998 : FROM '{}'
999 : WHERE kappa = {} AND q = {} AND
1000 : f_initial BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b) AND
1001 : f_final BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b)
1002 : ),
1003 : e_filtered AS (
1004 : SELECT *
1005 : FROM '{}'
1006 : WHERE
1007 : id_initial BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b) AND
1008 : id_final BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b)
1009 : )
1010 : SELECT
1011 : s2.ketid AS row,
1012 : s1.ketid AS col,
1013 : e.val*w.val AS val
1014 : FROM e_filtered AS e
1015 : JOIN s AS s1 ON e.id_initial = s1.id
1016 : JOIN s AS s2 ON e.id_final = s2.id
1017 : JOIN w_filtered AS w ON
1018 : w.f_initial = s1.f AND w.m_initial = s1.m AND
1019 : w.f_final = s2.f AND w.m_final = s2.m
1020 : ORDER BY row ASC, col ASC)",
1021 : id_of_kets, manager->get_path("misc", "wigner"), kappa, q,
1022 : manager->get_path(species, specifier)));
1023 : } else {
1024 4 : result = con->Query(fmt::format(
1025 : R"(SELECT ketid as row, ketid as col, energy as val FROM '{}' ORDER BY row ASC)",
1026 : id_of_kets));
1027 : }
1028 :
1029 222 : if (result->HasError()) {
1030 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
1031 : }
1032 :
1033 : // Check the types of the columns
1034 222 : const auto &types = result->types;
1035 222 : const auto &labels = result->names;
1036 888 : const std::vector<duckdb::LogicalType> ref_types = {duckdb::LogicalType::BIGINT,
1037 : duckdb::LogicalType::BIGINT,
1038 : duckdb::LogicalType::DOUBLE};
1039 888 : for (size_t i = 0; i < types.size(); i++) {
1040 666 : if (types[i] != ref_types[i]) {
1041 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'.");
1042 : }
1043 : }
1044 :
1045 : // Construct the matrix
1046 222 : int num_entries = static_cast<int>(result->RowCount());
1047 222 : outerIndexPtr.reserve(dim + 1);
1048 222 : innerIndices.reserve(num_entries);
1049 222 : values.reserve(num_entries);
1050 :
1051 222 : int last_row = -1;
1052 :
1053 635 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
1054 :
1055 413 : auto *chunk_row = duckdb::FlatVector::GetData<int64_t>(chunk->data[0]);
1056 413 : auto *chunk_col = duckdb::FlatVector::GetData<int64_t>(chunk->data[1]);
1057 413 : auto *chunk_val = duckdb::FlatVector::GetData<double>(chunk->data[2]);
1058 :
1059 543004 : for (size_t i = 0; i < chunk->size(); i++) {
1060 542591 : int row = final_basis->get_ket_index_from_id(chunk_row[i]);
1061 542591 : if (row != last_row) {
1062 22996 : if (row < last_row) {
1063 0 : throw std::runtime_error("The rows are not sorted.");
1064 : }
1065 51221 : for (; last_row < row; last_row++) {
1066 28225 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1067 : }
1068 : }
1069 542591 : innerIndices.push_back(initial_basis->get_ket_index_from_id(chunk_col[i]));
1070 542591 : values.push_back(chunk_val[i]);
1071 : }
1072 : }
1073 :
1074 767 : for (; last_row < dim + 1; last_row++) {
1075 545 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1076 : }
1077 222 : }
1078 :
1079 223 : Eigen::Map<const Eigen::SparseMatrix<real_t, Eigen::RowMajor>> matrix_map(
1080 223 : dim, dim, values.size(), outerIndexPtr.data(), innerIndices.data(), values.data());
1081 :
1082 : // Cache the matrix
1083 223 : get_matrix_elements_cache()[cache_key] = matrix_map;
1084 223 : }
1085 :
1086 : // Construct the operator and return it
1087 736 : return final_basis->get_coefficients().adjoint() *
1088 1358 : get_matrix_elements_cache()[cache_key].template cast<Scalar>() *
1089 2322 : initial_basis->get_coefficients();
1090 958 : }
1091 :
1092 2 : bool Database::get_download_missing() const { return download_missing_; }
1093 :
1094 0 : bool Database::get_use_cache() const { return use_cache_; }
1095 :
1096 0 : std::filesystem::path Database::get_database_dir() const { return database_dir_; }
1097 :
1098 : oneapi::tbb::concurrent_unordered_map<std::string, Eigen::SparseMatrix<double, Eigen::RowMajor>> &
1099 1695 : Database::get_matrix_elements_cache() {
1100 : static oneapi::tbb::concurrent_unordered_map<std::string,
1101 : Eigen::SparseMatrix<double, Eigen::RowMajor>>
1102 1695 : matrix_elements_cache;
1103 1695 : return matrix_elements_cache;
1104 : }
1105 :
1106 31 : Database &Database::get_global_instance() {
1107 62 : return get_global_instance_without_checks(default_download_missing, default_use_cache,
1108 62 : default_database_dir);
1109 : }
1110 :
1111 0 : Database &Database::get_global_instance(bool download_missing) {
1112 0 : Database &database = get_global_instance_without_checks(download_missing, default_use_cache,
1113 : default_database_dir);
1114 0 : if (download_missing != database.download_missing_) {
1115 0 : throw std::invalid_argument(
1116 0 : "The 'download_missing' argument must not change between calls to the method.");
1117 : }
1118 0 : return database;
1119 : }
1120 :
1121 0 : Database &Database::get_global_instance(std::filesystem::path database_dir) {
1122 0 : if (database_dir.empty()) {
1123 0 : database_dir = default_database_dir;
1124 : }
1125 0 : Database &database = get_global_instance_without_checks(default_download_missing,
1126 : default_use_cache, database_dir);
1127 0 : if (!std::filesystem::exists(database_dir) ||
1128 0 : std::filesystem::canonical(database_dir) != database.database_dir_) {
1129 0 : throw std::invalid_argument(
1130 0 : "The 'database_dir' argument must not change between calls to the method.");
1131 : }
1132 0 : return database;
1133 : }
1134 :
1135 1 : Database &Database::get_global_instance(bool download_missing, bool use_cache,
1136 : std::filesystem::path database_dir) {
1137 1 : if (database_dir.empty()) {
1138 0 : database_dir = default_database_dir;
1139 : }
1140 : Database &database =
1141 1 : get_global_instance_without_checks(download_missing, use_cache, database_dir);
1142 1 : if (download_missing != database.download_missing_ || use_cache != database.use_cache_ ||
1143 3 : !std::filesystem::exists(database_dir) ||
1144 2 : std::filesystem::canonical(database_dir) != database.database_dir_) {
1145 0 : throw std::invalid_argument(
1146 : "The 'download_missing', 'use_cache' and 'database_dir' arguments must not "
1147 0 : "change between calls to the method.");
1148 : }
1149 1 : return database;
1150 : }
1151 :
1152 32 : Database &Database::get_global_instance_without_checks(bool download_missing, bool use_cache,
1153 : std::filesystem::path database_dir) {
1154 32 : static Database database(download_missing, use_cache, std::move(database_dir));
1155 32 : return database;
1156 : }
1157 :
1158 : struct database_dir_noexcept : std::filesystem::path {
1159 2 : explicit database_dir_noexcept() noexcept try : std
1160 2 : ::filesystem::path(paths::get_cache_directory() / "database") {}
1161 0 : catch (...) {
1162 0 : SPDLOG_ERROR("Error getting the pairinteraction cache directory.");
1163 0 : std::terminate();
1164 2 : }
1165 : };
1166 :
1167 : const std::filesystem::path Database::default_database_dir = database_dir_noexcept();
1168 :
1169 : // Explicit instantiations
1170 : // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
1171 : #define INSTANTIATE_GETTERS(SCALAR) \
1172 : template std::shared_ptr<const BasisAtom<SCALAR>> Database::get_basis<SCALAR>( \
1173 : const std::string &species, const AtomDescriptionByRanges &description, \
1174 : std::vector<size_t> additional_ket_ids); \
1175 : template Eigen::SparseMatrix<SCALAR, Eigen::RowMajor> Database::get_matrix_elements<SCALAR>( \
1176 : std::shared_ptr<const BasisAtom<SCALAR>> initial_basis, \
1177 : std::shared_ptr<const BasisAtom<SCALAR>> final_basis, OperatorType type, int q);
1178 : // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
1179 :
1180 : INSTANTIATE_GETTERS(double)
1181 : INSTANTIATE_GETTERS(std::complex<double>)
1182 :
1183 : #undef INSTANTIATE_GETTERS
1184 : } // namespace pairinteraction
|