Line data Source code
1 : // SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/ParquetManager.hpp"
5 :
6 : #include "pairinteraction/database/GitHubDownloader.hpp"
7 : #include "pairinteraction/version.hpp"
8 :
9 : #include <cpptrace/cpptrace.hpp>
10 : #include <ctime>
11 : #include <duckdb.hpp>
12 : #include <filesystem>
13 : #include <fmt/core.h>
14 : #include <fstream>
15 : #include <future>
16 : #include <iomanip>
17 : #include <miniz.h>
18 : #include <nlohmann/json.hpp>
19 : #include <regex>
20 : #include <set>
21 : #include <spdlog/spdlog.h>
22 : #include <sstream>
23 : #include <stdexcept>
24 : #include <string>
25 :
26 : namespace fs = std::filesystem;
27 : using json = nlohmann::json;
28 :
29 1 : std::string format_time(std::time_t time_val) {
30 1 : std::tm *ptm = std::localtime(&time_val);
31 1 : std::ostringstream oss;
32 1 : oss << std::put_time(ptm, "%Y-%m-%d %H:%M:%S");
33 2 : return oss.str();
34 1 : }
35 :
36 0 : json load_json(const fs::path &file) {
37 0 : std::ifstream in(file);
38 0 : in.exceptions(std::ifstream::failbit | std::ifstream::badbit);
39 0 : json doc;
40 0 : in >> doc;
41 0 : return doc;
42 0 : }
43 :
44 1 : void save_json(const fs::path &file, const json &doc) {
45 1 : std::ofstream out(file);
46 1 : if (!out) {
47 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", file.string()));
48 : }
49 1 : out << doc;
50 1 : out.close();
51 1 : }
52 :
53 : namespace pairinteraction {
54 0 : void ParquetManager::react_on_exception(const std::string &context, const std::exception &e) {
55 0 : repo_paths_.clear();
56 0 : remote_asset_info.clear();
57 0 : SPDLOG_ERROR("{}: {}. The download of database tables is disabled.", context, e.what());
58 0 : }
59 :
60 0 : void ParquetManager::react_on_error_code(const std::string &context, int error_code) {
61 0 : repo_paths_.clear();
62 0 : remote_asset_info.clear();
63 0 : SPDLOG_ERROR("{}: status code {}. The download of database tables is disabled.", context,
64 : error_code);
65 0 : }
66 :
67 0 : void ParquetManager::react_on_rate_limit_reached(std::time_t reset_time) {
68 0 : repo_paths_.clear();
69 0 : remote_asset_info.clear();
70 0 : SPDLOG_ERROR("Rate limit reached, resets at {}. The download of database tables is disabled.",
71 : format_time(reset_time));
72 0 : }
73 :
74 5 : ParquetManager::ParquetManager(std::filesystem::path directory, const GitHubDownloader &downloader,
75 : std::vector<std::string> repo_paths, duckdb::Connection &con,
76 5 : bool use_cache)
77 5 : : directory_(std::move(directory)), downloader(downloader), repo_paths_(std::move(repo_paths)),
78 5 : con(con), use_cache_(use_cache) {
79 : // Ensure the local directory exists
80 5 : if (!std::filesystem::exists(directory_ / "tables")) {
81 0 : fs::create_directories(directory_ / "tables");
82 : }
83 :
84 : // If repo paths are provided, check the GitHub rate limit
85 5 : if (!repo_paths_.empty()) {
86 1 : GitHubDownloader::Result result;
87 : try {
88 1 : result = downloader.download("/rate_limit", "", false).get();
89 0 : } catch (const std::exception &e) {
90 0 : react_on_exception("Failed obtaining the rate limit", e);
91 0 : return;
92 0 : }
93 1 : if (result.status_code != 200) {
94 0 : react_on_error_code("Failed obtaining the rate limit", result.status_code);
95 0 : return;
96 : }
97 1 : if (result.rate_limit.remaining == 0) {
98 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
99 0 : return;
100 : }
101 1 : SPDLOG_INFO("Remaining GitHub API requests: {}. Rate limit resets at {}.",
102 : result.rate_limit.remaining, format_time(result.rate_limit.reset_time));
103 1 : }
104 0 : }
105 :
106 5 : void ParquetManager::scan_remote() {
107 : // If repo_paths_ is empty, we have nothing to do
108 5 : if (repo_paths_.empty()) {
109 4 : return;
110 : }
111 :
112 1 : remote_asset_info.clear();
113 :
114 : struct RepoDownload {
115 : std::string endpoint;
116 : fs::path cache_file;
117 : json cached_doc;
118 : std::future<GitHubDownloader::Result> future;
119 : };
120 1 : std::vector<RepoDownload> downloads;
121 1 : downloads.reserve(repo_paths_.size());
122 :
123 : // For each repo path, load its cached JSON (or an empty JSON) and issue the download
124 2 : for (const auto &endpoint : repo_paths_) {
125 : // Generate a unique cache filename per endpoint
126 : std::string cache_filename =
127 1 : "homepage_cache_" + std::to_string(std::hash<std::string>{}(endpoint)) + ".json";
128 1 : auto cache_file = directory_ / cache_filename;
129 :
130 : // Load cached JSON from file if it exists
131 1 : json cached_doc;
132 1 : if (std::filesystem::exists(cache_file)) {
133 : try {
134 0 : cached_doc = load_json(cache_file);
135 0 : } catch (const std::exception &e) {
136 0 : SPDLOG_WARN("Error reading {}: {}. Discarding homepage cache.", cache_file.string(),
137 : e.what());
138 0 : cached_doc = json{};
139 0 : }
140 : }
141 :
142 : // Extract the last-modified header from the cached JSON if available
143 1 : std::string last_modified;
144 1 : if (!cached_doc.is_null() && cached_doc.contains("last-modified")) {
145 0 : last_modified = cached_doc["last-modified"].get<std::string>();
146 : }
147 :
148 : // Issue the asynchronous download using the cached last-modified value.
149 : try {
150 1 : downloads.push_back(
151 1 : {endpoint, cache_file, cached_doc, downloader.download(endpoint, last_modified)});
152 0 : } catch (const std::exception &e) {
153 0 : react_on_exception(
154 0 : fmt::format("Failed to download overview of available tables from {}",
155 0 : downloads.back().endpoint),
156 : e);
157 0 : return;
158 0 : }
159 1 : }
160 :
161 : // Process downloads for each repo path
162 2 : for (auto &dl : downloads) {
163 1 : auto result = dl.future.get();
164 1 : if ((result.status_code == 403 || result.status_code == 429) &&
165 0 : result.rate_limit.remaining == 0) {
166 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
167 0 : return;
168 : }
169 1 : if (result.status_code != 200 && result.status_code != 304) {
170 0 : react_on_error_code(
171 0 : fmt::format("Failed to download overview of available tables from {}", dl.endpoint),
172 : result.status_code);
173 0 : return;
174 : }
175 :
176 1 : json doc;
177 1 : if (result.status_code == 304) {
178 0 : if (dl.cached_doc.is_null() || dl.cached_doc.empty()) {
179 0 : throw std::runtime_error(
180 0 : fmt::format("Received 304 Not Modified but cached response {} does not exist.",
181 0 : dl.cache_file.string()));
182 : }
183 0 : doc = dl.cached_doc;
184 0 : SPDLOG_INFO("Using cached overview of available tables from {}.", dl.endpoint);
185 : } else {
186 1 : doc = json::parse(result.body, nullptr, /*allow_exceptions=*/false);
187 1 : doc["last-modified"] = result.last_modified;
188 1 : save_json(dl.cache_file, doc);
189 1 : SPDLOG_INFO("Using downloaded overview of available tables from {}.", dl.endpoint);
190 : }
191 :
192 : // Validate the JSON response
193 1 : if (doc.is_discarded() || !doc.contains("assets")) {
194 0 : throw std::runtime_error(fmt::format(
195 0 : "Failed to parse remote JSON or missing 'assets' key from {}.", dl.endpoint));
196 : }
197 :
198 : // Update remote_asset_info based on the asset entries
199 2 : for (auto &asset : doc["assets"]) {
200 1 : std::string name = asset["name"].get<std::string>();
201 1 : std::smatch match;
202 :
203 1 : if (std::regex_match(name, match, remote_regex) && match.size() == 4) {
204 1 : std::string key = match[1].str();
205 1 : int version_major = std::stoi(match[2].str());
206 1 : int version_minor = std::stoi(match[3].str());
207 :
208 1 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
209 0 : continue;
210 : }
211 :
212 1 : auto it = remote_asset_info.find(key);
213 1 : if (it == remote_asset_info.end() || version_minor > it->second.version_minor) {
214 1 : std::string remote_url = asset["url"].get<std::string>();
215 1 : const std::string host = downloader.get_host();
216 2 : remote_asset_info[key] = {version_minor, remote_url.erase(0, host.size())};
217 1 : }
218 1 : }
219 1 : }
220 1 : }
221 :
222 : // Ensure that scan_remote was successful
223 1 : if (!downloads.empty() && remote_asset_info.empty()) {
224 0 : throw std::runtime_error(
225 : "No compatible database tables were found in the remote repositories. Consider "
226 0 : "upgrading PairInteraction to a newer version.");
227 : }
228 4 : }
229 :
230 5 : void ParquetManager::scan_local() {
231 5 : local_asset_info.clear();
232 :
233 : // Iterate over files in the directory to update local_asset_info
234 53 : for (const auto &entry : fs::directory_iterator(directory_ / "tables")) {
235 24 : std::string name = entry.path().filename().string();
236 24 : std::smatch match;
237 :
238 48 : if (entry.is_directory() && std::regex_match(name, match, local_regex) &&
239 24 : match.size() == 4) {
240 24 : std::string key = match[1].str();
241 24 : int version_major = std::stoi(match[2].str());
242 24 : int version_minor = std::stoi(match[3].str());
243 :
244 24 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
245 0 : continue;
246 : }
247 :
248 24 : auto it = local_asset_info.find(key);
249 24 : if (it == local_asset_info.end() || version_minor > it->second.version_minor) {
250 24 : local_asset_info[key].version_minor = version_minor;
251 212 : for (const auto &subentry : fs::directory_iterator(entry)) {
252 94 : if (subentry.is_regular_file() && subentry.path().extension() == ".parquet") {
253 188 : local_asset_info[key].paths[subentry.path().stem().string()] = {
254 282 : subentry.path().string(), false};
255 : }
256 24 : }
257 : }
258 24 : }
259 29 : }
260 99 : }
261 :
262 1721 : void ParquetManager::update_local_asset(const std::string &key) {
263 1721 : assert(remote_asset_info.empty() == repo_paths_.empty());
264 :
265 : // If remote_asset_info is empty, we have nothing to do
266 1721 : if (remote_asset_info.empty()) {
267 1720 : return;
268 : }
269 :
270 : // Get remote version if available
271 1 : int remote_version = -1;
272 1 : auto remote_it = remote_asset_info.find(key);
273 1 : if (remote_it != remote_asset_info.end()) {
274 1 : remote_version = remote_it->second.version_minor;
275 : }
276 :
277 : // Get local version if available and check if it is up-to-date
278 : {
279 1 : int local_version = -1;
280 1 : std::shared_lock<std::shared_mutex> lock(mtx_local);
281 1 : auto local_it = local_asset_info.find(key);
282 1 : if (local_it != local_asset_info.end()) {
283 1 : local_version = local_it->second.version_minor;
284 : }
285 1 : if (local_version >= remote_version) {
286 0 : return;
287 : }
288 1 : }
289 :
290 : // If it is not up-to-date, acquire a unique lock for updating the table
291 1 : std::unique_lock<std::shared_mutex> lock(mtx_local);
292 :
293 : // Re-check if the table is up to date because another thread might have updated it
294 1 : int local_version = -1;
295 1 : auto local_it = local_asset_info.find(key);
296 1 : if (local_it != local_asset_info.end()) {
297 1 : local_version = local_it->second.version_minor;
298 : }
299 1 : if (local_version >= remote_version) {
300 0 : return;
301 : }
302 :
303 : // Download the remote file
304 1 : std::string endpoint = remote_it->second.endpoint;
305 1 : SPDLOG_INFO("Downloading {}_v{}.{} from {}", key, COMPATIBLE_DATABASE_VERSION_MAJOR,
306 : remote_version, endpoint);
307 :
308 1 : GitHubDownloader::Result result;
309 : try {
310 1 : result = downloader.download(endpoint, "", true).get();
311 0 : } catch (const std::exception &e) {
312 0 : react_on_exception(fmt::format("Failed to download table {}", endpoint), e);
313 0 : return;
314 0 : }
315 1 : if ((result.status_code == 403 || result.status_code == 429) &&
316 0 : result.rate_limit.remaining == 0) {
317 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
318 0 : return;
319 : }
320 1 : if (result.status_code != 200) {
321 0 : react_on_error_code(fmt::format("Failed to download table {}", endpoint),
322 : result.status_code);
323 0 : return;
324 : }
325 :
326 : // Unzip the downloaded file
327 1 : mz_zip_archive zip_archive{};
328 1 : if (mz_zip_reader_init_mem(&zip_archive, result.body.data(), result.body.size(), 0) == 0) {
329 0 : throw std::runtime_error("Failed to initialize zip archive.");
330 : }
331 :
332 2 : for (mz_uint i = 0; i < mz_zip_reader_get_num_files(&zip_archive); i++) {
333 : mz_zip_archive_file_stat file_stat;
334 1 : if (mz_zip_reader_file_stat(&zip_archive, i, &file_stat) == 0) {
335 0 : throw std::runtime_error("Failed to get file stat from zip archive.");
336 : }
337 :
338 : // Skip directories
339 1 : const char *filename = static_cast<const char *>(file_stat.m_filename);
340 1 : size_t len = std::strlen(filename);
341 1 : if (len > 0 && filename[len - 1] == '/') {
342 0 : continue;
343 : }
344 :
345 : // Ensure that the filename matches the expectations
346 1 : std::string dir = fs::path(filename).parent_path().string();
347 1 : std::string stem = fs::path(filename).stem().string();
348 1 : std::string suffix = fs::path(filename).extension().string();
349 1 : std::smatch match;
350 2 : if (!std::regex_match(dir, match, local_regex) || match.size() != 4 ||
351 1 : suffix != ".parquet") {
352 0 : throw std::runtime_error(
353 0 : fmt::format("Unexpected filename {} in zip archive.", filename));
354 : }
355 :
356 : // Construct the path to store the table
357 1 : auto path = directory_ / "tables" / dir / (stem + suffix);
358 1 : SPDLOG_INFO("Storing table to {}", path.string());
359 :
360 : // Extract the file to memory
361 1 : std::vector<char> buffer(file_stat.m_uncomp_size);
362 1 : if (mz_zip_reader_extract_to_mem(&zip_archive, i, buffer.data(), buffer.size(), 0) == 0) {
363 0 : throw std::runtime_error(fmt::format("Failed to extract {}.", filename));
364 : }
365 :
366 : // Ensure the parent directory exists
367 1 : fs::create_directories(path.parent_path());
368 :
369 : // Save the extracted file
370 1 : std::ofstream out(path.string(), std::ios::binary);
371 1 : if (!out) {
372 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", path.string()));
373 : }
374 1 : out.write(buffer.data(), static_cast<std::streamsize>(buffer.size()));
375 1 : out.close();
376 :
377 : // Update the local asset/table info
378 1 : local_asset_info[key].version_minor = remote_version;
379 1 : local_asset_info[key].paths[path.stem().string()] = {path.string(), false};
380 1 : }
381 :
382 1 : mz_zip_reader_end(&zip_archive);
383 2 : }
384 :
385 1718 : void ParquetManager::cache_table(std::unordered_map<std::string, PathInfo>::iterator table_it) {
386 : // Check if the table is already cached
387 : {
388 1718 : std::shared_lock<std::shared_mutex> lock(mtx_local);
389 1718 : if (table_it->second.cached) {
390 1692 : return;
391 : }
392 1718 : }
393 :
394 : // Acquire a unique lock for caching the table
395 26 : std::unique_lock<std::shared_mutex> lock(mtx_local);
396 :
397 : // Re-check if the table is already cached because another thread might have cached it
398 26 : if (table_it->second.cached) {
399 0 : return;
400 : }
401 :
402 : // Cache the table in memory
403 26 : std::string table_name;
404 : {
405 26 : auto result = con.Query(R"(SELECT UUID()::varchar)");
406 26 : if (result->HasError()) {
407 0 : throw cpptrace::runtime_error("Error selecting a unique table name: " +
408 0 : result->GetError());
409 : }
410 : table_name =
411 26 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
412 26 : }
413 :
414 : {
415 26 : auto result = con.Query(fmt::format(R"(CREATE TEMP TABLE '{}' AS SELECT * FROM '{}')",
416 52 : table_name, table_it->second.path));
417 26 : if (result->HasError()) {
418 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
419 : }
420 26 : }
421 :
422 26 : table_it->second.path = table_name;
423 26 : table_it->second.cached = true;
424 26 : }
425 :
426 1721 : std::string ParquetManager::get_path(const std::string &key, const std::string &table) {
427 : // Update the local table if a newer version is available remotely
428 1721 : this->update_local_asset(key);
429 :
430 : // Ensure availability of the local table file
431 1721 : auto asset_it = local_asset_info.find(key);
432 1721 : if (asset_it == local_asset_info.end()) {
433 : // If we do not know about any table that can be downloaded, downloading might be blocked.
434 : // Otherwise, the species might be misspelled.
435 0 : if (remote_asset_info.empty()) {
436 0 : throw std::runtime_error(
437 0 : "No tables found for species '" + key +
438 : "'. Check whether you have allowed downloading missing tables, or download the "
439 0 : "tables manually via `pairinteraction download " +
440 0 : key + "`.");
441 : }
442 0 : throw std::runtime_error("No tables found for species '" + key +
443 0 : "'. Check the spelling of the species.");
444 : }
445 1721 : auto table_it = asset_it->second.paths.find(table);
446 1721 : if (table_it == asset_it->second.paths.end()) {
447 1 : throw std::runtime_error("No table '" + table + ".parquet' found for species '" + key +
448 2 : "'. The tables for the species are incomplete.");
449 : }
450 :
451 : // Cache the local table in memory if requested
452 1720 : if (use_cache_) {
453 1718 : this->cache_table(table_it);
454 : }
455 :
456 : // Return the path to the local table file
457 3440 : return table_it->second.path;
458 : }
459 :
460 2 : std::string ParquetManager::get_versions_info() const {
461 : // Helper lambda returns the version string if available
462 36 : auto get_version = [](const auto &map, const std::string &table) -> int {
463 36 : if (auto it = map.find(table); it != map.end()) {
464 18 : return it->second.version_minor;
465 : }
466 18 : return -1;
467 : };
468 :
469 : // Gather all unique table names
470 2 : std::set<std::string> tables;
471 20 : for (const auto &entry : local_asset_info) {
472 18 : tables.insert(entry.first);
473 : }
474 2 : for (const auto &entry : remote_asset_info) {
475 0 : tables.insert(entry.first);
476 : }
477 :
478 : // Output versions info
479 2 : std::ostringstream oss;
480 :
481 2 : oss << " ";
482 2 : oss << std::left << std::setw(17) << "Asset";
483 2 : oss << std::left << std::setw(6 + 4) << "Local";
484 2 : oss << std::left << std::setw(7) << "Remote\n";
485 2 : oss << std::string(35, '-') << "\n";
486 :
487 20 : for (const auto &table : tables) {
488 18 : int local_version = get_version(local_asset_info, table);
489 18 : int remote_version = get_version(remote_asset_info, table);
490 :
491 : std::string comparator = (local_version < remote_version)
492 : ? "<"
493 18 : : ((local_version > remote_version) ? ">" : "==");
494 : std::string local_version_str = local_version == -1
495 18 : ? "N/A"
496 36 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
497 54 : std::to_string(local_version);
498 : std::string remote_version_str = remote_version == -1
499 18 : ? "N/A"
500 18 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
501 36 : std::to_string(remote_version);
502 :
503 18 : oss << " ";
504 18 : oss << std::left << std::setw(17) << table;
505 18 : oss << std::left << std::setw(6) << local_version_str;
506 18 : oss << std::left << std::setw(4) << comparator;
507 18 : oss << std::left << std::setw(7) << remote_version_str << "\n";
508 18 : }
509 4 : return oss.str();
510 2 : }
511 :
512 : } // namespace pairinteraction
|