Line data Source code
1 : // SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/ParquetManager.hpp"
5 :
6 : #include "pairinteraction/database/GitHubDownloader.hpp"
7 : #include "pairinteraction/version.hpp"
8 :
9 : #include <cpptrace/cpptrace.hpp>
10 : #include <ctime>
11 : #include <duckdb.hpp>
12 : #include <filesystem>
13 : #include <fmt/core.h>
14 : #include <fstream>
15 : #include <future>
16 : #include <iomanip>
17 : #include <miniz.h>
18 : #include <nlohmann/json.hpp>
19 : #include <regex>
20 : #include <set>
21 : #include <spdlog/spdlog.h>
22 : #include <sstream>
23 : #include <stdexcept>
24 : #include <string>
25 :
26 : namespace fs = std::filesystem;
27 : using json = nlohmann::json;
28 :
29 : namespace {
30 :
31 1 : std::string format_time(std::time_t time_val) {
32 1 : std::tm *ptm = std::localtime(&time_val);
33 1 : std::ostringstream oss;
34 1 : oss << std::put_time(ptm, "%Y-%m-%d %H:%M:%S");
35 2 : return oss.str();
36 1 : }
37 :
38 0 : json load_json(const fs::path &file) {
39 0 : std::ifstream in(file);
40 0 : in.exceptions(std::ifstream::failbit | std::ifstream::badbit);
41 0 : json doc;
42 0 : in >> doc;
43 0 : return doc;
44 0 : }
45 :
46 1 : void save_json(const fs::path &file, const json &doc) {
47 1 : std::ofstream out(file);
48 1 : if (!out) {
49 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", file.string()));
50 : }
51 1 : out << doc;
52 1 : out.close();
53 1 : }
54 :
55 : } // namespace
56 :
57 : namespace pairinteraction {
58 0 : void ParquetManager::react_on_exception(const std::string &context, const std::exception &e) {
59 0 : repo_paths_.clear();
60 0 : remote_asset_info.clear();
61 0 : SPDLOG_ERROR("{}: {}. The download of database tables is disabled.", context, e.what());
62 0 : }
63 :
64 0 : void ParquetManager::react_on_error_code(const std::string &context, int error_code) {
65 0 : repo_paths_.clear();
66 0 : remote_asset_info.clear();
67 0 : SPDLOG_ERROR("{}: status code {}. The download of database tables is disabled.", context,
68 : error_code);
69 0 : }
70 :
71 0 : void ParquetManager::react_on_rate_limit_reached(std::time_t reset_time) {
72 0 : repo_paths_.clear();
73 0 : remote_asset_info.clear();
74 0 : SPDLOG_ERROR("Rate limit reached, resets at {}. The download of database tables is disabled.",
75 : format_time(reset_time));
76 0 : }
77 :
78 5 : ParquetManager::ParquetManager(std::filesystem::path directory, const GitHubDownloader &downloader,
79 : std::vector<std::string> repo_paths, duckdb::Connection &con,
80 5 : bool use_cache)
81 5 : : directory_(std::move(directory)), downloader(downloader), repo_paths_(std::move(repo_paths)),
82 5 : con(con), use_cache_(use_cache) {
83 : // Ensure the local directory exists
84 5 : if (!std::filesystem::exists(directory_ / "tables")) {
85 0 : fs::create_directories(directory_ / "tables");
86 : }
87 :
88 : // If repo paths are provided, check the GitHub rate limit
89 5 : if (!repo_paths_.empty()) {
90 1 : GitHubDownloader::Result result;
91 : try {
92 1 : result = downloader.download("/rate_limit", "", false).get();
93 0 : } catch (const std::exception &e) {
94 0 : react_on_exception("Failed obtaining the rate limit", e);
95 0 : return;
96 0 : }
97 1 : if (result.status_code != 200) {
98 0 : react_on_error_code("Failed obtaining the rate limit", result.status_code);
99 0 : return;
100 : }
101 1 : if (result.rate_limit.remaining == 0) {
102 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
103 0 : return;
104 : }
105 1 : SPDLOG_INFO("Remaining GitHub API requests: {}. Rate limit resets at {}.",
106 : result.rate_limit.remaining, format_time(result.rate_limit.reset_time));
107 1 : }
108 0 : }
109 :
110 5 : void ParquetManager::scan_remote() {
111 : // If repo_paths_ is empty, we have nothing to do
112 5 : if (repo_paths_.empty()) {
113 4 : return;
114 : }
115 :
116 1 : remote_asset_info.clear();
117 :
118 : struct RepoDownload {
119 : std::string endpoint;
120 : fs::path cache_file;
121 : json cached_doc;
122 : std::future<GitHubDownloader::Result> future;
123 : };
124 1 : std::vector<RepoDownload> downloads;
125 1 : downloads.reserve(repo_paths_.size());
126 :
127 : // For each repo path, load its cached JSON (or an empty JSON) and issue the download
128 2 : for (const auto &endpoint : repo_paths_) {
129 : // Generate a unique cache filename per endpoint
130 : std::string cache_filename =
131 1 : "homepage_cache_" + std::to_string(std::hash<std::string>{}(endpoint)) + ".json";
132 1 : auto cache_file = directory_ / cache_filename;
133 :
134 : // Load cached JSON from file if it exists
135 1 : json cached_doc;
136 1 : if (std::filesystem::exists(cache_file)) {
137 : try {
138 0 : cached_doc = load_json(cache_file);
139 0 : } catch (const std::exception &e) {
140 0 : SPDLOG_WARN("Error reading {}: {}. Discarding homepage cache.", cache_file.string(),
141 : e.what());
142 0 : cached_doc = json{};
143 0 : }
144 : }
145 :
146 : // Extract the last-modified header from the cached JSON if available
147 1 : std::string last_modified;
148 1 : if (!cached_doc.is_null() && cached_doc.contains("last-modified")) {
149 0 : last_modified = cached_doc["last-modified"].get<std::string>();
150 : }
151 :
152 : // Issue the asynchronous download using the cached last-modified value.
153 : try {
154 1 : downloads.push_back(
155 1 : {endpoint, cache_file, cached_doc, downloader.download(endpoint, last_modified)});
156 0 : } catch (const std::exception &e) {
157 0 : react_on_exception(
158 0 : fmt::format("Failed to download overview of available tables from {}",
159 0 : downloads.back().endpoint),
160 : e);
161 0 : return;
162 0 : }
163 1 : }
164 :
165 : // Process downloads for each repo path
166 2 : for (auto &dl : downloads) {
167 1 : auto result = dl.future.get();
168 1 : if ((result.status_code == 403 || result.status_code == 429) &&
169 0 : result.rate_limit.remaining == 0) {
170 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
171 0 : return;
172 : }
173 1 : if (result.status_code != 200 && result.status_code != 304) {
174 0 : react_on_error_code(
175 0 : fmt::format("Failed to download overview of available tables from {}", dl.endpoint),
176 : result.status_code);
177 0 : return;
178 : }
179 :
180 1 : json doc;
181 1 : if (result.status_code == 304) {
182 0 : if (dl.cached_doc.is_null() || dl.cached_doc.empty()) {
183 0 : throw std::runtime_error(
184 0 : fmt::format("Received 304 Not Modified but cached response {} does not exist.",
185 0 : dl.cache_file.string()));
186 : }
187 0 : doc = dl.cached_doc;
188 0 : SPDLOG_INFO("Using cached overview of available tables from {}.", dl.endpoint);
189 : } else {
190 1 : doc = json::parse(result.body, nullptr, /*allow_exceptions=*/false);
191 1 : doc["last-modified"] = result.last_modified;
192 1 : save_json(dl.cache_file, doc);
193 1 : SPDLOG_INFO("Using downloaded overview of available tables from {}.", dl.endpoint);
194 : }
195 :
196 : // Validate the JSON response
197 1 : if (doc.is_discarded() || !doc.contains("assets")) {
198 0 : throw std::runtime_error(fmt::format(
199 0 : "Failed to parse remote JSON or missing 'assets' key from {}.", dl.endpoint));
200 : }
201 :
202 : // Update remote_asset_info based on the asset entries
203 2 : for (auto &asset : doc["assets"]) {
204 1 : std::string name = asset["name"].get<std::string>();
205 1 : std::smatch match;
206 :
207 1 : if (std::regex_match(name, match, remote_regex) && match.size() == 4) {
208 1 : std::string key = match[1].str();
209 1 : int version_major = std::stoi(match[2].str());
210 1 : int version_minor = std::stoi(match[3].str());
211 :
212 1 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
213 0 : continue;
214 : }
215 :
216 1 : auto it = remote_asset_info.find(key);
217 1 : if (it == remote_asset_info.end() || version_minor > it->second.version_minor) {
218 1 : std::string remote_url = asset["url"].get<std::string>();
219 1 : const std::string host = downloader.get_host();
220 2 : remote_asset_info[key] = {version_minor, remote_url.erase(0, host.size())};
221 1 : }
222 1 : }
223 1 : }
224 1 : }
225 :
226 : // Ensure that scan_remote was successful
227 1 : if (!downloads.empty() && remote_asset_info.empty()) {
228 0 : throw std::runtime_error(
229 : "No compatible database tables were found in the remote repositories. Consider "
230 0 : "upgrading PairInteraction to a newer version.");
231 : }
232 4 : }
233 :
234 5 : void ParquetManager::scan_local() {
235 5 : local_asset_info.clear();
236 :
237 : // Iterate over files in the directory to update local_asset_info
238 53 : for (const auto &entry : fs::directory_iterator(directory_ / "tables")) {
239 24 : std::string name = entry.path().filename().string();
240 24 : std::smatch match;
241 :
242 48 : if (entry.is_directory() && std::regex_match(name, match, local_regex) &&
243 24 : match.size() == 4) {
244 24 : std::string key = match[1].str();
245 24 : int version_major = std::stoi(match[2].str());
246 24 : int version_minor = std::stoi(match[3].str());
247 :
248 24 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
249 0 : continue;
250 : }
251 :
252 24 : auto it = local_asset_info.find(key);
253 24 : if (it == local_asset_info.end() || version_minor > it->second.version_minor) {
254 21 : local_asset_info[key].version_minor = version_minor;
255 203 : for (const auto &subentry : fs::directory_iterator(entry)) {
256 91 : if (subentry.is_regular_file() && subentry.path().extension() == ".parquet") {
257 182 : local_asset_info[key].paths[subentry.path().stem().string()] = {
258 273 : subentry.path().string(), false};
259 : }
260 21 : }
261 : }
262 24 : }
263 29 : }
264 96 : }
265 :
266 1690 : void ParquetManager::update_local_asset(const std::string &key) {
267 1690 : assert(remote_asset_info.empty() == repo_paths_.empty());
268 :
269 : // If remote_asset_info is empty, we have nothing to do
270 1690 : if (remote_asset_info.empty()) {
271 1689 : return;
272 : }
273 :
274 : // Get remote version if available
275 1 : int remote_version = -1;
276 1 : auto remote_it = remote_asset_info.find(key);
277 1 : if (remote_it != remote_asset_info.end()) {
278 1 : remote_version = remote_it->second.version_minor;
279 : }
280 :
281 : // Get local version if available and check if it is up-to-date
282 : {
283 1 : int local_version = -1;
284 1 : std::shared_lock<std::shared_mutex> lock(mtx_local);
285 1 : auto local_it = local_asset_info.find(key);
286 1 : if (local_it != local_asset_info.end()) {
287 1 : local_version = local_it->second.version_minor;
288 : }
289 1 : if (local_version >= remote_version) {
290 0 : return;
291 : }
292 1 : }
293 :
294 : // If it is not up-to-date, acquire a unique lock for updating the table
295 1 : std::unique_lock<std::shared_mutex> lock(mtx_local);
296 :
297 : // Re-check if the table is up to date because another thread might have updated it
298 1 : int local_version = -1;
299 1 : auto local_it = local_asset_info.find(key);
300 1 : if (local_it != local_asset_info.end()) {
301 1 : local_version = local_it->second.version_minor;
302 : }
303 1 : if (local_version >= remote_version) {
304 0 : return;
305 : }
306 :
307 : // Download the remote file
308 1 : std::string endpoint = remote_it->second.endpoint;
309 1 : SPDLOG_INFO("Downloading {}_v{}.{} from {}", key, COMPATIBLE_DATABASE_VERSION_MAJOR,
310 : remote_version, endpoint);
311 :
312 1 : GitHubDownloader::Result result;
313 : try {
314 1 : result = downloader.download(endpoint, "", true).get();
315 0 : } catch (const std::exception &e) {
316 0 : react_on_exception(fmt::format("Failed to download table {}", endpoint), e);
317 0 : return;
318 0 : }
319 1 : if ((result.status_code == 403 || result.status_code == 429) &&
320 0 : result.rate_limit.remaining == 0) {
321 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
322 0 : return;
323 : }
324 1 : if (result.status_code != 200) {
325 0 : react_on_error_code(fmt::format("Failed to download table {}", endpoint),
326 : result.status_code);
327 0 : return;
328 : }
329 :
330 : // Unzip the downloaded file
331 1 : mz_zip_archive zip_archive{};
332 1 : if (mz_zip_reader_init_mem(&zip_archive, result.body.data(), result.body.size(), 0) == 0) {
333 0 : throw std::runtime_error("Failed to initialize zip archive.");
334 : }
335 :
336 2 : for (mz_uint i = 0; i < mz_zip_reader_get_num_files(&zip_archive); i++) {
337 : mz_zip_archive_file_stat file_stat;
338 1 : if (mz_zip_reader_file_stat(&zip_archive, i, &file_stat) == 0) {
339 0 : throw std::runtime_error("Failed to get file stat from zip archive.");
340 : }
341 :
342 : // Skip directories
343 1 : const char *filename = static_cast<const char *>(file_stat.m_filename);
344 1 : size_t len = std::strlen(filename);
345 1 : if (len > 0 && filename[len - 1] == '/') {
346 0 : continue;
347 : }
348 :
349 : // Ensure that the filename matches the expectations
350 1 : std::string dir = fs::path(filename).parent_path().string();
351 1 : std::string stem = fs::path(filename).stem().string();
352 1 : std::string suffix = fs::path(filename).extension().string();
353 1 : std::smatch match;
354 2 : if (!std::regex_match(dir, match, local_regex) || match.size() != 4 ||
355 1 : suffix != ".parquet") {
356 0 : throw std::runtime_error(
357 0 : fmt::format("Unexpected filename {} in zip archive.", filename));
358 : }
359 :
360 : // Construct the path to store the table
361 1 : auto path = directory_ / "tables" / dir / (stem + suffix);
362 1 : SPDLOG_INFO("Storing table to {}", path.string());
363 :
364 : // Extract the file to memory
365 1 : std::vector<char> buffer(file_stat.m_uncomp_size);
366 1 : if (mz_zip_reader_extract_to_mem(&zip_archive, i, buffer.data(), buffer.size(), 0) == 0) {
367 0 : throw std::runtime_error(fmt::format("Failed to extract {}.", filename));
368 : }
369 :
370 : // Ensure the parent directory exists
371 1 : fs::create_directories(path.parent_path());
372 :
373 : // Save the extracted file
374 1 : std::ofstream out(path.string(), std::ios::binary);
375 1 : if (!out) {
376 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", path.string()));
377 : }
378 1 : out.write(buffer.data(), static_cast<std::streamsize>(buffer.size()));
379 1 : out.close();
380 :
381 : // Update the local asset/table info
382 1 : local_asset_info[key].version_minor = remote_version;
383 1 : local_asset_info[key].paths[path.stem().string()] = {path.string(), false};
384 1 : }
385 :
386 1 : mz_zip_reader_end(&zip_archive);
387 2 : }
388 :
389 1687 : void ParquetManager::cache_table(std::unordered_map<std::string, PathInfo>::iterator table_it) {
390 : // Check if the table is already cached
391 : {
392 1687 : std::shared_lock<std::shared_mutex> lock(mtx_local);
393 1687 : if (table_it->second.cached) {
394 1661 : return;
395 : }
396 1687 : }
397 :
398 : // Acquire a unique lock for caching the table
399 26 : std::unique_lock<std::shared_mutex> lock(mtx_local);
400 :
401 : // Re-check if the table is already cached because another thread might have cached it
402 26 : if (table_it->second.cached) {
403 0 : return;
404 : }
405 :
406 : // Cache the table in memory
407 26 : std::string table_name;
408 : {
409 26 : auto result = con.Query(R"(SELECT UUID()::varchar)");
410 26 : if (result->HasError()) {
411 0 : throw cpptrace::runtime_error("Error selecting a unique table name: " +
412 0 : result->GetError());
413 : }
414 : table_name =
415 26 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
416 26 : }
417 :
418 : {
419 26 : auto result = con.Query(fmt::format(R"(CREATE TEMP TABLE '{}' AS SELECT * FROM '{}')",
420 52 : table_name, table_it->second.path));
421 26 : if (result->HasError()) {
422 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
423 : }
424 26 : }
425 :
426 26 : table_it->second.path = table_name;
427 26 : table_it->second.cached = true;
428 26 : }
429 :
430 1690 : std::string ParquetManager::get_path(const std::string &key, const std::string &table) {
431 : // Update the local table if a newer version is available remotely
432 1690 : this->update_local_asset(key);
433 :
434 : // Ensure availability of the local table file
435 1690 : auto asset_it = local_asset_info.find(key);
436 1690 : if (asset_it == local_asset_info.end()) {
437 : // If we do not know about any table that can be downloaded, downloading might be blocked.
438 : // Otherwise, the species might be misspelled.
439 0 : if (remote_asset_info.empty()) {
440 0 : throw std::runtime_error(
441 0 : "No tables found for species '" + key +
442 : "'. Check whether you have allowed downloading missing tables, or download the "
443 0 : "tables manually via `pairinteraction database download " +
444 0 : key + "`.");
445 : }
446 0 : throw std::runtime_error("No tables found for species '" + key +
447 0 : "'. Check the spelling of the species.");
448 : }
449 1690 : auto table_it = asset_it->second.paths.find(table);
450 1690 : if (table_it == asset_it->second.paths.end()) {
451 1 : throw std::runtime_error("No table '" + table + ".parquet' found for species '" + key +
452 2 : "'. The tables for the species are incomplete.");
453 : }
454 :
455 : // Cache the local table in memory if requested
456 1689 : if (use_cache_) {
457 1687 : this->cache_table(table_it);
458 : }
459 :
460 : // Return the path to the local table file
461 3378 : return table_it->second.path;
462 : }
463 :
464 2 : std::string ParquetManager::get_versions_info() const {
465 : // Helper lambda returns the version string if available
466 36 : auto get_version = [](const auto &map, const std::string &table) -> int {
467 36 : if (auto it = map.find(table); it != map.end()) {
468 18 : return it->second.version_minor;
469 : }
470 18 : return -1;
471 : };
472 :
473 : // Gather all unique table names
474 2 : std::set<std::string> tables;
475 20 : for (const auto &entry : local_asset_info) {
476 18 : tables.insert(entry.first);
477 : }
478 2 : for (const auto &entry : remote_asset_info) {
479 0 : tables.insert(entry.first);
480 : }
481 :
482 : // Output versions info
483 2 : std::ostringstream oss;
484 :
485 2 : oss << " ";
486 2 : oss << std::left << std::setw(17) << "Asset";
487 2 : oss << std::left << std::setw(6 + 4) << "Local";
488 2 : oss << std::left << std::setw(7) << "Remote\n";
489 2 : oss << std::string(35, '-') << "\n";
490 :
491 20 : for (const auto &table : tables) {
492 18 : int local_version = get_version(local_asset_info, table);
493 18 : int remote_version = get_version(remote_asset_info, table);
494 :
495 : std::string comparator = (local_version < remote_version)
496 : ? "<"
497 18 : : ((local_version > remote_version) ? ">" : "==");
498 : std::string local_version_str = local_version == -1
499 18 : ? "N/A"
500 36 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
501 54 : std::to_string(local_version);
502 : std::string remote_version_str = remote_version == -1
503 18 : ? "N/A"
504 18 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
505 36 : std::to_string(remote_version);
506 :
507 18 : oss << " ";
508 18 : oss << std::left << std::setw(17) << table;
509 18 : oss << std::left << std::setw(6) << local_version_str;
510 18 : oss << std::left << std::setw(4) << comparator;
511 18 : oss << std::left << std::setw(7) << remote_version_str << "\n";
512 18 : }
513 4 : return oss.str();
514 2 : }
515 :
516 : } // namespace pairinteraction
|