Line data Source code
1 : // SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/ParquetManager.hpp"
5 :
6 : #include "pairinteraction/database/GitHubDownloader.hpp"
7 : #include "pairinteraction/version.hpp"
8 :
9 : #include <cpptrace/cpptrace.hpp>
10 : #include <ctime>
11 : #include <duckdb.hpp>
12 : #include <filesystem>
13 : #include <fmt/core.h>
14 : #include <fmt/format.h>
15 : #include <fstream>
16 : #include <future>
17 : #include <iomanip>
18 : #include <miniz.h>
19 : #include <nlohmann/json.hpp>
20 : #include <regex>
21 : #include <set>
22 : #include <spdlog/spdlog.h>
23 : #include <sstream>
24 : #include <stdexcept>
25 : #include <string>
26 :
27 : namespace fs = std::filesystem;
28 : using json = nlohmann::json;
29 :
30 : namespace pairinteraction {
31 : namespace {
32 :
33 1 : std::string format_time(std::time_t time_val) {
34 1 : std::tm *ptm = std::localtime(&time_val);
35 1 : std::ostringstream oss;
36 1 : oss << std::put_time(ptm, "%Y-%m-%d %H:%M:%S");
37 2 : return oss.str();
38 1 : }
39 :
40 0 : json load_json(const fs::path &file) {
41 0 : std::ifstream in(file);
42 0 : in.exceptions(std::ifstream::failbit | std::ifstream::badbit);
43 0 : json doc;
44 0 : in >> doc;
45 0 : return doc;
46 0 : }
47 :
48 1 : void save_json(const fs::path &file, const json &doc) {
49 1 : std::ofstream out(file);
50 1 : if (!out) {
51 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", file.string()));
52 : }
53 1 : out << doc;
54 1 : out.close();
55 1 : }
56 :
57 : } // namespace
58 :
59 0 : void ParquetManager::react_on_exception(const std::string &context, const std::exception &e) {
60 0 : repo_paths_.clear();
61 0 : remote_asset_info.clear();
62 0 : SPDLOG_ERROR("{}: {}. The download of database tables is disabled.", context, e.what());
63 0 : }
64 :
65 0 : void ParquetManager::react_on_error_code(const std::string &context, int error_code) {
66 0 : repo_paths_.clear();
67 0 : remote_asset_info.clear();
68 0 : SPDLOG_ERROR("{}: status code {}. The download of database tables is disabled.", context,
69 : error_code);
70 0 : }
71 :
72 0 : void ParquetManager::react_on_rate_limit_reached(std::time_t reset_time) {
73 0 : repo_paths_.clear();
74 0 : remote_asset_info.clear();
75 0 : SPDLOG_ERROR("Rate limit reached, resets at {}. The download of database tables is disabled.",
76 : format_time(reset_time));
77 0 : }
78 :
79 6 : ParquetManager::ParquetManager(std::filesystem::path directory, const GitHubDownloader &downloader,
80 : std::vector<std::string> repo_paths, duckdb::Connection &con,
81 6 : bool use_cache)
82 6 : : directory_(std::move(directory)), downloader(downloader), repo_paths_(std::move(repo_paths)),
83 6 : con(con), use_cache_(use_cache) {
84 : // Ensure the local directory exists
85 6 : if (!std::filesystem::exists(directory_ / "tables")) {
86 0 : fs::create_directories(directory_ / "tables");
87 : }
88 :
89 : // If repo paths are provided, check the GitHub rate limit
90 6 : if (!repo_paths_.empty()) {
91 1 : GitHubDownloader::Result result;
92 : try {
93 1 : result = downloader.download("/rate_limit", "", false).get();
94 0 : } catch (const std::exception &e) {
95 0 : react_on_exception("Failed obtaining the rate limit", e);
96 0 : return;
97 0 : }
98 1 : if (result.status_code != 200) {
99 0 : react_on_error_code("Failed obtaining the rate limit", result.status_code);
100 0 : return;
101 : }
102 1 : if (result.rate_limit.remaining == 0) {
103 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
104 0 : return;
105 : }
106 1 : SPDLOG_INFO("Remaining GitHub API requests: {}. Rate limit resets at {}.",
107 : result.rate_limit.remaining, format_time(result.rate_limit.reset_time));
108 1 : }
109 0 : }
110 :
111 6 : void ParquetManager::scan_remote() {
112 : // If repo_paths_ is empty, we have nothing to do
113 6 : if (repo_paths_.empty()) {
114 5 : return;
115 : }
116 :
117 1 : remote_asset_info.clear();
118 :
119 : struct RepoDownload {
120 : std::string endpoint;
121 : fs::path cache_file;
122 : json cached_doc;
123 : std::future<GitHubDownloader::Result> future;
124 : };
125 1 : std::vector<RepoDownload> downloads;
126 1 : downloads.reserve(repo_paths_.size());
127 :
128 : // For each repo path, load its cached JSON (or an empty JSON) and issue the download
129 2 : for (const auto &endpoint : repo_paths_) {
130 : // Generate a unique cache filename per endpoint
131 : std::string cache_filename =
132 1 : "homepage_cache_" + std::to_string(std::hash<std::string>{}(endpoint)) + ".json";
133 1 : auto cache_file = directory_ / cache_filename;
134 :
135 : // Load cached JSON from file if it exists
136 1 : json cached_doc;
137 1 : if (std::filesystem::exists(cache_file)) {
138 : try {
139 0 : cached_doc = load_json(cache_file);
140 0 : } catch (const std::exception &e) {
141 0 : SPDLOG_WARN("Error reading {}: {}. Discarding homepage cache.", cache_file.string(),
142 : e.what());
143 0 : cached_doc = json{};
144 0 : }
145 : }
146 :
147 : // Extract the last-modified header from the cached JSON if available
148 1 : std::string last_modified;
149 1 : if (!cached_doc.is_null() && cached_doc.contains("last-modified")) {
150 0 : last_modified = cached_doc["last-modified"].get<std::string>();
151 : }
152 :
153 : // Issue the asynchronous download using the cached last-modified value.
154 : try {
155 1 : downloads.push_back(
156 1 : {endpoint, cache_file, cached_doc, downloader.download(endpoint, last_modified)});
157 0 : } catch (const std::exception &e) {
158 0 : react_on_exception(
159 0 : fmt::format("Failed to download overview of available tables from {}",
160 0 : downloads.back().endpoint),
161 : e);
162 0 : return;
163 0 : }
164 1 : }
165 :
166 : // Process downloads for each repo path
167 2 : for (auto &dl : downloads) {
168 1 : auto result = dl.future.get();
169 1 : if ((result.status_code == 403 || result.status_code == 429) &&
170 0 : result.rate_limit.remaining == 0) {
171 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
172 0 : return;
173 : }
174 1 : if (result.status_code != 200 && result.status_code != 304) {
175 0 : react_on_error_code(
176 0 : fmt::format("Failed to download overview of available tables from {}", dl.endpoint),
177 : result.status_code);
178 0 : return;
179 : }
180 :
181 1 : json doc;
182 1 : if (result.status_code == 304) {
183 0 : if (dl.cached_doc.is_null() || dl.cached_doc.empty()) {
184 0 : throw std::runtime_error(
185 0 : fmt::format("Received 304 Not Modified but cached response {} does not exist.",
186 0 : dl.cache_file.string()));
187 : }
188 0 : doc = dl.cached_doc;
189 0 : SPDLOG_INFO("Using cached overview of available tables from {}.", dl.endpoint);
190 : } else {
191 1 : doc = json::parse(result.body, nullptr, /*allow_exceptions=*/false);
192 1 : doc["last-modified"] = result.last_modified;
193 1 : save_json(dl.cache_file, doc);
194 1 : SPDLOG_INFO("Using downloaded overview of available tables from {}.", dl.endpoint);
195 : }
196 :
197 : // Validate the JSON response
198 1 : if (doc.is_discarded() || !doc.contains("assets")) {
199 0 : throw std::runtime_error(fmt::format(
200 0 : "Failed to parse remote JSON or missing 'assets' key from {}.", dl.endpoint));
201 : }
202 :
203 : // Update remote_asset_info based on the asset entries
204 2 : for (auto &asset : doc["assets"]) {
205 1 : std::string name = asset["name"].get<std::string>();
206 1 : std::smatch match;
207 :
208 1 : if (std::regex_match(name, match, remote_regex) && match.size() == 4) {
209 1 : std::string key = match[1].str();
210 1 : int version_major = std::stoi(match[2].str());
211 1 : int version_minor = std::stoi(match[3].str());
212 :
213 1 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
214 0 : continue;
215 : }
216 :
217 1 : auto it = remote_asset_info.find(key);
218 1 : if (it == remote_asset_info.end() || version_minor > it->second.version_minor) {
219 1 : std::string remote_url = asset["url"].get<std::string>();
220 1 : const std::string host = downloader.get_host();
221 2 : remote_asset_info[key] = {version_minor, remote_url.erase(0, host.size())};
222 1 : }
223 1 : }
224 1 : }
225 1 : }
226 :
227 : // Ensure that scan_remote was successful
228 1 : if (!downloads.empty() && remote_asset_info.empty()) {
229 0 : throw std::runtime_error(
230 : "No compatible database tables were found in the remote repositories. Consider "
231 0 : "upgrading PairInteraction to a newer version.");
232 : }
233 4 : }
234 :
235 6 : void ParquetManager::scan_local() {
236 6 : local_asset_info.clear();
237 :
238 : // Iterate over files in the directory to update local_asset_info
239 72 : for (const auto &entry : fs::directory_iterator(directory_ / "tables")) {
240 33 : std::string name = entry.path().filename().string();
241 33 : std::smatch match;
242 :
243 66 : if (entry.is_directory() && std::regex_match(name, match, local_regex) &&
244 33 : match.size() == 4) {
245 33 : std::string key = match[1].str();
246 33 : int version_major = std::stoi(match[2].str());
247 33 : int version_minor = std::stoi(match[3].str());
248 :
249 33 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
250 0 : continue;
251 : }
252 :
253 33 : auto it = local_asset_info.find(key);
254 33 : if (it == local_asset_info.end() || version_minor > it->second.version_minor) {
255 33 : local_asset_info[key].version_minor = version_minor;
256 309 : for (const auto &subentry : fs::directory_iterator(entry)) {
257 138 : if (subentry.is_regular_file() && subentry.path().extension() == ".parquet") {
258 276 : local_asset_info[key].paths[subentry.path().stem().string()] = {
259 414 : subentry.path().string(), false};
260 : }
261 33 : }
262 : }
263 33 : }
264 39 : }
265 144 : }
266 :
267 4815 : void ParquetManager::update_local_asset(const std::string &key) {
268 4815 : assert(remote_asset_info.empty() == repo_paths_.empty());
269 :
270 : // If remote_asset_info is empty, we have nothing to do
271 4815 : if (remote_asset_info.empty()) {
272 4814 : return;
273 : }
274 :
275 : // Get remote version if available
276 1 : int remote_version = -1;
277 1 : auto remote_it = remote_asset_info.find(key);
278 1 : if (remote_it != remote_asset_info.end()) {
279 1 : remote_version = remote_it->second.version_minor;
280 : }
281 :
282 : // Get local version if available and check if it is up-to-date
283 : {
284 1 : int local_version = -1;
285 1 : std::shared_lock<std::shared_mutex> lock(mtx_local);
286 1 : auto local_it = local_asset_info.find(key);
287 1 : if (local_it != local_asset_info.end()) {
288 1 : local_version = local_it->second.version_minor;
289 : }
290 1 : if (local_version >= remote_version) {
291 0 : return;
292 : }
293 1 : }
294 :
295 : // If it is not up-to-date, acquire a unique lock for updating the table
296 1 : std::unique_lock<std::shared_mutex> lock(mtx_local);
297 :
298 : // Re-check if the table is up to date because another thread might have updated it
299 1 : int local_version = -1;
300 1 : auto local_it = local_asset_info.find(key);
301 1 : if (local_it != local_asset_info.end()) {
302 1 : local_version = local_it->second.version_minor;
303 : }
304 1 : if (local_version >= remote_version) {
305 0 : return;
306 : }
307 :
308 : // Download the remote file
309 1 : std::string endpoint = remote_it->second.endpoint;
310 1 : SPDLOG_INFO("Downloading {}_v{}.{} from {}", key, COMPATIBLE_DATABASE_VERSION_MAJOR,
311 : remote_version, endpoint);
312 :
313 1 : GitHubDownloader::Result result;
314 : try {
315 1 : result = downloader.download(endpoint, "", true).get();
316 0 : } catch (const std::exception &e) {
317 0 : react_on_exception(fmt::format("Failed to download table {}", endpoint), e);
318 0 : return;
319 0 : }
320 1 : if ((result.status_code == 403 || result.status_code == 429) &&
321 0 : result.rate_limit.remaining == 0) {
322 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
323 0 : return;
324 : }
325 1 : if (result.status_code != 200) {
326 0 : react_on_error_code(fmt::format("Failed to download table {}", endpoint),
327 : result.status_code);
328 0 : return;
329 : }
330 :
331 : // Unzip the downloaded file
332 1 : mz_zip_archive zip_archive{};
333 1 : if (mz_zip_reader_init_mem(&zip_archive, result.body.data(), result.body.size(), 0) == 0) {
334 0 : throw std::runtime_error("Failed to initialize zip archive.");
335 : }
336 :
337 2 : for (mz_uint i = 0; i < mz_zip_reader_get_num_files(&zip_archive); i++) {
338 : mz_zip_archive_file_stat file_stat;
339 1 : if (mz_zip_reader_file_stat(&zip_archive, i, &file_stat) == 0) {
340 0 : throw std::runtime_error("Failed to get file stat from zip archive.");
341 : }
342 :
343 : // Skip directories
344 1 : const char *filename = static_cast<const char *>(file_stat.m_filename);
345 1 : size_t len = std::strlen(filename);
346 1 : if (len > 0 && filename[len - 1] == '/') {
347 0 : continue;
348 : }
349 :
350 : // Ensure that the filename matches the expectations
351 1 : std::string dir = fs::path(filename).parent_path().string();
352 1 : std::string stem = fs::path(filename).stem().string();
353 1 : std::string suffix = fs::path(filename).extension().string();
354 1 : std::smatch match;
355 2 : if (!std::regex_match(dir, match, local_regex) || match.size() != 4 ||
356 1 : suffix != ".parquet") {
357 0 : throw std::runtime_error(
358 0 : fmt::format("Unexpected filename {} in zip archive.", filename));
359 : }
360 :
361 : // Construct the path to store the table
362 1 : auto path = directory_ / "tables" / dir / (stem + suffix);
363 1 : SPDLOG_INFO("Storing table to {}", path.string());
364 :
365 : // Extract the file to memory
366 1 : std::vector<char> buffer(file_stat.m_uncomp_size);
367 1 : if (mz_zip_reader_extract_to_mem(&zip_archive, i, buffer.data(), buffer.size(), 0) == 0) {
368 0 : throw std::runtime_error(fmt::format("Failed to extract {}.", filename));
369 : }
370 :
371 : // Ensure the parent directory exists
372 1 : fs::create_directories(path.parent_path());
373 :
374 : // Save the extracted file
375 1 : std::ofstream out(path.string(), std::ios::binary);
376 1 : if (!out) {
377 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", path.string()));
378 : }
379 1 : out.write(buffer.data(), static_cast<std::streamsize>(buffer.size()));
380 1 : out.close();
381 :
382 : // Update the local asset/table info
383 1 : local_asset_info[key].version_minor = remote_version;
384 1 : local_asset_info[key].paths[path.stem().string()] = {path.string(), false};
385 1 : }
386 :
387 1 : mz_zip_reader_end(&zip_archive);
388 2 : }
389 :
390 4812 : void ParquetManager::cache_table(std::unordered_map<std::string, PathInfo>::iterator table_it) {
391 : // Check if the table is already cached
392 : {
393 4812 : std::shared_lock<std::shared_mutex> lock(mtx_local);
394 4812 : if (table_it->second.cached) {
395 4785 : return;
396 : }
397 4812 : }
398 :
399 : // Acquire a unique lock for caching the table
400 27 : std::unique_lock<std::shared_mutex> lock(mtx_local);
401 :
402 : // Re-check if the table is already cached because another thread might have cached it
403 27 : if (table_it->second.cached) {
404 0 : return;
405 : }
406 :
407 : // Cache the table in memory
408 27 : std::string table_name;
409 : {
410 27 : auto result = con.Query(R"(SELECT UUID()::varchar)");
411 27 : if (result->HasError()) {
412 0 : throw cpptrace::runtime_error("Error selecting a unique table name: " +
413 0 : result->GetError());
414 : }
415 : table_name =
416 27 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
417 27 : }
418 :
419 : {
420 27 : auto result = con.Query(fmt::format(R"(CREATE TEMP TABLE '{}' AS SELECT * FROM '{}')",
421 54 : table_name, table_it->second.path));
422 27 : if (result->HasError()) {
423 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
424 : }
425 27 : }
426 :
427 27 : table_it->second.path = table_name;
428 27 : table_it->second.cached = true;
429 27 : }
430 :
431 4815 : std::string ParquetManager::get_path(const std::string &key, const std::string &table) {
432 : // Update the local table if a newer version is available remotely
433 4815 : this->update_local_asset(key);
434 :
435 : // Ensure availability of the local table file
436 4815 : auto asset_it = local_asset_info.find(key);
437 4815 : if (asset_it == local_asset_info.end()) {
438 : // If we do not know about any table that can be downloaded, downloading might be blocked.
439 : // Otherwise, the species might be misspelled.
440 0 : if (remote_asset_info.empty()) {
441 0 : throw std::runtime_error(
442 0 : "No tables found for species '" + key +
443 : "'. Check whether you have allowed downloading missing tables, or download the "
444 0 : "tables manually via `pairinteraction database download " +
445 0 : key + "`.");
446 : }
447 0 : throw std::runtime_error("No tables found for species '" + key +
448 0 : "'. Check the spelling of the species.");
449 : }
450 4815 : auto table_it = asset_it->second.paths.find(table);
451 4815 : if (table_it == asset_it->second.paths.end()) {
452 1 : throw std::runtime_error("No table '" + table + ".parquet' found for species '" + key +
453 2 : "'. The tables for the species are incomplete.");
454 : }
455 :
456 : // Cache the local table in memory if requested
457 4814 : if (use_cache_) {
458 4812 : this->cache_table(table_it);
459 : }
460 :
461 : // Return the path to the local table file
462 9628 : return table_it->second.path;
463 : }
464 :
465 3 : std::string ParquetManager::get_versions_info() const {
466 : // Helper lambda returns the version string if available
467 54 : auto get_version = [](const auto &map, const std::string &table) -> int {
468 54 : if (auto it = map.find(table); it != map.end()) {
469 27 : return it->second.version_minor;
470 : }
471 27 : return -1;
472 : };
473 :
474 : // Gather all unique table names
475 3 : std::set<std::string> tables;
476 30 : for (const auto &entry : local_asset_info) {
477 27 : tables.insert(entry.first);
478 : }
479 3 : for (const auto &entry : remote_asset_info) {
480 0 : tables.insert(entry.first);
481 : }
482 :
483 : // Output versions info
484 3 : std::ostringstream oss;
485 :
486 3 : oss << " ";
487 3 : oss << std::left << std::setw(17) << "Asset";
488 3 : oss << std::left << std::setw(6 + 4) << "Local";
489 3 : oss << std::left << std::setw(7) << "Remote\n";
490 3 : oss << std::string(35, '-') << "\n";
491 :
492 30 : for (const auto &table : tables) {
493 27 : int local_version = get_version(local_asset_info, table);
494 27 : int remote_version = get_version(remote_asset_info, table);
495 :
496 : std::string comparator = (local_version < remote_version)
497 : ? "<"
498 27 : : ((local_version > remote_version) ? ">" : "==");
499 : std::string local_version_str = local_version == -1
500 27 : ? "N/A"
501 54 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
502 81 : std::to_string(local_version);
503 : std::string remote_version_str = remote_version == -1
504 27 : ? "N/A"
505 27 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
506 54 : std::to_string(remote_version);
507 :
508 27 : oss << " ";
509 27 : oss << std::left << std::setw(17) << table;
510 27 : oss << std::left << std::setw(6) << local_version_str;
511 27 : oss << std::left << std::setw(4) << comparator;
512 27 : oss << std::left << std::setw(7) << remote_version_str << "\n";
513 27 : }
514 6 : return oss.str();
515 3 : }
516 :
517 : } // namespace pairinteraction
|