Line data Source code
1 : // SPDX-FileCopyrightText: 2024 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/Database.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/database/AtomDescriptionByParameters.hpp"
8 : #include "pairinteraction/database/AtomDescriptionByRanges.hpp"
9 : #include "pairinteraction/database/GitHubDownloader.hpp"
10 : #include "pairinteraction/database/ParquetManager.hpp"
11 : #include "pairinteraction/enums/OperatorType.hpp"
12 : #include "pairinteraction/enums/Parity.hpp"
13 : #include "pairinteraction/ket/KetAtom.hpp"
14 : #include "pairinteraction/utils/hash.hpp"
15 : #include "pairinteraction/utils/id_in_database.hpp"
16 : #include "pairinteraction/utils/paths.hpp"
17 : #include "pairinteraction/utils/streamed.hpp"
18 :
19 : #include <cpptrace/cpptrace.hpp>
20 : #include <duckdb.hpp>
21 : #include <fmt/core.h>
22 : #include <fmt/ranges.h>
23 : #include <fstream>
24 : #include <nlohmann/json.hpp>
25 : #include <oneapi/tbb.h>
26 : #include <spdlog/spdlog.h>
27 : #include <system_error>
28 :
29 : namespace pairinteraction {
30 0 : Database::Database() : Database(default_download_missing) {}
31 :
32 0 : Database::Database(bool download_missing)
33 0 : : Database(download_missing, default_use_cache, default_database_dir) {}
34 :
35 0 : Database::Database(std::filesystem::path database_dir)
36 0 : : Database(default_download_missing, default_use_cache, std::move(database_dir)) {}
37 :
38 5 : Database::Database(bool download_missing, bool use_cache, std::filesystem::path database_dir)
39 5 : : download_missing_(download_missing), use_cache_(use_cache),
40 5 : database_dir_(std::move(database_dir)), db(std::make_unique<duckdb::DuckDB>(nullptr)),
41 25 : con(std::make_unique<duckdb::Connection>(*db)) {
42 :
43 5 : if (database_dir_.empty()) {
44 0 : database_dir_ = default_database_dir;
45 : }
46 :
47 : // Ensure the database directory exists
48 5 : if (!std::filesystem::exists(database_dir_)) {
49 0 : std::filesystem::create_directories(database_dir_);
50 : }
51 5 : database_dir_ = std::filesystem::canonical(database_dir_);
52 5 : if (!std::filesystem::is_directory(database_dir_)) {
53 0 : throw std::filesystem::filesystem_error("Cannot access database", database_dir_.string(),
54 0 : std::make_error_code(std::errc::not_a_directory));
55 : }
56 10 : SPDLOG_INFO("Using database directory: {}", database_dir_.string());
57 :
58 : // Ensure that the config directory exists
59 5 : std::filesystem::path configdir = paths::get_config_directory();
60 5 : if (!std::filesystem::exists(configdir)) {
61 1 : std::filesystem::create_directories(configdir);
62 4 : } else if (!std::filesystem::is_directory(configdir)) {
63 0 : throw std::filesystem::filesystem_error("Cannot access config directory ",
64 0 : configdir.string(),
65 0 : std::make_error_code(std::errc::not_a_directory));
66 : }
67 :
68 : // Read in the database_repo_paths if a config file exists, otherwise use the default and
69 : // write it to the config file
70 5 : std::filesystem::path configfile = configdir / "database.json";
71 5 : std::string database_repo_host;
72 5 : std::vector<std::string> database_repo_paths;
73 5 : if (std::filesystem::exists(configfile)) {
74 4 : std::ifstream file(configfile);
75 4 : nlohmann::json doc = nlohmann::json::parse(file, nullptr, false);
76 :
77 8 : if (!doc.is_discarded() && doc.contains("hash") && doc.contains("database_repo_host") &&
78 4 : doc.contains("database_repo_paths")) {
79 4 : database_repo_host = doc["database_repo_host"].get<std::string>();
80 4 : database_repo_paths = doc["database_repo_paths"].get<std::vector<std::string>>();
81 :
82 : // If the values are not equal to the default values but the hash is consistent (i.e.,
83 : // the user has not changed anything manually), clear the values so that they can be
84 : // updated
85 8 : if (database_repo_host != default_database_repo_host ||
86 4 : database_repo_paths != default_database_repo_paths) {
87 0 : std::size_t seed = 0;
88 0 : utils::hash_combine(seed, database_repo_paths);
89 0 : utils::hash_combine(seed, database_repo_host);
90 0 : if (seed == doc["hash"].get<std::size_t>()) {
91 0 : database_repo_host.clear();
92 0 : database_repo_paths.clear();
93 : } else {
94 0 : SPDLOG_INFO("The database repository host and paths have been changed "
95 : "manually. Thus, they will not be updated automatically. To reset "
96 : "them, delete the file '{}'.",
97 : configfile.string());
98 : }
99 : }
100 : }
101 4 : }
102 :
103 : // Read in and store the default values if necessary
104 5 : if (database_repo_host.empty() || database_repo_paths.empty()) {
105 2 : SPDLOG_INFO("Updating the database repository host and paths:");
106 :
107 1 : database_repo_host = default_database_repo_host;
108 1 : database_repo_paths = default_database_repo_paths;
109 1 : std::ofstream file(configfile);
110 1 : nlohmann::json doc;
111 :
112 2 : SPDLOG_INFO("* New host: {}", default_database_repo_host);
113 3 : SPDLOG_INFO("* New paths: {}", fmt::join(default_database_repo_paths, ", "));
114 :
115 1 : doc["database_repo_host"] = default_database_repo_host;
116 1 : doc["database_repo_paths"] = database_repo_paths;
117 :
118 1 : std::size_t seed = 0;
119 1 : utils::hash_combine(seed, default_database_repo_paths);
120 1 : utils::hash_combine(seed, default_database_repo_host);
121 1 : doc["hash"] = seed;
122 :
123 1 : file << doc.dump(4);
124 1 : }
125 :
126 : // Limit the memory usage of duckdb's buffer manager
127 : {
128 5 : auto result = con->Query("PRAGMA max_memory = '8GB';");
129 5 : if (result->HasError()) {
130 0 : throw cpptrace::runtime_error("Error setting the memory limit: " + result->GetError());
131 : }
132 5 : }
133 :
134 : // Instantiate a database manager that provides access to database tables. If a table
135 : // is outdated/not available locally, it will be downloaded if download_missing_ is true.
136 5 : if (!download_missing_) {
137 5 : database_repo_paths.clear();
138 : }
139 5 : downloader = std::make_unique<GitHubDownloader>();
140 10 : manager = std::make_unique<ParquetManager>(database_dir_, *downloader, database_repo_paths,
141 10 : *con, use_cache_);
142 5 : manager->scan_local();
143 5 : manager->scan_remote();
144 :
145 : // Print versions of tables
146 5 : std::istringstream iss(manager->get_versions_info());
147 60 : for (std::string line; std::getline(iss, line);) {
148 55 : SPDLOG_INFO(line);
149 5 : }
150 15 : }
151 :
152 5 : Database::~Database() = default;
153 :
154 30 : std::shared_ptr<const KetAtom> Database::get_ket(const std::string &species,
155 : const AtomDescriptionByParameters &description) {
156 : // Check that the specifications are valid
157 30 : if (!description.quantum_number_m.has_value()) {
158 0 : throw std::invalid_argument("The quantum number m must be specified.");
159 : }
160 38 : if (description.quantum_number_f.has_value() &&
161 8 : 2 * description.quantum_number_f.value() !=
162 8 : std::rint(2 * description.quantum_number_f.value())) {
163 0 : throw std::invalid_argument("The quantum number f must be an integer or half-integer.");
164 : }
165 30 : if (description.quantum_number_f.has_value() && description.quantum_number_f.value() < 0) {
166 0 : throw std::invalid_argument("The quantum number f must be positive.");
167 : }
168 49 : if (description.quantum_number_j.has_value() &&
169 19 : 2 * description.quantum_number_j.value() !=
170 19 : std::rint(2 * description.quantum_number_j.value())) {
171 0 : throw std::invalid_argument("The quantum number j must be an integer or half-integer.");
172 : }
173 30 : if (description.quantum_number_j.has_value() && description.quantum_number_j.value() < 0) {
174 0 : throw std::invalid_argument("The quantum number j must be positive.");
175 : }
176 60 : if (description.quantum_number_m.has_value() &&
177 30 : 2 * description.quantum_number_m.value() !=
178 30 : std::rint(2 * description.quantum_number_m.value())) {
179 0 : throw std::invalid_argument("The quantum number m must be an integer or half-integer.");
180 : }
181 :
182 : // Describe the state
183 30 : std::string where;
184 30 : std::string separator;
185 30 : if (description.energy.has_value()) {
186 : // The following condition derives from demanding that quantum number n that corresponds to
187 : // the energy "E_n = -1/(2*n^2)" is not off by more than 1 from the actual quantum number n,
188 : // i.e., "sqrt(-1/(2*E_n)) - sqrt(-1/(2*E_{n-1})) = 1"
189 0 : where += separator +
190 0 : fmt::format("SQRT(-1/(2*energy)) BETWEEN {} AND {}",
191 0 : std::sqrt(-1 / (2 * description.energy.value())) - 0.5,
192 0 : std::sqrt(-1 / (2 * description.energy.value())) + 0.5);
193 0 : separator = " AND ";
194 : }
195 30 : if (description.quantum_number_f.has_value()) {
196 16 : where += separator + fmt::format("f = {}", description.quantum_number_f.value());
197 8 : separator = " AND ";
198 : }
199 30 : if (description.parity != Parity::UNKNOWN) {
200 0 : where += separator + fmt::format("parity = {}", fmt::streamed(description.parity));
201 0 : separator = " AND ";
202 : }
203 30 : if (description.quantum_number_n.has_value()) {
204 60 : where += separator + fmt::format("n = {}", description.quantum_number_n.value());
205 30 : separator = " AND ";
206 : }
207 30 : if (description.quantum_number_nu.has_value()) {
208 0 : where += separator +
209 0 : fmt::format("nu BETWEEN {} AND {}", description.quantum_number_nu.value() - 0.5,
210 0 : description.quantum_number_nu.value() + 0.5);
211 0 : separator = " AND ";
212 : }
213 30 : if (description.quantum_number_nui.has_value()) {
214 0 : where += separator +
215 0 : fmt::format("exp_nui BETWEEN {} AND {}", description.quantum_number_nui.value() - 0.5,
216 0 : description.quantum_number_nui.value() + 0.5);
217 0 : separator = " AND ";
218 : }
219 30 : if (description.quantum_number_l.has_value()) {
220 30 : where += separator +
221 90 : fmt::format("exp_l BETWEEN {} AND {}", description.quantum_number_l.value() - 0.5,
222 60 : description.quantum_number_l.value() + 0.5);
223 30 : separator = " AND ";
224 : }
225 30 : if (description.quantum_number_s.has_value()) {
226 8 : where += separator +
227 24 : fmt::format("exp_s BETWEEN {} AND {}", description.quantum_number_s.value() - 0.5,
228 16 : description.quantum_number_s.value() + 0.5);
229 8 : separator = " AND ";
230 : }
231 30 : if (description.quantum_number_j.has_value()) {
232 19 : where += separator +
233 57 : fmt::format("exp_j BETWEEN {} AND {}", description.quantum_number_j.value() - 0.5,
234 38 : description.quantum_number_j.value() + 0.5);
235 19 : separator = " AND ";
236 : }
237 30 : if (description.quantum_number_l_ryd.has_value()) {
238 0 : where += separator +
239 0 : fmt::format("exp_l_ryd BETWEEN {} AND {}",
240 0 : description.quantum_number_l_ryd.value() - 0.5,
241 0 : description.quantum_number_l_ryd.value() + 0.5);
242 0 : separator = " AND ";
243 : }
244 30 : if (description.quantum_number_j_ryd.has_value()) {
245 0 : where += separator +
246 0 : fmt::format("exp_j_ryd BETWEEN {} AND {}",
247 0 : description.quantum_number_j_ryd.value() - 0.5,
248 0 : description.quantum_number_j_ryd.value() + 0.5);
249 0 : separator = " AND ";
250 : }
251 30 : if (separator.empty()) {
252 0 : where += "FALSE";
253 : }
254 :
255 30 : std::string orderby;
256 30 : separator = "";
257 30 : if (description.energy.has_value()) {
258 0 : orderby += separator +
259 0 : fmt::format("(SQRT(-1/(2*energy)) - {})^2",
260 0 : std::sqrt(-1 / (2 * description.energy.value())));
261 0 : separator = " + ";
262 : }
263 30 : if (description.quantum_number_nu.has_value()) {
264 0 : orderby += separator + fmt::format("(nu - {})^2", description.quantum_number_nu.value());
265 0 : separator = " + ";
266 : }
267 30 : if (description.quantum_number_nui.has_value()) {
268 : orderby +=
269 0 : separator + fmt::format("(exp_nui - {})^2", description.quantum_number_nui.value());
270 0 : separator = " + ";
271 : }
272 30 : if (description.quantum_number_l.has_value()) {
273 60 : orderby += separator + fmt::format("(exp_l - {})^2", description.quantum_number_l.value());
274 30 : separator = " + ";
275 : }
276 30 : if (description.quantum_number_s.has_value()) {
277 16 : orderby += separator + fmt::format("(exp_s - {})^2", description.quantum_number_s.value());
278 8 : separator = " + ";
279 : }
280 30 : if (description.quantum_number_j.has_value()) {
281 38 : orderby += separator + fmt::format("(exp_j - {})^2", description.quantum_number_j.value());
282 19 : separator = " + ";
283 : }
284 30 : if (description.quantum_number_l_ryd.has_value()) {
285 : orderby +=
286 0 : separator + fmt::format("(exp_l_ryd - {})^2", description.quantum_number_l_ryd.value());
287 0 : separator = " + ";
288 : }
289 30 : if (description.quantum_number_j_ryd.has_value()) {
290 : orderby +=
291 0 : separator + fmt::format("(exp_j_ryd - {})^2", description.quantum_number_j_ryd.value());
292 0 : separator = " + ";
293 : }
294 30 : if (separator.empty()) {
295 0 : orderby += "id";
296 : }
297 :
298 : // Ask the database for the described state
299 30 : auto result = con->Query(fmt::format(
300 : R"(SELECT energy, f, parity, id, n, nu, exp_nui, std_nui, exp_l, std_l, exp_s, std_s,
301 : exp_j, std_j, exp_l_ryd, std_l_ryd, exp_j_ryd, std_j_ryd, is_j_total_momentum, is_calculated_with_mqdt, underspecified_channel_contribution, {} AS order_val FROM '{}' WHERE {} ORDER BY order_val ASC LIMIT 2)",
302 90 : orderby, manager->get_path(species, "states"), where));
303 :
304 30 : if (result->HasError()) {
305 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
306 : }
307 :
308 30 : if (result->RowCount() == 0) {
309 0 : throw std::invalid_argument("No state found.");
310 : }
311 :
312 : // Check the types of the columns
313 30 : const auto &types = result->types;
314 30 : const auto &labels = result->names;
315 : const std::vector<duckdb::LogicalType> ref_types = {
316 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::BIGINT,
317 : duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT, duckdb::LogicalType::DOUBLE,
318 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
319 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
320 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
321 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
322 : duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::DOUBLE,
323 690 : duckdb::LogicalType::DOUBLE};
324 :
325 690 : for (size_t i = 0; i < types.size(); i++) {
326 660 : if (types[i] != ref_types[i]) {
327 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'. Got " +
328 0 : types[i].ToString() + " but expected " +
329 0 : ref_types[i].ToString());
330 : }
331 : }
332 :
333 : // Get the first chunk of the results (the first chunk is sufficient as we need two rows at
334 : // most)
335 30 : auto chunk = result->Fetch();
336 :
337 : // Check that the ket is uniquely specified
338 30 : if (chunk->size() > 1) {
339 1 : auto order_val_0 = duckdb::FlatVector::GetData<double>(chunk->data[21])[0];
340 1 : auto order_val_1 = duckdb::FlatVector::GetData<double>(chunk->data[21])[1];
341 :
342 1 : if (order_val_1 - order_val_0 <= order_val_0) {
343 : // Get a list of possible kets
344 1 : std::vector<KetAtom> kets;
345 1 : kets.reserve(2);
346 3 : for (size_t i = 0; i < 2; ++i) {
347 2 : auto result_quantum_number_m = description.quantum_number_m.value();
348 2 : auto result_energy = duckdb::FlatVector::GetData<double>(chunk->data[0])[i];
349 : auto result_quantum_number_f =
350 2 : duckdb::FlatVector::GetData<double>(chunk->data[1])[i];
351 2 : auto result_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[2])[i];
352 4 : auto result_id = utils::get_linearized_id_in_database(
353 2 : duckdb::FlatVector::GetData<int64_t>(chunk->data[3])[i],
354 2 : result_quantum_number_m);
355 : auto result_quantum_number_n =
356 2 : duckdb::FlatVector::GetData<int64_t>(chunk->data[4])[i];
357 : auto result_quantum_number_nu =
358 2 : duckdb::FlatVector::GetData<double>(chunk->data[5])[i];
359 : auto result_quantum_number_nui_exp =
360 2 : duckdb::FlatVector::GetData<double>(chunk->data[6])[i];
361 : auto result_quantum_number_nui_std =
362 2 : duckdb::FlatVector::GetData<double>(chunk->data[7])[i];
363 : auto result_quantum_number_l_exp =
364 2 : duckdb::FlatVector::GetData<double>(chunk->data[8])[i];
365 : auto result_quantum_number_l_std =
366 2 : duckdb::FlatVector::GetData<double>(chunk->data[9])[i];
367 : auto result_quantum_number_s_exp =
368 2 : duckdb::FlatVector::GetData<double>(chunk->data[10])[i];
369 : auto result_quantum_number_s_std =
370 2 : duckdb::FlatVector::GetData<double>(chunk->data[11])[i];
371 : auto result_quantum_number_j_exp =
372 2 : duckdb::FlatVector::GetData<double>(chunk->data[12])[i];
373 : auto result_quantum_number_j_std =
374 2 : duckdb::FlatVector::GetData<double>(chunk->data[13])[i];
375 : auto result_quantum_number_l_ryd_exp =
376 2 : duckdb::FlatVector::GetData<double>(chunk->data[14])[i];
377 : auto result_quantum_number_l_ryd_std =
378 2 : duckdb::FlatVector::GetData<double>(chunk->data[15])[i];
379 : auto result_quantum_number_j_ryd_exp =
380 2 : duckdb::FlatVector::GetData<double>(chunk->data[16])[i];
381 : auto result_quantum_number_j_ryd_std =
382 2 : duckdb::FlatVector::GetData<double>(chunk->data[17])[i];
383 : auto result_is_j_total_momentum =
384 2 : duckdb::FlatVector::GetData<bool>(chunk->data[18])[i];
385 : auto result_is_calculated_with_mqdt =
386 2 : duckdb::FlatVector::GetData<bool>(chunk->data[19])[i];
387 : auto result_underspecified_channel_contribution =
388 2 : duckdb::FlatVector::GetData<double>(chunk->data[20])[i];
389 0 : kets.emplace_back(typename KetAtom::Private(), result_energy,
390 : result_quantum_number_f, result_quantum_number_m,
391 2 : static_cast<Parity>(result_parity), species,
392 : result_quantum_number_n, result_quantum_number_nu,
393 : result_quantum_number_nui_exp, result_quantum_number_nui_std,
394 : result_quantum_number_l_exp, result_quantum_number_l_std,
395 : result_quantum_number_s_exp, result_quantum_number_s_std,
396 : result_quantum_number_j_exp, result_quantum_number_j_std,
397 : result_quantum_number_l_ryd_exp, result_quantum_number_l_ryd_std,
398 : result_quantum_number_j_ryd_exp, result_quantum_number_j_ryd_std,
399 : result_is_j_total_momentum, result_is_calculated_with_mqdt,
400 : result_underspecified_channel_contribution, *this, result_id);
401 : }
402 :
403 : // Throw an error with the possible kets
404 2 : throw std::invalid_argument(
405 1 : fmt::format("The ket is not uniquely specified. Possible kets are:\n{}\n{}",
406 2 : fmt::streamed(kets[0]), fmt::streamed(kets[1])));
407 1 : }
408 : }
409 :
410 : // Construct the state
411 29 : auto result_quantum_number_m = description.quantum_number_m.value();
412 29 : auto result_energy = duckdb::FlatVector::GetData<double>(chunk->data[0])[0];
413 29 : auto result_quantum_number_f = duckdb::FlatVector::GetData<double>(chunk->data[1])[0];
414 29 : auto result_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[2])[0];
415 58 : auto result_id = utils::get_linearized_id_in_database(
416 29 : duckdb::FlatVector::GetData<int64_t>(chunk->data[3])[0], result_quantum_number_m);
417 29 : auto result_quantum_number_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[4])[0];
418 29 : auto result_quantum_number_nu = duckdb::FlatVector::GetData<double>(chunk->data[5])[0];
419 29 : auto result_quantum_number_nui_exp = duckdb::FlatVector::GetData<double>(chunk->data[6])[0];
420 29 : auto result_quantum_number_nui_std = duckdb::FlatVector::GetData<double>(chunk->data[7])[0];
421 29 : auto result_quantum_number_l_exp = duckdb::FlatVector::GetData<double>(chunk->data[8])[0];
422 29 : auto result_quantum_number_l_std = duckdb::FlatVector::GetData<double>(chunk->data[9])[0];
423 29 : auto result_quantum_number_s_exp = duckdb::FlatVector::GetData<double>(chunk->data[10])[0];
424 29 : auto result_quantum_number_s_std = duckdb::FlatVector::GetData<double>(chunk->data[11])[0];
425 29 : auto result_quantum_number_j_exp = duckdb::FlatVector::GetData<double>(chunk->data[12])[0];
426 29 : auto result_quantum_number_j_std = duckdb::FlatVector::GetData<double>(chunk->data[13])[0];
427 29 : auto result_quantum_number_l_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[14])[0];
428 29 : auto result_quantum_number_l_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[15])[0];
429 29 : auto result_quantum_number_j_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[16])[0];
430 29 : auto result_quantum_number_j_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[17])[0];
431 29 : auto result_is_j_total_momentum = duckdb::FlatVector::GetData<bool>(chunk->data[18])[0];
432 29 : auto result_is_calculated_with_mqdt = duckdb::FlatVector::GetData<bool>(chunk->data[19])[0];
433 : auto result_underspecified_channel_contribution =
434 29 : duckdb::FlatVector::GetData<double>(chunk->data[20])[0];
435 :
436 : // Check the quantum number m
437 29 : if (std::abs(result_quantum_number_m) > result_quantum_number_f) {
438 1 : throw std::invalid_argument(
439 2 : "The absolute value of the quantum number m must be less than or equal to f.");
440 : }
441 28 : if (result_quantum_number_f + result_quantum_number_m !=
442 28 : std::rint(result_quantum_number_f + result_quantum_number_m)) {
443 0 : throw std::invalid_argument(
444 0 : "The quantum numbers f and m must be both either integers or half-integers.");
445 : }
446 :
447 : #ifndef NDEBUG
448 : // Check database consistency
449 28 : if (result_is_j_total_momentum && result_quantum_number_f != result_quantum_number_j_exp) {
450 0 : throw std::runtime_error("If j is the total momentum, f must be equal to j.");
451 : }
452 : #endif
453 :
454 : return std::make_shared<const KetAtom>(
455 0 : typename KetAtom::Private(), result_energy, result_quantum_number_f,
456 0 : result_quantum_number_m, static_cast<Parity>(result_parity), species,
457 : result_quantum_number_n, result_quantum_number_nu, result_quantum_number_nui_exp,
458 : result_quantum_number_nui_std, result_quantum_number_l_exp, result_quantum_number_l_std,
459 : result_quantum_number_s_exp, result_quantum_number_s_std, result_quantum_number_j_exp,
460 : result_quantum_number_j_std, result_quantum_number_l_ryd_exp,
461 : result_quantum_number_l_ryd_std, result_quantum_number_j_ryd_exp,
462 : result_quantum_number_j_ryd_std, result_is_j_total_momentum, result_is_calculated_with_mqdt,
463 56 : result_underspecified_channel_contribution, *this, result_id);
464 70 : }
465 :
466 : template <typename Scalar>
467 : std::shared_ptr<const BasisAtom<Scalar>>
468 27 : Database::get_basis(const std::string &species, const AtomDescriptionByRanges &description,
469 : std::vector<size_t> additional_ket_ids) {
470 : // Describe the states
471 27 : std::string where = "(";
472 27 : std::string separator;
473 27 : if (description.parity != Parity::UNKNOWN) {
474 0 : where += separator + fmt::format("parity = {}", fmt::streamed(description.parity));
475 0 : separator = " AND ";
476 : }
477 27 : if (description.range_energy.is_finite()) {
478 0 : where += separator +
479 0 : fmt::format("energy BETWEEN {} AND {}", description.range_energy.min(),
480 0 : description.range_energy.max());
481 0 : separator = " AND ";
482 : }
483 27 : if (description.range_quantum_number_f.is_finite()) {
484 0 : where += separator +
485 0 : fmt::format("f BETWEEN {} AND {}", description.range_quantum_number_f.min(),
486 0 : description.range_quantum_number_f.max());
487 0 : separator = " AND ";
488 : }
489 27 : if (description.range_quantum_number_m.is_finite()) {
490 5 : where += separator +
491 5 : fmt::format("m BETWEEN {} AND {}", description.range_quantum_number_m.min(),
492 5 : description.range_quantum_number_m.max());
493 5 : separator = " AND ";
494 : }
495 27 : if (description.range_quantum_number_n.is_finite()) {
496 23 : where += separator +
497 23 : fmt::format("n BETWEEN {} AND {}", description.range_quantum_number_n.min(),
498 23 : description.range_quantum_number_n.max());
499 23 : separator = " AND ";
500 : }
501 27 : if (description.range_quantum_number_nu.is_finite()) {
502 1 : where += separator +
503 1 : fmt::format("nu BETWEEN {} AND {}", description.range_quantum_number_nu.min(),
504 1 : description.range_quantum_number_nu.max());
505 1 : separator = " AND ";
506 : }
507 27 : if (description.range_quantum_number_nui.is_finite()) {
508 0 : where += separator +
509 : fmt::format("exp_nui BETWEEN {}-2*std_nui AND {}+2*std_nui",
510 0 : description.range_quantum_number_nui.min(),
511 0 : description.range_quantum_number_nui.max());
512 0 : separator = " AND ";
513 : }
514 27 : if (description.range_quantum_number_l.is_finite()) {
515 24 : where += separator +
516 : fmt::format("exp_l BETWEEN {}-2*std_l AND {}+2*std_l",
517 24 : description.range_quantum_number_l.min(),
518 24 : description.range_quantum_number_l.max());
519 24 : separator = " AND ";
520 : }
521 27 : if (description.range_quantum_number_s.is_finite()) {
522 0 : where += separator +
523 : fmt::format("exp_s BETWEEN {}-2*std_s AND {}+2*std_s",
524 0 : description.range_quantum_number_s.min(),
525 0 : description.range_quantum_number_s.max());
526 0 : separator = " AND ";
527 : }
528 27 : if (description.range_quantum_number_j.is_finite()) {
529 0 : where += separator +
530 : fmt::format("exp_j BETWEEN {}-2*std_j AND {}+2*std_j",
531 0 : description.range_quantum_number_j.min(),
532 0 : description.range_quantum_number_j.max());
533 0 : separator = " AND ";
534 : }
535 27 : if (description.range_quantum_number_l_ryd.is_finite()) {
536 0 : where += separator +
537 : fmt::format("exp_l_ryd BETWEEN {}-2*std_l_ryd AND {}+2*std_l_ryd",
538 0 : description.range_quantum_number_l_ryd.min(),
539 0 : description.range_quantum_number_l_ryd.max());
540 0 : separator = " AND ";
541 : }
542 27 : if (description.range_quantum_number_j_ryd.is_finite()) {
543 0 : where += separator +
544 : fmt::format("exp_j_ryd BETWEEN {}-2*std_j_ryd AND {}+2*std_j_ryd",
545 0 : description.range_quantum_number_j_ryd.min(),
546 0 : description.range_quantum_number_j_ryd.max());
547 0 : separator = " AND ";
548 : }
549 27 : if (separator.empty()) {
550 3 : where += "FALSE";
551 : }
552 27 : where += ")";
553 27 : if (!additional_ket_ids.empty()) {
554 3 : where += fmt::format(" OR {} IN ({})", utils::SQL_TERM_FOR_LINEARIZED_ID_IN_DATABASE,
555 6 : fmt::join(additional_ket_ids, ","));
556 : }
557 :
558 : // Create a table containing the described states
559 27 : std::string id_of_kets;
560 : {
561 27 : auto result = con->Query(R"(SELECT UUID()::varchar)");
562 27 : if (result->HasError()) {
563 0 : throw cpptrace::runtime_error("Error selecting id_of_kets: " + result->GetError());
564 : }
565 27 : id_of_kets =
566 54 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
567 27 : }
568 : {
569 54 : auto result = con->Query(fmt::format(
570 : R"(CREATE TEMP TABLE '{}' AS SELECT *, {} AS ketid FROM (
571 : SELECT *,
572 : UNNEST(list_transform(generate_series(0,(2*f)::bigint),
573 : x -> x::double-f)) AS m FROM '{}'
574 : ) WHERE {})",
575 : id_of_kets, utils::SQL_TERM_FOR_LINEARIZED_ID_IN_DATABASE,
576 : manager->get_path(species, "states"), where));
577 :
578 27 : if (result->HasError()) {
579 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
580 : }
581 27 : }
582 :
583 : // Ask the table for the extreme values of the quantum numbers
584 : {
585 27 : std::string select;
586 27 : std::string separator;
587 27 : if (description.range_energy.is_finite()) {
588 0 : select += separator + "MIN(energy) AS min_energy, MAX(energy) AS max_energy";
589 0 : separator = ", ";
590 : }
591 27 : if (description.range_quantum_number_f.is_finite()) {
592 0 : select += separator + "MIN(f) AS min_f, MAX(f) AS max_f";
593 0 : separator = ", ";
594 : }
595 27 : if (description.range_quantum_number_m.is_finite()) {
596 5 : select += separator + "MIN(m) AS min_m, MAX(m) AS max_m";
597 5 : separator = ", ";
598 : }
599 27 : if (description.range_quantum_number_n.is_finite()) {
600 23 : select += separator + "MIN(n) AS min_n, MAX(n) AS max_n";
601 23 : separator = ", ";
602 : }
603 27 : if (description.range_quantum_number_nu.is_finite()) {
604 1 : select += separator + "MIN(nu) AS min_nu, MAX(nu) AS max_nu";
605 1 : separator = ", ";
606 : }
607 27 : if (description.range_quantum_number_nui.is_finite()) {
608 0 : select += separator + "MIN(exp_nui) AS min_nui, MAX(exp_nui) AS max_nui";
609 0 : separator = ", ";
610 : }
611 27 : if (description.range_quantum_number_l.is_finite()) {
612 24 : select += separator + "MIN(exp_l) AS min_l, MAX(exp_l) AS max_l";
613 24 : separator = ", ";
614 : }
615 27 : if (description.range_quantum_number_s.is_finite()) {
616 0 : select += separator + "MIN(exp_s) AS min_s, MAX(exp_s) AS max_s";
617 0 : separator = ", ";
618 : }
619 27 : if (description.range_quantum_number_j.is_finite()) {
620 0 : select += separator + "MIN(exp_j) AS min_j, MAX(exp_j) AS max_j";
621 0 : separator = ", ";
622 : }
623 27 : if (description.range_quantum_number_l_ryd.is_finite()) {
624 0 : select += separator + "MIN(exp_l_ryd) AS min_l_ryd, MAX(exp_l_ryd) AS max_l_ryd";
625 0 : separator = ", ";
626 : }
627 27 : if (description.range_quantum_number_j_ryd.is_finite()) {
628 0 : select += separator + "MIN(exp_j_ryd) AS min_j_ryd, MAX(exp_j_ryd) AS max_j_ryd";
629 0 : separator = ", ";
630 : }
631 :
632 27 : if (!separator.empty()) {
633 48 : auto result = con->Query(fmt::format(R"(SELECT {} FROM '{}')", select, id_of_kets));
634 :
635 24 : if (result->HasError()) {
636 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
637 : }
638 :
639 24 : auto chunk = result->Fetch();
640 :
641 130 : for (size_t i = 0; i < chunk->ColumnCount(); i++) {
642 106 : if (duckdb::FlatVector::IsNull(chunk->data[i], 0)) {
643 0 : throw std::invalid_argument("No state found.");
644 : }
645 : }
646 :
647 24 : size_t idx = 0;
648 24 : if (description.range_energy.is_finite()) {
649 0 : auto min_energy = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
650 0 : if (std::sqrt(-1 / (2 * min_energy)) - 1 >
651 0 : std::sqrt(-1 / (2 * description.range_energy.min()))) {
652 0 : SPDLOG_DEBUG("No state found with the requested minimum energy. Requested: {}, "
653 : "found: {}.",
654 : description.range_energy.min(), min_energy);
655 : }
656 0 : auto max_energy = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
657 0 : if (std::sqrt(-1 / (2 * max_energy)) + 1 <
658 0 : std::sqrt(-1 / (2 * description.range_energy.max()))) {
659 0 : SPDLOG_DEBUG("No state found with the requested maximum energy. Requested: {}, "
660 : "found: {}.",
661 : description.range_energy.max(), max_energy);
662 : }
663 : }
664 24 : if (description.range_quantum_number_f.is_finite()) {
665 0 : auto min_f = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
666 0 : if (min_f > description.range_quantum_number_f.min()) {
667 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number f. "
668 : "Requested: {}, found: {}.",
669 : description.range_quantum_number_f.min(), min_f);
670 : }
671 0 : auto max_f = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
672 0 : if (max_f < description.range_quantum_number_f.max()) {
673 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number f. "
674 : "Requested: {}, found: {}.",
675 : description.range_quantum_number_f.max(), max_f);
676 : }
677 : }
678 24 : if (description.range_quantum_number_m.is_finite()) {
679 5 : auto min_m = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
680 5 : if (min_m > description.range_quantum_number_m.min()) {
681 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number m. "
682 : "Requested: {}, found: {}.",
683 : description.range_quantum_number_m.min(), min_m);
684 : }
685 5 : auto max_m = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
686 5 : if (max_m < description.range_quantum_number_m.max()) {
687 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number m. "
688 : "Requested: {}, found: {}.",
689 : description.range_quantum_number_m.max(), max_m);
690 : }
691 : }
692 24 : if (description.range_quantum_number_n.is_finite()) {
693 23 : auto min_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[idx++])[0];
694 23 : if (min_n > description.range_quantum_number_n.min()) {
695 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number n. "
696 : "Requested: {}, found: {}.",
697 : description.range_quantum_number_n.min(), min_n);
698 : }
699 23 : auto max_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[idx++])[0];
700 23 : if (max_n < description.range_quantum_number_n.max()) {
701 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number n. "
702 : "Requested: {}, found: {}.",
703 : description.range_quantum_number_n.max(), max_n);
704 : }
705 : }
706 24 : if (description.range_quantum_number_nu.is_finite()) {
707 1 : auto min_nu = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
708 1 : if (min_nu - 1 > description.range_quantum_number_nu.min()) {
709 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number nu. "
710 : "Requested: {}, found: {}.",
711 : description.range_quantum_number_nu.min(), min_nu);
712 : }
713 1 : auto max_nu = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
714 1 : if (max_nu + 1 < description.range_quantum_number_nu.max()) {
715 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number nu. "
716 : "Requested: {}, found: {}.",
717 : description.range_quantum_number_nu.max(), max_nu);
718 : }
719 : }
720 24 : if (description.range_quantum_number_nui.is_finite()) {
721 0 : auto min_nui = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
722 0 : if (min_nui - 1 > description.range_quantum_number_nui.min()) {
723 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number nui. "
724 : "Requested: {}, found: {}.",
725 : description.range_quantum_number_nui.min(), min_nui);
726 : }
727 0 : auto max_nui = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
728 0 : if (max_nui + 1 < description.range_quantum_number_nui.max()) {
729 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number nui. "
730 : "Requested: {}, found: {}.",
731 : description.range_quantum_number_nui.max(), max_nui);
732 : }
733 : }
734 24 : if (description.range_quantum_number_l.is_finite()) {
735 24 : auto min_l = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
736 24 : if (min_l - 1 > description.range_quantum_number_l.min()) {
737 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number l. "
738 : "Requested: {}, found: {}.",
739 : description.range_quantum_number_l.min(), min_l);
740 : }
741 24 : auto max_l = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
742 24 : if (max_l + 1 < description.range_quantum_number_l.max()) {
743 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number l. "
744 : "Requested: {}, found: {}.",
745 : description.range_quantum_number_l.max(), max_l);
746 : }
747 : }
748 24 : if (description.range_quantum_number_s.is_finite()) {
749 0 : auto min_s = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
750 0 : if (min_s - 1 > description.range_quantum_number_s.min()) {
751 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number s. "
752 : "Requested: {}, found: {}.",
753 : description.range_quantum_number_s.min(), min_s);
754 : }
755 0 : auto max_s = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
756 0 : if (max_s + 1 < description.range_quantum_number_s.max()) {
757 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number s. "
758 : "Requested: {}, found: {}.",
759 : description.range_quantum_number_s.max(), max_s);
760 : }
761 : }
762 24 : if (description.range_quantum_number_j.is_finite()) {
763 0 : auto min_j = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
764 0 : if (min_j - 1 > description.range_quantum_number_j.min()) {
765 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number j. "
766 : "Requested: {}, found: {}.",
767 : description.range_quantum_number_j.min(), min_j);
768 : }
769 0 : auto max_j = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
770 0 : if (max_j + 1 < description.range_quantum_number_j.max()) {
771 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number j. "
772 : "Requested: {}, found: {}.",
773 : description.range_quantum_number_j.max(), max_j);
774 : }
775 : }
776 24 : if (description.range_quantum_number_l_ryd.is_finite()) {
777 0 : auto min_l_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
778 0 : if (min_l_ryd - 1 > description.range_quantum_number_l_ryd.min()) {
779 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number l_ryd. "
780 : "Requested: {}, found: {}.",
781 : description.range_quantum_number_l_ryd.min(), min_l_ryd);
782 : }
783 0 : auto max_l_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
784 0 : if (max_l_ryd + 1 < description.range_quantum_number_l_ryd.max()) {
785 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number l_ryd. "
786 : "Requested: {}, found: {}.",
787 : description.range_quantum_number_l_ryd.max(), max_l_ryd);
788 : }
789 : }
790 24 : if (description.range_quantum_number_j_ryd.is_finite()) {
791 0 : auto min_j_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
792 0 : if (min_j_ryd - 1 > description.range_quantum_number_j_ryd.min()) {
793 0 : SPDLOG_DEBUG("No state found with the requested minimum quantum number j_ryd. "
794 : "Requested: {}, found: {}.",
795 : description.range_quantum_number_j_ryd.min(), min_j_ryd);
796 : }
797 0 : auto max_j_ryd = duckdb::FlatVector::GetData<double>(chunk->data[idx++])[0];
798 0 : if (max_j_ryd + 1 < description.range_quantum_number_j_ryd.max()) {
799 0 : SPDLOG_DEBUG("No state found with the requested maximum quantum number j_ryd. "
800 : "Requested: {}, found: {}.",
801 : description.range_quantum_number_j_ryd.max(), max_j_ryd);
802 : }
803 : }
804 24 : }
805 27 : }
806 :
807 : // Ask the table for the described states
808 54 : auto result = con->Query(fmt::format(
809 : R"(SELECT energy, f, m, parity, ketid, n, nu, exp_nui, std_nui, exp_l, std_l,
810 : exp_s, std_s, exp_j, std_j, exp_l_ryd, std_l_ryd, exp_j_ryd, std_j_ryd, is_j_total_momentum, is_calculated_with_mqdt, underspecified_channel_contribution FROM '{}' ORDER BY ketid ASC)",
811 : id_of_kets));
812 :
813 27 : if (result->HasError()) {
814 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
815 : }
816 :
817 27 : if (result->RowCount() == 0) {
818 0 : throw std::invalid_argument("No state found.");
819 : }
820 :
821 : // Check the types of the columns
822 27 : const auto &types = result->types;
823 27 : const auto &labels = result->names;
824 621 : const std::vector<duckdb::LogicalType> ref_types = {
825 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
826 : duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT, duckdb::LogicalType::BIGINT,
827 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
828 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
829 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
830 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE, duckdb::LogicalType::DOUBLE,
831 : duckdb::LogicalType::DOUBLE, duckdb::LogicalType::BOOLEAN, duckdb::LogicalType::BOOLEAN,
832 : duckdb::LogicalType::DOUBLE};
833 :
834 621 : for (size_t i = 0; i < types.size(); i++) {
835 594 : if (types[i] != ref_types[i]) {
836 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'. Got " +
837 0 : types[i].ToString() + " but expected " +
838 0 : ref_types[i].ToString());
839 : }
840 : }
841 :
842 : // Construct the states
843 27 : std::vector<std::shared_ptr<const KetAtom>> kets;
844 27 : kets.reserve(result->RowCount());
845 : #ifndef NDEBUG
846 27 : double last_energy = std::numeric_limits<double>::lowest();
847 : #endif
848 :
849 54 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
850 :
851 27 : auto *chunk_energy = duckdb::FlatVector::GetData<double>(chunk->data[0]);
852 27 : auto *chunk_quantum_number_f = duckdb::FlatVector::GetData<double>(chunk->data[1]);
853 27 : auto *chunk_quantum_number_m = duckdb::FlatVector::GetData<double>(chunk->data[2]);
854 27 : auto *chunk_parity = duckdb::FlatVector::GetData<int64_t>(chunk->data[3]);
855 27 : auto *chunk_id = duckdb::FlatVector::GetData<int64_t>(chunk->data[4]);
856 27 : auto *chunk_quantum_number_n = duckdb::FlatVector::GetData<int64_t>(chunk->data[5]);
857 27 : auto *chunk_quantum_number_nu = duckdb::FlatVector::GetData<double>(chunk->data[6]);
858 27 : auto *chunk_quantum_number_nui_exp = duckdb::FlatVector::GetData<double>(chunk->data[7]);
859 27 : auto *chunk_quantum_number_nui_std = duckdb::FlatVector::GetData<double>(chunk->data[8]);
860 27 : auto *chunk_quantum_number_l_exp = duckdb::FlatVector::GetData<double>(chunk->data[9]);
861 27 : auto *chunk_quantum_number_l_std = duckdb::FlatVector::GetData<double>(chunk->data[10]);
862 27 : auto *chunk_quantum_number_s_exp = duckdb::FlatVector::GetData<double>(chunk->data[11]);
863 27 : auto *chunk_quantum_number_s_std = duckdb::FlatVector::GetData<double>(chunk->data[12]);
864 27 : auto *chunk_quantum_number_j_exp = duckdb::FlatVector::GetData<double>(chunk->data[13]);
865 27 : auto *chunk_quantum_number_j_std = duckdb::FlatVector::GetData<double>(chunk->data[14]);
866 27 : auto *chunk_quantum_number_l_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[15]);
867 27 : auto *chunk_quantum_number_l_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[16]);
868 27 : auto *chunk_quantum_number_j_ryd_exp = duckdb::FlatVector::GetData<double>(chunk->data[17]);
869 27 : auto *chunk_quantum_number_j_ryd_std = duckdb::FlatVector::GetData<double>(chunk->data[18]);
870 27 : auto *chunk_is_j_total_momentum = duckdb::FlatVector::GetData<bool>(chunk->data[19]);
871 27 : auto *chunk_is_calculated_with_mqdt = duckdb::FlatVector::GetData<bool>(chunk->data[20]);
872 : auto *chunk_underspecified_channel_contribution =
873 27 : duckdb::FlatVector::GetData<double>(chunk->data[21]);
874 :
875 1214 : for (size_t i = 0; i < chunk->size(); i++) {
876 :
877 : #ifndef NDEBUG
878 : // Check database consistency
879 1187 : if (chunk_is_j_total_momentum[i] &&
880 1131 : chunk_quantum_number_f[i] != chunk_quantum_number_j_exp[i]) {
881 0 : throw std::runtime_error("If j is the total momentum, f must be equal to j.");
882 : }
883 1187 : if (chunk_energy[i] < last_energy) {
884 0 : throw std::runtime_error("The states are not sorted by energy.");
885 : }
886 1187 : last_energy = chunk_energy[i];
887 : #endif
888 :
889 : // Append a new state
890 1187 : kets.push_back(std::make_shared<const KetAtom>(
891 0 : typename KetAtom::Private(), chunk_energy[i], chunk_quantum_number_f[i],
892 1187 : chunk_quantum_number_m[i], static_cast<Parity>(chunk_parity[i]), species,
893 1187 : chunk_quantum_number_n[i], chunk_quantum_number_nu[i],
894 1187 : chunk_quantum_number_nui_exp[i], chunk_quantum_number_nui_std[i],
895 1187 : chunk_quantum_number_l_exp[i], chunk_quantum_number_l_std[i],
896 1187 : chunk_quantum_number_s_exp[i], chunk_quantum_number_s_std[i],
897 1187 : chunk_quantum_number_j_exp[i], chunk_quantum_number_j_std[i],
898 1187 : chunk_quantum_number_l_ryd_exp[i], chunk_quantum_number_l_ryd_std[i],
899 1187 : chunk_quantum_number_j_ryd_exp[i], chunk_quantum_number_j_ryd_std[i],
900 1187 : chunk_is_j_total_momentum[i], chunk_is_calculated_with_mqdt[i],
901 1187 : chunk_underspecified_channel_contribution[i], *this, chunk_id[i]));
902 : }
903 : }
904 :
905 0 : return std::make_shared<const BasisAtom<Scalar>>(typename BasisAtom<Scalar>::Private(),
906 54 : std::move(kets), std::move(id_of_kets), *this);
907 54 : }
908 :
909 : template <typename Scalar>
910 : Eigen::SparseMatrix<Scalar, Eigen::RowMajor>
911 193 : Database::get_matrix_elements(std::shared_ptr<const BasisAtom<Scalar>> initial_basis,
912 : std::shared_ptr<const BasisAtom<Scalar>> final_basis,
913 : OperatorType type, int q) {
914 : using real_t = typename traits::NumTraits<Scalar>::real_t;
915 :
916 193 : std::string specifier;
917 194 : int kappa{};
918 194 : switch (type) {
919 186 : case OperatorType::ELECTRIC_DIPOLE:
920 186 : specifier = "matrix_elements_d";
921 186 : kappa = 1;
922 186 : break;
923 0 : case OperatorType::ELECTRIC_QUADRUPOLE:
924 0 : specifier = "matrix_elements_q";
925 0 : kappa = 2;
926 0 : break;
927 0 : case OperatorType::ELECTRIC_QUADRUPOLE_ZERO:
928 0 : specifier = "matrix_elements_q0";
929 0 : kappa = 0;
930 0 : break;
931 0 : case OperatorType::ELECTRIC_OCTUPOLE:
932 0 : specifier = "matrix_elements_o";
933 0 : kappa = 3;
934 0 : break;
935 0 : case OperatorType::MAGNETIC_DIPOLE:
936 0 : specifier = "matrix_elements_mu";
937 0 : kappa = 1;
938 0 : break;
939 4 : case OperatorType::ENERGY:
940 4 : specifier = "energy";
941 4 : kappa = 0;
942 4 : break;
943 4 : case OperatorType::IDENTITY:
944 4 : specifier = "identity";
945 4 : kappa = 0;
946 4 : break;
947 0 : default:
948 0 : throw std::invalid_argument("Unknown operator type.");
949 : }
950 :
951 194 : if (initial_basis->get_id_of_kets() != final_basis->get_id_of_kets()) {
952 0 : throw std::invalid_argument(
953 : "The initial and final basis must be expressed using the same kets.");
954 : }
955 194 : std::string id_of_kets = initial_basis->get_id_of_kets();
956 0 : std::string cache_key = fmt::format("{}_{}_{}", specifier, q, id_of_kets);
957 :
958 194 : if (!get_matrix_elements_cache().contains(cache_key)) {
959 57 : Eigen::Index dim = initial_basis->get_number_of_kets();
960 :
961 56 : std::vector<int> outerIndexPtr;
962 56 : std::vector<int> innerIndices;
963 56 : std::vector<real_t> values;
964 :
965 56 : if (specifier == "identity") {
966 1 : outerIndexPtr.reserve(dim + 1);
967 1 : innerIndices.reserve(dim);
968 1 : values.reserve(dim);
969 :
970 91 : for (int i = 0; i < dim; i++) {
971 90 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
972 90 : innerIndices.push_back(i);
973 90 : values.push_back(1);
974 : }
975 1 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
976 :
977 : } else {
978 : // Check that the specifications are valid
979 54 : if (std::abs(q) > kappa) {
980 0 : throw std::invalid_argument("Invalid q.");
981 : }
982 :
983 : // Ask the database for the operator
984 54 : std::string species = initial_basis->get_species();
985 55 : duckdb::unique_ptr<duckdb::MaterializedQueryResult> result;
986 55 : if (specifier != "energy") {
987 107 : result = con->Query(fmt::format(
988 : R"(WITH s AS (
989 : SELECT id, f, m, ketid FROM '{}'
990 : ),
991 : b AS (
992 : SELECT MIN(f) AS min_f, MAX(f) AS max_f,
993 : MIN(id) AS min_id, MAX(id) AS max_id
994 : FROM s
995 : ),
996 : w_filtered AS (
997 : SELECT *
998 : FROM '{}'
999 : WHERE kappa = {} AND q = {} AND
1000 : f_initial BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b) AND
1001 : f_final BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b)
1002 : ),
1003 : e_filtered AS (
1004 : SELECT *
1005 : FROM '{}'
1006 : WHERE
1007 : id_initial BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b) AND
1008 : id_final BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b)
1009 : )
1010 : SELECT
1011 : s2.ketid AS row,
1012 : s1.ketid AS col,
1013 : e.val*w.val AS val
1014 : FROM e_filtered AS e
1015 : JOIN s AS s1 ON e.id_initial = s1.id
1016 : JOIN s AS s2 ON e.id_final = s2.id
1017 : JOIN w_filtered AS w ON
1018 : w.f_initial = s1.f AND w.m_initial = s1.m AND
1019 : w.f_final = s2.f AND w.m_final = s2.m
1020 : ORDER BY row ASC, col ASC)",
1021 : id_of_kets, manager->get_path("misc", "wigner"), kappa, q,
1022 : manager->get_path(species, specifier)));
1023 : } else {
1024 4 : result = con->Query(fmt::format(
1025 : R"(SELECT ketid as row, ketid as col, energy as val FROM '{}' ORDER BY row ASC)",
1026 : id_of_kets));
1027 : }
1028 :
1029 56 : if (result->HasError()) {
1030 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
1031 : }
1032 :
1033 : // Check the types of the columns
1034 56 : const auto &types = result->types;
1035 56 : const auto &labels = result->names;
1036 224 : const std::vector<duckdb::LogicalType> ref_types = {duckdb::LogicalType::BIGINT,
1037 : duckdb::LogicalType::BIGINT,
1038 : duckdb::LogicalType::DOUBLE};
1039 224 : for (size_t i = 0; i < types.size(); i++) {
1040 168 : if (types[i] != ref_types[i]) {
1041 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'.");
1042 : }
1043 : }
1044 :
1045 : // Construct the matrix
1046 56 : int num_entries = static_cast<int>(result->RowCount());
1047 56 : outerIndexPtr.reserve(dim + 1);
1048 56 : innerIndices.reserve(num_entries);
1049 56 : values.reserve(num_entries);
1050 :
1051 56 : int last_row = -1;
1052 :
1053 112 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
1054 :
1055 56 : auto *chunk_row = duckdb::FlatVector::GetData<int64_t>(chunk->data[0]);
1056 56 : auto *chunk_col = duckdb::FlatVector::GetData<int64_t>(chunk->data[1]);
1057 56 : auto *chunk_val = duckdb::FlatVector::GetData<double>(chunk->data[2]);
1058 :
1059 19783 : for (size_t i = 0; i < chunk->size(); i++) {
1060 19727 : int row = final_basis->get_ket_index_from_id(chunk_row[i]);
1061 19727 : if (row != last_row) {
1062 2643 : if (row < last_row) {
1063 0 : throw std::runtime_error("The rows are not sorted.");
1064 : }
1065 5767 : for (; last_row < row; last_row++) {
1066 3124 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1067 : }
1068 : }
1069 19727 : innerIndices.push_back(initial_basis->get_ket_index_from_id(chunk_col[i]));
1070 19727 : values.push_back(chunk_val[i]);
1071 : }
1072 : }
1073 :
1074 215 : for (; last_row < dim + 1; last_row++) {
1075 159 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
1076 : }
1077 56 : }
1078 :
1079 57 : Eigen::Map<const Eigen::SparseMatrix<real_t, Eigen::RowMajor>> matrix_map(
1080 57 : dim, dim, values.size(), outerIndexPtr.data(), innerIndices.data(), values.data());
1081 :
1082 : // Cache the matrix
1083 57 : get_matrix_elements_cache()[cache_key] = matrix_map;
1084 57 : }
1085 :
1086 : // Construct the operator and return it
1087 194 : return final_basis->get_coefficients().adjoint() *
1088 317 : get_matrix_elements_cache()[cache_key].template cast<Scalar>() *
1089 653 : initial_basis->get_coefficients();
1090 250 : }
1091 :
1092 2 : bool Database::get_download_missing() const { return download_missing_; }
1093 :
1094 0 : bool Database::get_use_cache() const { return use_cache_; }
1095 :
1096 0 : std::filesystem::path Database::get_database_dir() const { return database_dir_; }
1097 :
1098 : oneapi::tbb::concurrent_unordered_map<std::string, Eigen::SparseMatrix<double, Eigen::RowMajor>> &
1099 445 : Database::get_matrix_elements_cache() {
1100 : static oneapi::tbb::concurrent_unordered_map<std::string,
1101 : Eigen::SparseMatrix<double, Eigen::RowMajor>>
1102 445 : matrix_elements_cache;
1103 445 : return matrix_elements_cache;
1104 : }
1105 :
1106 31 : Database &Database::get_global_instance() {
1107 62 : return get_global_instance_without_checks(default_download_missing, default_use_cache,
1108 62 : default_database_dir);
1109 : }
1110 :
1111 0 : Database &Database::get_global_instance(bool download_missing) {
1112 0 : Database &database = get_global_instance_without_checks(download_missing, default_use_cache,
1113 : default_database_dir);
1114 0 : if (download_missing != database.download_missing_) {
1115 0 : throw std::invalid_argument(
1116 0 : "The 'download_missing' argument must not change between calls to the method.");
1117 : }
1118 0 : return database;
1119 : }
1120 :
1121 0 : Database &Database::get_global_instance(std::filesystem::path database_dir) {
1122 0 : if (database_dir.empty()) {
1123 0 : database_dir = default_database_dir;
1124 : }
1125 0 : Database &database = get_global_instance_without_checks(default_download_missing,
1126 : default_use_cache, database_dir);
1127 0 : if (!std::filesystem::exists(database_dir) ||
1128 0 : std::filesystem::canonical(database_dir) != database.database_dir_) {
1129 0 : throw std::invalid_argument(
1130 0 : "The 'database_dir' argument must not change between calls to the method.");
1131 : }
1132 0 : return database;
1133 : }
1134 :
1135 1 : Database &Database::get_global_instance(bool download_missing, bool use_cache,
1136 : std::filesystem::path database_dir) {
1137 1 : if (database_dir.empty()) {
1138 0 : database_dir = default_database_dir;
1139 : }
1140 : Database &database =
1141 1 : get_global_instance_without_checks(download_missing, use_cache, database_dir);
1142 1 : if (download_missing != database.download_missing_ || use_cache != database.use_cache_ ||
1143 3 : !std::filesystem::exists(database_dir) ||
1144 2 : std::filesystem::canonical(database_dir) != database.database_dir_) {
1145 0 : throw std::invalid_argument(
1146 : "The 'download_missing', 'use_cache' and 'database_dir' arguments must not "
1147 0 : "change between calls to the method.");
1148 : }
1149 1 : return database;
1150 : }
1151 :
1152 32 : Database &Database::get_global_instance_without_checks(bool download_missing, bool use_cache,
1153 : std::filesystem::path database_dir) {
1154 32 : static Database database(download_missing, use_cache, std::move(database_dir));
1155 32 : return database;
1156 : }
1157 :
1158 : struct database_dir_noexcept : std::filesystem::path {
1159 7 : explicit database_dir_noexcept() noexcept try : std
1160 7 : ::filesystem::path(paths::get_cache_directory() / "database") {}
1161 0 : catch (...) {
1162 0 : SPDLOG_ERROR("Error getting the pairinteraction cache directory.");
1163 0 : std::terminate();
1164 7 : }
1165 : };
1166 :
1167 : const std::filesystem::path Database::default_database_dir = database_dir_noexcept();
1168 :
1169 : // Explicit instantiations
1170 : // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
1171 : #define INSTANTIATE_GETTERS(SCALAR) \
1172 : template std::shared_ptr<const BasisAtom<SCALAR>> Database::get_basis<SCALAR>( \
1173 : const std::string &species, const AtomDescriptionByRanges &description, \
1174 : std::vector<size_t> additional_ket_ids); \
1175 : template Eigen::SparseMatrix<SCALAR, Eigen::RowMajor> Database::get_matrix_elements<SCALAR>( \
1176 : std::shared_ptr<const BasisAtom<SCALAR>> initial_basis, \
1177 : std::shared_ptr<const BasisAtom<SCALAR>> final_basis, OperatorType type, int q);
1178 : // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
1179 :
1180 : INSTANTIATE_GETTERS(double)
1181 : INSTANTIATE_GETTERS(std::complex<double>)
1182 :
1183 : #undef INSTANTIATE_GETTERS
1184 : } // namespace pairinteraction
|