Line data Source code
1 : // SPDX-FileCopyrightText: 2024 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/Database.hpp"
5 :
6 : #include "pairinteraction/basis/BasisAtom.hpp"
7 : #include "pairinteraction/database/AtomDescriptionByParameters.hpp"
8 : #include "pairinteraction/database/AtomDescriptionByRanges.hpp"
9 : #include "pairinteraction/database/GitHubDownloader.hpp"
10 : #include "pairinteraction/database/ParquetManager.hpp"
11 : #include "pairinteraction/enums/OperatorType.hpp"
12 : #include "pairinteraction/ket/KetAtom.hpp"
13 : #include "pairinteraction/utils/TaskControl.hpp"
14 : #include "pairinteraction/utils/hash.hpp"
15 : #include "pairinteraction/utils/ket_id.hpp"
16 : #include "pairinteraction/utils/paths.hpp"
17 : #include "pairinteraction/utils/streamed.hpp"
18 :
19 : #include <algorithm>
20 : #include <cpptrace/cpptrace.hpp>
21 : #include <duckdb.hpp>
22 : #include <fmt/core.h>
23 : #include <fmt/format.h>
24 : #include <fmt/ranges.h>
25 : #include <fstream>
26 : #include <iterator>
27 : #include <nlohmann/json.hpp>
28 : #include <oneapi/tbb.h>
29 : #include <spdlog/spdlog.h>
30 : #include <string>
31 : #include <system_error>
32 : #include <unordered_map>
33 : #include <unordered_set>
34 :
35 : namespace pairinteraction {
36 :
37 : namespace {
38 :
39 315 : std::string format_expectation_value_range(const std::string &value_column,
40 : const std::string &std_column,
41 : const Range<double> &range,
42 : double standard_deviation_factor) {
43 0 : return fmt::format("{} BETWEEN {}-{}*{} AND {}+{}*{}", value_column, range.min(),
44 315 : standard_deviation_factor, std_column, range.max(),
45 630 : standard_deviation_factor, std_column);
46 : }
47 :
48 : // Find the index of the result column with the given name.
49 2903 : size_t get_column_index(const std::vector<std::string> &names, const std::string &name) {
50 2903 : auto it = std::find(names.begin(), names.end(), name);
51 2903 : if (it == names.end()) {
52 0 : throw std::runtime_error("Missing database column '" + name + "'.");
53 : }
54 2903 : return static_cast<size_t>(std::distance(names.begin(), it));
55 : }
56 :
57 : // Read a single value of a duckdb result column as a double, regardless of its logical type.
58 845864 : double get_entry_as_double(duckdb::Vector &vector, const duckdb::LogicalType &type, size_t row) {
59 845864 : switch (type.id()) {
60 684340 : case duckdb::LogicalTypeId::DOUBLE:
61 684340 : return duckdb::FlatVector::GetData<double>(vector)[row];
62 81074 : case duckdb::LogicalTypeId::BIGINT:
63 81074 : return static_cast<double>(duckdb::FlatVector::GetData<int64_t>(vector)[row]);
64 80450 : case duckdb::LogicalTypeId::BOOLEAN:
65 80450 : return duckdb::FlatVector::GetData<bool>(vector)[row] ? 1.0 : 0.0;
66 0 : default:
67 0 : throw std::runtime_error("Cannot read database column of type " + type.ToString() +
68 0 : " as a quantum number.");
69 : }
70 : }
71 :
72 : struct QuantumNumbers {
73 : std::unordered_map<std::string, double> values;
74 : std::unordered_map<std::string, double> stds;
75 : };
76 :
77 40225 : QuantumNumbers get_quantum_numbers_from_row(duckdb::DataChunk &chunk,
78 : const std::vector<duckdb::LogicalType> &types,
79 : const std::vector<std::string> &names,
80 : const std::unordered_set<std::string> &excluded_columns,
81 : size_t row) {
82 40225 : QuantumNumbers quantum_numbers;
83 964975 : for (size_t col = 0; col < names.size(); ++col) {
84 924750 : const std::string &name = names[col];
85 924750 : if (excluded_columns.contains(name)) {
86 120675 : continue;
87 : }
88 804075 : double value = get_entry_as_double(chunk.data[col], types[col], row);
89 804075 : if (name.starts_with("std_")) {
90 241350 : quantum_numbers.stds[name.substr(4)] = value;
91 562725 : } else if (name.starts_with("exp_")) {
92 241350 : quantum_numbers.values[name.substr(4)] = value;
93 : } else {
94 321375 : quantum_numbers.values[name] = value;
95 : }
96 : }
97 40225 : return quantum_numbers;
98 0 : }
99 : } // namespace
100 :
101 40223 : void ensure_consistent_quantum_numbers(double quantum_number_f, double quantum_number_m) {
102 40223 : if (2 * quantum_number_m != std::rint(2 * quantum_number_m)) {
103 0 : throw std::runtime_error("The quantum number m must be an integer or half-integer.");
104 : }
105 40223 : if (2 * quantum_number_f != std::rint(2 * quantum_number_f)) {
106 0 : throw std::runtime_error("The quantum number f must be an integer or half-integer.");
107 : }
108 40223 : if (quantum_number_f + quantum_number_m != std::rint(quantum_number_f + quantum_number_m)) {
109 0 : throw std::invalid_argument(
110 0 : "The quantum numbers f and m must be both either integers or half-integers.");
111 : }
112 40223 : if (std::abs(quantum_number_m) > quantum_number_f) {
113 1 : throw std::invalid_argument(
114 2 : "The absolute value of the quantum number m must be less than or equal to f.");
115 : }
116 40222 : }
117 :
118 0 : Database::Database() : Database(default_download_missing) {}
119 :
120 0 : Database::Database(bool download_missing)
121 0 : : Database(download_missing, default_use_cache, default_database_dir) {}
122 :
123 0 : Database::Database(std::filesystem::path database_dir)
124 0 : : Database(default_download_missing, default_use_cache, std::move(database_dir)) {}
125 :
126 3 : Database::Database(bool download_missing, bool use_cache, std::filesystem::path database_dir)
127 3 : : download_missing_(download_missing), use_cache_(use_cache),
128 3 : database_dir_(std::move(database_dir)), db(std::make_unique<duckdb::DuckDB>(nullptr)),
129 15 : con(std::make_unique<duckdb::Connection>(*db)) {
130 :
131 3 : if (database_dir_.empty()) {
132 0 : database_dir_ = default_database_dir;
133 : }
134 :
135 : // Ensure the database directory exists
136 3 : if (!std::filesystem::exists(database_dir_)) {
137 0 : std::filesystem::create_directories(database_dir_);
138 : }
139 3 : database_dir_ = std::filesystem::canonical(database_dir_);
140 3 : if (!std::filesystem::is_directory(database_dir_)) {
141 0 : throw std::filesystem::filesystem_error("Cannot access database", database_dir_.string(),
142 0 : std::make_error_code(std::errc::not_a_directory));
143 : }
144 3 : SPDLOG_INFO("Using database directory: {}", database_dir_.string());
145 :
146 : // Ensure that the config directory exists
147 3 : std::filesystem::path configdir = paths::get_config_directory();
148 3 : if (!std::filesystem::exists(configdir)) {
149 1 : std::filesystem::create_directories(configdir);
150 2 : } else if (!std::filesystem::is_directory(configdir)) {
151 0 : throw std::filesystem::filesystem_error("Cannot access config directory ",
152 0 : configdir.string(),
153 0 : std::make_error_code(std::errc::not_a_directory));
154 : }
155 :
156 : // Read in the database_repo_paths if a config file exists, otherwise use the default and
157 : // write it to the config file
158 3 : std::filesystem::path configfile = configdir / "database.json";
159 3 : std::string database_repo_host;
160 3 : std::vector<std::string> database_repo_paths;
161 3 : if (std::filesystem::exists(configfile)) {
162 2 : std::ifstream file(configfile);
163 2 : nlohmann::json doc = nlohmann::json::parse(file, nullptr, false);
164 :
165 4 : if (!doc.is_discarded() && doc.contains("hash") && doc.contains("database_repo_host") &&
166 2 : doc.contains("database_repo_paths")) {
167 2 : database_repo_host = doc["database_repo_host"].get<std::string>();
168 2 : database_repo_paths = doc["database_repo_paths"].get<std::vector<std::string>>();
169 :
170 : // If the values are not equal to the default values but the hash is consistent (i.e.,
171 : // the user has not changed anything manually), clear the values so that they can be
172 : // updated
173 4 : if (database_repo_host != default_database_repo_host ||
174 2 : database_repo_paths != default_database_repo_paths) {
175 0 : std::size_t seed = 0;
176 0 : utils::hash_combine(seed, database_repo_paths);
177 0 : utils::hash_combine(seed, database_repo_host);
178 0 : if (seed == doc["hash"].get<std::size_t>()) {
179 0 : database_repo_host.clear();
180 0 : database_repo_paths.clear();
181 : } else {
182 0 : SPDLOG_INFO("The database repository host and paths have been changed "
183 : "manually. Thus, they will not be updated automatically. To reset "
184 : "them, delete the file '{}'.",
185 : configfile.string());
186 : }
187 : }
188 : }
189 2 : }
190 :
191 : // Read in and store the default values if necessary
192 3 : if (database_repo_host.empty() || database_repo_paths.empty()) {
193 2 : SPDLOG_INFO("Updating the database repository host and paths:");
194 :
195 1 : database_repo_host = default_database_repo_host;
196 1 : database_repo_paths = default_database_repo_paths;
197 1 : std::ofstream file(configfile);
198 1 : nlohmann::json doc;
199 :
200 1 : SPDLOG_INFO("* New host: {}", default_database_repo_host);
201 2 : SPDLOG_INFO("* New paths: {}", fmt::join(default_database_repo_paths, ", "));
202 :
203 1 : doc["database_repo_host"] = default_database_repo_host;
204 1 : doc["database_repo_paths"] = database_repo_paths;
205 :
206 1 : std::size_t seed = 0;
207 1 : utils::hash_combine(seed, default_database_repo_paths);
208 1 : utils::hash_combine(seed, default_database_repo_host);
209 1 : doc["hash"] = seed;
210 :
211 1 : file << doc.dump(4);
212 1 : }
213 :
214 : // Limit the memory usage of duckdb's buffer manager
215 : {
216 3 : auto result = con->Query("PRAGMA max_memory = '8GB';");
217 3 : if (result->HasError()) {
218 0 : throw cpptrace::runtime_error("Error setting the memory limit: " + result->GetError());
219 : }
220 3 : }
221 :
222 : // Instantiate a database manager that provides access to database tables. If a table
223 : // is outdated/not available locally, it will be downloaded if download_missing_ is true.
224 3 : if (!download_missing_) {
225 3 : database_repo_paths.clear();
226 : }
227 3 : downloader = std::make_unique<GitHubDownloader>();
228 6 : manager = std::make_unique<ParquetManager>(database_dir_, *downloader, database_repo_paths,
229 6 : *con, use_cache_);
230 3 : manager->scan_local();
231 3 : manager->scan_remote();
232 :
233 : // Print versions of tables
234 3 : std::istringstream iss(manager->get_versions_info());
235 36 : for (std::string line; std::getline(iss, line);) {
236 33 : SPDLOG_INFO(line);
237 3 : }
238 9 : }
239 :
240 2 : Database::~Database() = default;
241 :
242 1457 : const std::unordered_set<std::string> &Database::get_column_names(const std::string &table_path) {
243 1457 : if (auto it = column_names_cache.find(table_path); it != column_names_cache.end()) {
244 1445 : return it->second;
245 : }
246 24 : auto result = con->Query(fmt::format(R"(SELECT * FROM '{}' LIMIT 0)", table_path));
247 12 : if (result->HasError()) {
248 0 : throw cpptrace::runtime_error("Error querying the database columns: " + result->GetError());
249 : }
250 12 : std::unordered_set<std::string> names(result->names.begin(), result->names.end());
251 :
252 : // Every states table must provide the columns required for constructing kets and must not
253 : // contain an 'm' column, since m added separately below.
254 48 : for (const auto &required : {"id", "f", "energy"}) {
255 36 : if (!names.contains(required)) {
256 0 : throw std::runtime_error(
257 0 : fmt::format("The database table '{}' is missing the required column '{}'.",
258 0 : table_path, required));
259 : }
260 : }
261 12 : if (names.contains("m")) {
262 0 : throw std::runtime_error(
263 0 : fmt::format("The database table '{}' must not contain a column 'm'.", table_path));
264 : }
265 :
266 : // If another thread inserted the same entry concurrently, insert keeps the existing value and
267 : // returns an iterator to it, so the reference stays valid (elements are never erased).
268 12 : return column_names_cache.insert({table_path, std::move(names)}).first->second;
269 12 : }
270 :
271 431 : std::shared_ptr<const KetAtom> Database::get_ket(const std::string &species,
272 : const AtomDescriptionByParameters &description) {
273 : // Check that the specifications are valid
274 431 : if (!description.quantum_numbers.contains("m")) {
275 0 : throw std::invalid_argument("The quantum number m must be specified.");
276 : }
277 2131 : for (const auto &[name, value] : description.quantum_numbers) {
278 1700 : if ((name == "f" || name == "m") && 2 * value != std::rint(2 * value)) {
279 0 : throw std::invalid_argument("The quantum number " + name +
280 0 : " must be an integer or half-integer.");
281 : }
282 1700 : if (name == "f" && value < 0) {
283 0 : throw std::invalid_argument("The quantum number " + name + " must be positive.");
284 : }
285 : }
286 :
287 431 : const auto &columns = get_column_names(manager->get_path(species, "states"));
288 :
289 : // Describe the state. The quantum numbers n, f and parity are matched exactly, while all other
290 : // quantum numbers are matched within a +-0.5 window (they can deviate from the requested value,
291 : // e.g. expectation values in MQDT). The result is ordered by the distance to the requested
292 : // values, so the nearest state is returned.
293 431 : std::string where;
294 431 : std::string where_separator;
295 431 : std::string orderby;
296 431 : std::string orderby_separator;
297 431 : if (description.energy.has_value()) {
298 : // The following condition derives from demanding that quantum number n that corresponds to
299 : // the energy "E_n = -1/(2*n^2)" is not off by more than 1 from the actual quantum number n,
300 : // i.e., "sqrt(-1/(2*E_n)) - sqrt(-1/(2*E_{n-1})) = 1"
301 0 : double n_from_energy = std::sqrt(-1 / (2 * description.energy.value()));
302 0 : where += where_separator +
303 0 : fmt::format("SQRT(-1/(2*energy)) BETWEEN {} AND {}", n_from_energy - 0.5,
304 0 : n_from_energy + 0.5);
305 0 : where_separator = " AND ";
306 0 : orderby += orderby_separator + fmt::format("(SQRT(-1/(2*energy)) - {})^2", n_from_energy);
307 0 : orderby_separator = " + ";
308 : }
309 2129 : for (const auto &[name, value] : description.quantum_numbers) {
310 1699 : if (name == "m") {
311 431 : continue; // m is encoded into the id, not stored as a queryable column
312 : }
313 1268 : std::string column = columns.contains("exp_" + name) ? "exp_" + name : name;
314 1268 : if (!columns.contains(column)) {
315 2 : throw std::invalid_argument(
316 2 : fmt::format("The quantum number '{}' is not stored in the database table for "
317 : "species '{}'.",
318 2 : name, species));
319 : }
320 1267 : double tolerance = (name == "n" || name == "f" || name == "parity") ? 0.0 : 0.5;
321 1267 : where += where_separator +
322 3801 : fmt::format("{} BETWEEN {} AND {}", column, value - tolerance, value + tolerance);
323 1267 : where_separator = " AND ";
324 2534 : orderby += orderby_separator + fmt::format("({} - {})^2", column, value);
325 1267 : orderby_separator = " + ";
326 1268 : }
327 430 : if (where_separator.empty()) {
328 0 : where += "FALSE";
329 : }
330 430 : if (orderby_separator.empty()) {
331 0 : orderby += "id";
332 : }
333 :
334 : // Ask the database for the described state
335 430 : set_task_status("Loading atomic ket from database...");
336 860 : auto result = con->Query(fmt::format(
337 : R"(SELECT *, {} AS order_val FROM '{}' WHERE {} ORDER BY order_val ASC LIMIT 2)", orderby,
338 1290 : manager->get_path(species, "states"), where));
339 :
340 430 : if (result->HasError()) {
341 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
342 : }
343 :
344 430 : if (result->RowCount() == 0) {
345 6 : throw std::invalid_argument("No state found.");
346 : }
347 :
348 : // Get the first chunk of the results (the first chunk is sufficient as we need two rows at
349 : // most). Every column except energy, id and the synthetic order_val is treated as a quantum
350 : // number; m is not a database column and is injected from the description.
351 424 : const auto &types = result->types;
352 424 : const auto &names = result->names;
353 1696 : const std::unordered_set<std::string> excluded_columns = {"energy", "id", "order_val"};
354 424 : auto chunk = result->Fetch();
355 :
356 424 : size_t energy_column = get_column_index(names, "energy");
357 424 : size_t id_column = get_column_index(names, "id");
358 424 : double quantum_number_m = description.quantum_numbers.at("m");
359 :
360 425 : auto make_ket = [&](size_t row) {
361 : auto quantum_numbers =
362 425 : get_quantum_numbers_from_row(*chunk, types, names, excluded_columns, row);
363 425 : quantum_numbers.values["m"] = quantum_number_m;
364 425 : double energy = get_entry_as_double(chunk->data[energy_column], types[energy_column], row);
365 : auto id =
366 850 : utils::encode_as_ket_id({.id = static_cast<size_t>(duckdb::FlatVector::GetData<int64_t>(
367 425 : chunk->data[id_column])[row]),
368 425 : .m = quantum_number_m});
369 425 : return KetAtom(typename KetAtom::Private(), energy, species,
370 850 : std::move(quantum_numbers.values), std::move(quantum_numbers.stds), *this,
371 1275 : id);
372 425 : };
373 :
374 : // Check that the ket is uniquely specified
375 424 : if (chunk->size() > 1) {
376 5 : size_t order_val_column = get_column_index(names, "order_val");
377 : auto order_val_0 =
378 5 : get_entry_as_double(chunk->data[order_val_column], types[order_val_column], 0);
379 : auto order_val_1 =
380 5 : get_entry_as_double(chunk->data[order_val_column], types[order_val_column], 1);
381 :
382 5 : if (order_val_1 - order_val_0 <= order_val_0) {
383 2 : throw std::invalid_argument(
384 2 : fmt::format("The ket is not uniquely specified. Possible kets are:\n{}\n{}",
385 2 : fmt::streamed(make_ket(0)), fmt::streamed(make_ket(1))));
386 : }
387 : }
388 :
389 : // Construct the state
390 423 : auto ket = make_ket(0);
391 :
392 : // Check database consistency
393 425 : ensure_consistent_quantum_numbers(ket.get_quantum_number("f"), quantum_number_m);
394 :
395 844 : return std::make_shared<const KetAtom>(std::move(ket));
396 1319 : }
397 :
398 : template <typename Scalar>
399 : std::shared_ptr<const BasisAtom<Scalar>>
400 1026 : Database::get_basis(const std::string &species, const AtomDescriptionByRanges &description,
401 : const std::vector<size_t> &additional_ket_ids) {
402 : // The quantum number m is restricted separately because it is generated by UNNEST below.
403 2052 : auto range_quantum_number_m = [&description]() {
404 1026 : auto it = description.quantum_number_ranges.find("m");
405 2052 : return it != description.quantum_number_ranges.end() ? it->second : Range<double>{};
406 1026 : }();
407 :
408 1026 : const auto &columns = get_column_names(manager->get_path(species, "states"));
409 :
410 : // Describe the states by all restrictions that do not involve the quantum number m
411 1026 : std::string where = "(";
412 1026 : std::string separator;
413 1026 : if (description.range_energy.is_finite()) {
414 43 : where += separator +
415 43 : fmt::format("energy BETWEEN {} AND {}", description.range_energy.min(),
416 43 : description.range_energy.max());
417 43 : separator = " AND ";
418 : }
419 2404 : for (const auto &[name, range] : description.quantum_number_ranges) {
420 735 : if (name == "m" || !range.is_finite()) {
421 92 : continue;
422 : }
423 643 : std::string exp_column = "exp_" + name;
424 643 : std::string std_column = "std_" + name;
425 643 : if (columns.contains(exp_column) && columns.contains(std_column)) {
426 315 : where += separator +
427 : format_expectation_value_range(
428 : exp_column, std_column, range,
429 315 : description.quantum_number_standard_deviation_factor);
430 328 : } else if (columns.contains(name)) {
431 327 : where +=
432 654 : separator + fmt::format("{} BETWEEN {} AND {}", name, range.min(), range.max());
433 : } else {
434 2 : throw std::invalid_argument(
435 : fmt::format("The quantum number '{}' is not stored in the database table for "
436 : "species '{}'.",
437 : name, species));
438 : }
439 642 : separator = " AND ";
440 : }
441 1025 : if (separator.empty()) {
442 : // If the description contains no restrictions at all, it describes no states
443 702 : where += range_quantum_number_m.is_finite() ? "TRUE" : "FALSE";
444 : }
445 1025 : where += ")";
446 :
447 : // Describe the restriction of the quantum number m
448 1025 : std::string where_m = "(";
449 1025 : if (range_quantum_number_m.is_finite()) {
450 184 : where_m += fmt::format("m BETWEEN {} AND {}", range_quantum_number_m.min(),
451 184 : range_quantum_number_m.max());
452 : } else {
453 933 : where_m += "TRUE";
454 : }
455 1025 : where_m += ")";
456 :
457 : // Create a table containing the described states
458 1025 : std::string canonical_basis_id;
459 : {
460 1025 : auto result = con->Query(R"(SELECT UUID()::varchar)");
461 1025 : if (result->HasError()) {
462 0 : throw cpptrace::runtime_error("Error selecting canonical_basis_id: " +
463 0 : result->GetError());
464 : }
465 1025 : canonical_basis_id =
466 2050 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
467 1025 : }
468 : {
469 1025 : set_task_status("Selecting atomic basis states...");
470 3075 : auto result = con->Query(fmt::format(
471 : R"(CREATE TEMP TABLE '{}' AS SELECT *, id*{}+(2*m+{})::bigint AS ketid FROM (
472 : SELECT *,
473 : UNNEST(list_transform(generate_series(0,(2*f)::bigint),
474 : x -> x::double-f)) AS m FROM (
475 : SELECT * FROM '{}' WHERE {}
476 : )
477 : ) WHERE {})",
478 : canonical_basis_id, utils::KET_ID_STRIDE, utils::M_OFFSET,
479 2050 : manager->get_path(species, "states"), where, where_m));
480 :
481 1025 : if (result->HasError()) {
482 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
483 : }
484 1025 : }
485 : {
486 2050 : auto result = con->Query(
487 : fmt::format(R"(ALTER TABLE '{}' ADD PRIMARY KEY (ketid))", canonical_basis_id));
488 :
489 1025 : if (result->HasError()) {
490 0 : throw cpptrace::runtime_error("Error adding primary key: " + result->GetError());
491 : }
492 1025 : }
493 :
494 : // Add the additional kets to the table if they are not already contained in it
495 1025 : if (!additional_ket_ids.empty()) {
496 750 : set_task_status("Selecting additional kets...");
497 750 : std::vector<std::string> values;
498 750 : values.reserve(additional_ket_ids.size());
499 1545 : for (size_t ket_id : additional_ket_ids) {
500 795 : auto [id, m] = utils::decode_from_ket_id(ket_id);
501 1590 : values.push_back(fmt::format("({}, {}::double, {})", id, m, ket_id));
502 : }
503 3000 : auto result = con->Query(fmt::format(
504 : R"(INSERT OR IGNORE INTO '{0}'
505 : SELECT s.*, v.m, v.ketid FROM '{1}' AS s
506 : JOIN (VALUES {2}) AS v(id, m, ketid) ON s.id = v.id)",
507 2250 : canonical_basis_id, manager->get_path(species, "states"), fmt::join(values, ",")));
508 :
509 750 : if (result->HasError()) {
510 0 : throw cpptrace::runtime_error("Error adding additional kets: " + result->GetError());
511 : }
512 750 : }
513 :
514 : // Ask the table for the extreme values of the quantum numbers
515 : {
516 1025 : set_task_status("Validating atomic basis coverage...");
517 :
518 : // Collect the finite quantum-number ranges to validate against the loaded basis. Energy is
519 : // handled separately because it is compared via its corresponding effective quantum number.
520 : struct CoverageCheck {
521 : std::string name;
522 : std::string column;
523 : double tolerance{}; // allowed slack between the requested and the available range
524 : Range<double> range;
525 : };
526 1025 : std::vector<CoverageCheck> checks;
527 1759 : for (const auto &[name, range] : description.quantum_number_ranges) {
528 734 : if (!range.is_finite()) {
529 0 : continue;
530 : }
531 734 : bool is_expectation_value =
532 734 : columns.contains("exp_" + name) && columns.contains("std_" + name);
533 734 : std::string column = is_expectation_value ? "exp_" + name : name;
534 734 : double tolerance =
535 734 : (name == "n" || name == "f" || name == "m" || name == "parity") ? 0.0 : 1.0;
536 734 : checks.push_back({name, column, tolerance, range});
537 : }
538 :
539 1025 : std::string select;
540 1025 : std::string separator;
541 1025 : if (description.range_energy.is_finite()) {
542 43 : select += separator + "MIN(energy) AS min_energy, MAX(energy) AS max_energy";
543 43 : separator = ", ";
544 : }
545 1759 : for (const auto &check : checks) {
546 734 : select += separator +
547 734 : fmt::format("MIN({0}) AS min_{1}, MAX({0}) AS max_{1}", check.column, check.name);
548 734 : separator = ", ";
549 : }
550 :
551 1025 : if (!separator.empty()) {
552 646 : auto result =
553 : con->Query(fmt::format(R"(SELECT {} FROM '{}')", select, canonical_basis_id));
554 :
555 323 : if (result->HasError()) {
556 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
557 : }
558 :
559 323 : auto chunk = result->Fetch();
560 323 : const auto &types = result->types;
561 :
562 1877 : for (size_t i = 0; i < chunk->ColumnCount(); i++) {
563 1554 : if (duckdb::FlatVector::IsNull(chunk->data[i], 0)) {
564 0 : throw std::invalid_argument("No state found.");
565 : }
566 : }
567 :
568 323 : size_t idx = 0;
569 323 : if (description.range_energy.is_finite()) {
570 43 : auto min_energy = get_entry_as_double(chunk->data[idx], types[idx], 0);
571 43 : idx++;
572 86 : if (std::sqrt(-1 / (2 * min_energy)) - 1 >
573 43 : std::sqrt(-1 / (2 * description.range_energy.min()))) {
574 0 : SPDLOG_DEBUG("No state found with the requested minimum energy. Requested: {}, "
575 : "found: {}.",
576 : description.range_energy.min(), min_energy);
577 : }
578 43 : auto max_energy = get_entry_as_double(chunk->data[idx], types[idx], 0);
579 43 : idx++;
580 86 : if (std::sqrt(-1 / (2 * max_energy)) + 1 <
581 43 : std::sqrt(-1 / (2 * description.range_energy.max()))) {
582 0 : SPDLOG_DEBUG("No state found with the requested maximum energy. Requested: {}, "
583 : "found: {}.",
584 : description.range_energy.max(), max_energy);
585 : }
586 : }
587 1057 : for (const auto &check : checks) {
588 734 : auto min_value = get_entry_as_double(chunk->data[idx], types[idx], 0);
589 734 : idx++;
590 734 : if (min_value - check.tolerance > check.range.min()) {
591 103 : SPDLOG_DEBUG("No state found with the requested minimum quantum number {}. "
592 : "Requested: {}, found: {}.",
593 : check.name, check.range.min(), min_value);
594 : }
595 734 : auto max_value = get_entry_as_double(chunk->data[idx], types[idx], 0);
596 734 : idx++;
597 734 : if (max_value + check.tolerance < check.range.max()) {
598 38 : SPDLOG_DEBUG("No state found with the requested maximum quantum number {}. "
599 : "Requested: {}, found: {}.",
600 : check.name, check.range.max(), max_value);
601 : }
602 : }
603 323 : }
604 1025 : }
605 :
606 : // Ask the table for the described states
607 1025 : set_task_status("Loading atomic basis states...");
608 2050 : auto result =
609 : con->Query(fmt::format(R"(SELECT * FROM '{}' ORDER BY ketid ASC)", canonical_basis_id));
610 :
611 1025 : if (result->HasError()) {
612 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
613 : }
614 :
615 1025 : if (result->RowCount() == 0) {
616 0 : throw std::invalid_argument("No state found.");
617 : }
618 :
619 : // Construct the states. Every column except energy and the raw/encoded id is treated as a
620 : // quantum number ("id" is the raw states-table id, "ketid" the m-encoded id of the basis
621 : // state).
622 1025 : const auto &types = result->types;
623 1025 : const auto &names = result->names;
624 4100 : const std::unordered_set<std::string> excluded_columns = {"energy", "id", "ketid"};
625 1025 : size_t energy_column = get_column_index(names, "energy");
626 1025 : size_t ketid_column = get_column_index(names, "ketid");
627 :
628 1025 : std::vector<std::shared_ptr<const KetAtom>> kets;
629 1025 : kets.reserve(result->RowCount());
630 1025 : double last_energy = std::numeric_limits<double>::lowest();
631 1025 : double min_quantum_number_nu = std::numeric_limits<double>::max();
632 :
633 2050 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
634 1025 : set_task_status("Constructing atomic basis...");
635 :
636 40825 : for (size_t i = 0; i < chunk->size(); i++) {
637 79600 : auto quantum_numbers =
638 39800 : get_quantum_numbers_from_row(*chunk, types, names, excluded_columns, i);
639 39800 : double energy =
640 39800 : get_entry_as_double(chunk->data[energy_column], types[energy_column], i);
641 39800 : auto id = static_cast<size_t>(
642 39800 : duckdb::FlatVector::GetData<int64_t>(chunk->data[ketid_column])[i]);
643 :
644 : // Check database consistency
645 39800 : ensure_consistent_quantum_numbers(quantum_numbers.values.at("f"),
646 79600 : quantum_numbers.values.at("m"));
647 39800 : if (energy < last_energy) {
648 0 : throw std::runtime_error("The states are not sorted by energy.");
649 : }
650 39800 : last_energy = energy;
651 :
652 39800 : if (auto it = quantum_numbers.values.find("nu"); it != quantum_numbers.values.end()) {
653 39800 : min_quantum_number_nu = std::min(min_quantum_number_nu, it->second);
654 : }
655 :
656 : // Append a new state
657 39800 : kets.push_back(std::make_shared<const KetAtom>(
658 79600 : typename KetAtom::Private(), energy, species, std::move(quantum_numbers.values),
659 39800 : std::move(quantum_numbers.stds), *this, id));
660 : }
661 : }
662 :
663 : // Show a warning for low-lying states
664 1025 : if (min_quantum_number_nu < 25) {
665 36 : if (species.ends_with("_mqdt")) {
666 16 : SPDLOG_WARN("The multi-channel quantum defect theory might produce inaccurate results "
667 : "for effective principal quantum numbers < 25. The models get increasingly "
668 : "unreliable for small principal quantum numbers, leading to inaccurate "
669 : "matrix elements and energies. Due to missing data, even some states might "
670 : "not be present.");
671 : } else {
672 56 : SPDLOG_WARN(
673 : "The single-channel quantum defect theory can be inaccurate for effective "
674 : "principal quantum numbers < 25. This can lead to inaccurate matrix elements.");
675 : }
676 : }
677 :
678 0 : return std::make_shared<const BasisAtom<Scalar>>(typename BasisAtom<Scalar>::Private(),
679 1025 : std::move(kets), std::move(canonical_basis_id),
680 2050 : *this);
681 3811 : }
682 :
683 : template <typename Scalar>
684 4808 : Eigen::SparseMatrix<Scalar, Eigen::RowMajor> Database::get_matrix_elements_in_canonical_basis(
685 : std::shared_ptr<const BasisAtom<Scalar>> initial_basis,
686 : std::shared_ptr<const BasisAtom<Scalar>> final_basis, OperatorType type, int q) {
687 : using real_t = typename traits::NumTraits<Scalar>::real_t;
688 : using cached_matrix_ptr_t = std::shared_ptr<const cached_matrix_t>;
689 :
690 4808 : if (&initial_basis->get_database() != this || &final_basis->get_database() != this) {
691 1 : throw std::invalid_argument(
692 : "The initial and final bases must belong to the Database instance used for the "
693 : "matrix element calculation.");
694 : }
695 4807 : if (initial_basis->get_species() != final_basis->get_species()) {
696 0 : throw std::invalid_argument(fmt::format(
697 : "The initial and final bases must have the same species, but got '{}' and '{}'.",
698 0 : initial_basis->get_species(), final_basis->get_species()));
699 : }
700 :
701 4807 : std::string specifier;
702 4807 : int kappa{};
703 4807 : switch (type) {
704 3722 : case OperatorType::ELECTRIC_DIPOLE:
705 3722 : specifier = "matrix_elements_d";
706 3722 : kappa = 1;
707 3722 : break;
708 257 : case OperatorType::ELECTRIC_QUADRUPOLE:
709 257 : specifier = "matrix_elements_q";
710 257 : kappa = 2;
711 257 : break;
712 49 : case OperatorType::ELECTRIC_QUADRUPOLE_ZERO:
713 49 : specifier = "matrix_elements_q0";
714 49 : kappa = 0;
715 49 : break;
716 0 : case OperatorType::ELECTRIC_OCTUPOLE:
717 0 : specifier = "matrix_elements_o";
718 0 : kappa = 3;
719 0 : break;
720 58 : case OperatorType::MAGNETIC_DIPOLE:
721 58 : specifier = "matrix_elements_mu";
722 58 : kappa = 1;
723 58 : break;
724 4 : case OperatorType::ENERGY:
725 4 : specifier = "energy";
726 4 : kappa = 0;
727 4 : break;
728 717 : case OperatorType::IDENTITY:
729 717 : specifier = "identity";
730 717 : kappa = 0;
731 717 : break;
732 0 : default:
733 0 : throw std::invalid_argument("Unknown operator type.");
734 : }
735 :
736 4807 : std::string canonical_basis_id_initial = initial_basis->get_canonical_basis_id();
737 4807 : std::string canonical_basis_id_final = final_basis->get_canonical_basis_id();
738 4807 : std::string cache_key = fmt::format("{}_{}_{}_{}", specifier, q, canonical_basis_id_initial,
739 : canonical_basis_id_final);
740 4807 : auto &matrix_elements_cache = get_matrix_elements_cache();
741 4807 : std::promise<cached_matrix_ptr_t> matrix_promise;
742 4807 : auto [cache_it, inserted] =
743 : matrix_elements_cache.insert({cache_key, matrix_promise.get_future().share()});
744 :
745 4807 : if (inserted) {
746 : try {
747 1180 : Eigen::Index num_rows = final_basis->get_number_of_kets();
748 1180 : Eigen::Index num_cols = initial_basis->get_number_of_kets();
749 :
750 1180 : std::vector<int> outerIndexPtr;
751 1180 : std::vector<int> innerIndices;
752 1180 : std::vector<real_t> values;
753 :
754 : // Check that the specifications are valid
755 1180 : if (std::abs(q) > kappa) {
756 0 : throw std::invalid_argument("Invalid q.");
757 : }
758 :
759 : // Ask the database for the operator
760 1180 : set_task_status("Loading matrix elements from database...");
761 1180 : std::string species = initial_basis->get_species();
762 1180 : duckdb::unique_ptr<duckdb::MaterializedQueryResult> result;
763 1180 : if (specifier == "identity") {
764 1204 : result = con->Query(fmt::format(
765 : R"(SELECT s2.ketid AS row, s1.ketid AS col, 1.0::DOUBLE AS val
766 : FROM '{}' AS s1
767 : INNER JOIN '{}' AS s2 ON s1.ketid = s2.ketid
768 : ORDER BY row ASC)",
769 : canonical_basis_id_initial, canonical_basis_id_final));
770 578 : } else if (specifier == "energy") {
771 6 : result = con->Query(fmt::format(
772 : R"(SELECT s2.ketid AS row, s1.ketid AS col, s1.energy AS val
773 : FROM '{}' AS s1
774 : INNER JOIN '{}' AS s2 ON s1.ketid = s2.ketid
775 : ORDER BY row ASC)",
776 : canonical_basis_id_initial, canonical_basis_id_final));
777 : } else {
778 2300 : result = con->Query(fmt::format(
779 : R"(WITH s1 AS (
780 : SELECT id, f, m, ketid FROM '{}'
781 : ),
782 : s2 AS (
783 : SELECT id, f, m, ketid FROM '{}'
784 : ),
785 : b AS (
786 : SELECT MIN(f) AS min_f, MAX(f) AS max_f,
787 : MIN(id) AS min_id, MAX(id) AS max_id
788 : FROM (SELECT f, id FROM s1 UNION ALL SELECT f, id FROM s2)
789 : ),
790 : w_filtered AS (
791 : SELECT *
792 : FROM '{}'
793 : WHERE kappa = {} AND q = {} AND
794 : f_initial BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b) AND
795 : f_final BETWEEN (SELECT min_f FROM b) AND (SELECT max_f FROM b)
796 : ),
797 : e_filtered AS (
798 : SELECT *
799 : FROM '{}'
800 : WHERE
801 : id_initial BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b) AND
802 : id_final BETWEEN (SELECT min_id FROM b) AND (SELECT max_id FROM b)
803 : )
804 : SELECT
805 : s2.ketid AS row,
806 : s1.ketid AS col,
807 : e.val*w.val AS val
808 : FROM e_filtered AS e
809 : JOIN s1 ON e.id_initial = s1.id
810 : JOIN s2 ON e.id_final = s2.id
811 : JOIN w_filtered AS w ON
812 : w.f_initial = s1.f AND w.m_initial = s1.m AND
813 : w.f_final = s2.f AND w.m_final = s2.m
814 : ORDER BY row ASC, col ASC)",
815 : canonical_basis_id_initial, canonical_basis_id_final,
816 1150 : manager->get_path("misc", "wigner"), kappa, q,
817 : manager->get_path(species, specifier)));
818 : }
819 :
820 1180 : if (result->HasError()) {
821 0 : throw cpptrace::runtime_error("Error querying the database: " + result->GetError());
822 : }
823 :
824 : // Check the types of the columns
825 1180 : const auto &types = result->types;
826 1180 : const auto &labels = result->names;
827 4720 : const std::vector<duckdb::LogicalType> ref_types = {duckdb::LogicalType::BIGINT,
828 : duckdb::LogicalType::BIGINT,
829 : duckdb::LogicalType::DOUBLE};
830 4720 : for (size_t i = 0; i < types.size(); i++) {
831 3540 : if (types[i] != ref_types[i]) {
832 0 : throw std::runtime_error("Wrong type for '" + labels[i] + "'.");
833 : }
834 : }
835 :
836 1180 : set_task_status("Constructing matrix elements...");
837 :
838 : // Construct the matrix
839 1180 : int num_entries = static_cast<int>(result->RowCount());
840 1180 : outerIndexPtr.reserve(num_rows + 1);
841 1180 : innerIndices.reserve(num_entries);
842 1180 : values.reserve(num_entries);
843 :
844 1180 : int last_row = -1;
845 :
846 2595 : for (auto chunk = result->Fetch(); chunk; chunk = result->Fetch()) {
847 1415 : auto *chunk_row = duckdb::FlatVector::GetData<int64_t>(chunk->data[0]);
848 1415 : auto *chunk_col = duckdb::FlatVector::GetData<int64_t>(chunk->data[1]);
849 1415 : auto *chunk_val = duckdb::FlatVector::GetData<double>(chunk->data[2]);
850 :
851 791696 : for (size_t i = 0; i < chunk->size(); i++) {
852 790281 : int row = final_basis->get_ket_index_from_id(chunk_row[i]);
853 790281 : if (row != last_row) {
854 51904 : if (row < last_row) {
855 0 : throw std::runtime_error("The rows are not sorted.");
856 : }
857 109424 : for (; last_row < row; last_row++) {
858 57520 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
859 : }
860 : }
861 790281 : innerIndices.push_back(initial_basis->get_ket_index_from_id(chunk_col[i]));
862 790281 : values.push_back(chunk_val[i]);
863 : }
864 : }
865 :
866 3948 : for (; last_row < num_rows + 1; last_row++) {
867 2768 : outerIndexPtr.push_back(static_cast<int>(innerIndices.size()));
868 : }
869 :
870 1180 : Eigen::Map<const cached_matrix_t> matrix_map(num_rows, num_cols, values.size(),
871 1180 : outerIndexPtr.data(), innerIndices.data(),
872 1180 : values.data());
873 :
874 1180 : auto cached_matrix = std::make_shared<const cached_matrix_t>(matrix_map);
875 1180 : matrix_promise.set_value(std::move(cached_matrix));
876 1180 : } catch (...) {
877 0 : matrix_promise.set_exception(std::current_exception());
878 0 : throw;
879 : }
880 : }
881 :
882 4807 : set_task_status("Returning matrix elements in canonical basis...");
883 :
884 9614 : return cache_it->second.get()->template cast<Scalar>();
885 5987 : }
886 :
887 4 : bool Database::get_download_missing() const { return download_missing_; }
888 :
889 2 : bool Database::get_use_cache() const { return use_cache_; }
890 :
891 6 : std::filesystem::path Database::get_database_dir() const { return database_dir_; }
892 :
893 0 : std::string Database::get_versions_info() const { return manager->get_versions_info(); }
894 :
895 4807 : Database::matrix_elements_cache_t &Database::get_matrix_elements_cache() {
896 4807 : static matrix_elements_cache_t matrix_elements_cache;
897 4807 : return matrix_elements_cache;
898 : }
899 :
900 44 : Database &Database::get_global_instance() {
901 88 : return get_global_instance_without_checks(default_download_missing, default_use_cache,
902 88 : default_database_dir);
903 : }
904 :
905 0 : Database &Database::get_global_instance(bool download_missing) {
906 0 : Database &database = get_global_instance_without_checks(download_missing, default_use_cache,
907 : default_database_dir);
908 0 : if (download_missing != database.download_missing_) {
909 0 : throw std::invalid_argument(
910 0 : "The 'download_missing' argument must not change between calls to the method.");
911 : }
912 0 : return database;
913 : }
914 :
915 0 : Database &Database::get_global_instance(std::filesystem::path database_dir) {
916 0 : if (database_dir.empty()) {
917 0 : database_dir = default_database_dir;
918 : }
919 0 : Database &database = get_global_instance_without_checks(default_download_missing,
920 : default_use_cache, database_dir);
921 0 : if (!std::filesystem::exists(database_dir) ||
922 0 : std::filesystem::canonical(database_dir) != database.database_dir_) {
923 0 : throw std::invalid_argument(
924 0 : "The 'database_dir' argument must not change between calls to the method.");
925 : }
926 0 : return database;
927 : }
928 :
929 1 : Database &Database::get_global_instance(bool download_missing, bool use_cache,
930 : std::filesystem::path database_dir) {
931 1 : if (database_dir.empty()) {
932 0 : database_dir = default_database_dir;
933 : }
934 : Database &database =
935 1 : get_global_instance_without_checks(download_missing, use_cache, database_dir);
936 1 : if (download_missing != database.download_missing_ || use_cache != database.use_cache_ ||
937 3 : !std::filesystem::exists(database_dir) ||
938 2 : std::filesystem::canonical(database_dir) != database.database_dir_) {
939 0 : throw std::invalid_argument(
940 : "The 'download_missing', 'use_cache' and 'database_dir' arguments must not "
941 0 : "change between calls to the method.");
942 : }
943 1 : return database;
944 : }
945 :
946 45 : Database &Database::get_global_instance_without_checks(bool download_missing, bool use_cache,
947 : std::filesystem::path database_dir) {
948 45 : static Database database(download_missing, use_cache, std::move(database_dir));
949 45 : return database;
950 : }
951 :
952 : struct database_dir_noexcept : std::filesystem::path {
953 2 : explicit database_dir_noexcept() noexcept try : std
954 2 : ::filesystem::path(paths::get_cache_directory() / "database") {}
955 0 : catch (...) {
956 0 : SPDLOG_ERROR("Error getting the PairInteraction cache directory.");
957 0 : std::terminate();
958 2 : }
959 : };
960 :
961 : const std::filesystem::path Database::default_database_dir = database_dir_noexcept();
962 :
963 : // Explicit instantiations
964 : // NOLINTBEGIN(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
965 : #define INSTANTIATE_GETTERS(SCALAR) \
966 : template std::shared_ptr<const BasisAtom<SCALAR>> Database::get_basis<SCALAR>( \
967 : const std::string &species, const AtomDescriptionByRanges &description, \
968 : const std::vector<size_t> &additional_ket_ids); \
969 : template Eigen::SparseMatrix<SCALAR, Eigen::RowMajor> \
970 : Database::get_matrix_elements_in_canonical_basis<SCALAR>( \
971 : std::shared_ptr<const BasisAtom<SCALAR>> initial_basis, \
972 : std::shared_ptr<const BasisAtom<SCALAR>> final_basis, OperatorType type, int q);
973 : // NOLINTEND(bugprone-macro-parentheses, cppcoreguidelines-macro-usage)
974 :
975 : INSTANTIATE_GETTERS(double)
976 : INSTANTIATE_GETTERS(std::complex<double>)
977 :
978 : #undef INSTANTIATE_GETTERS
979 : } // namespace pairinteraction
|