Line data Source code
1 : // SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/ParquetManager.hpp"
5 :
6 : #include "pairinteraction/database/GitHubDownloader.hpp"
7 : #include "pairinteraction/version.hpp"
8 :
9 : #include <cpptrace/cpptrace.hpp>
10 : #include <ctime>
11 : #include <duckdb.hpp>
12 : #include <filesystem>
13 : #include <fmt/core.h>
14 : #include <fstream>
15 : #include <future>
16 : #include <iomanip>
17 : #include <miniz.h>
18 : #include <nlohmann/json.hpp>
19 : #include <regex>
20 : #include <set>
21 : #include <spdlog/spdlog.h>
22 : #include <sstream>
23 : #include <stdexcept>
24 : #include <string>
25 :
26 : namespace fs = std::filesystem;
27 : using json = nlohmann::json;
28 :
29 1 : std::string format_time(std::time_t time_val) {
30 1 : std::tm *ptm = std::localtime(&time_val);
31 1 : std::ostringstream oss;
32 1 : oss << std::put_time(ptm, "%Y-%m-%d %H:%M:%S");
33 2 : return oss.str();
34 1 : }
35 :
36 0 : json load_json(const fs::path &file) {
37 0 : std::ifstream in(file);
38 0 : in.exceptions(std::ifstream::failbit | std::ifstream::badbit);
39 0 : json doc;
40 0 : in >> doc;
41 0 : return doc;
42 0 : }
43 :
44 1 : void save_json(const fs::path &file, const json &doc) {
45 1 : std::ofstream out(file);
46 1 : if (!out) {
47 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", file.string()));
48 : }
49 1 : out << doc;
50 1 : out.close();
51 1 : }
52 :
53 : namespace pairinteraction {
54 5 : ParquetManager::ParquetManager(std::filesystem::path directory, const GitHubDownloader &downloader,
55 : std::vector<std::string> repo_paths, duckdb::Connection &con,
56 5 : bool use_cache)
57 5 : : directory_(std::move(directory)), downloader(downloader), repo_paths_(std::move(repo_paths)),
58 5 : con(con), use_cache_(use_cache) {
59 : // Ensure the local directory exists
60 5 : if (!std::filesystem::exists(directory_ / "tables")) {
61 0 : fs::create_directories(directory_ / "tables");
62 : }
63 :
64 : // If repo paths are provided, check the GitHub rate limit
65 5 : if (!repo_paths_.empty()) {
66 1 : auto rate_limit = downloader.get_rate_limit();
67 1 : if (rate_limit.remaining == 0) {
68 0 : react_on_rate_limit_reached(rate_limit.reset_time);
69 : } else {
70 2 : SPDLOG_INFO("Remaining GitHub API requests: {}. Rate limit resets at {}.",
71 : rate_limit.remaining, format_time(rate_limit.reset_time));
72 : }
73 : }
74 5 : }
75 :
76 5 : void ParquetManager::scan_remote() {
77 : // If repo_paths_ is empty, we have nothing to do
78 5 : if (repo_paths_.empty()) {
79 4 : return;
80 : }
81 :
82 1 : remote_asset_info.clear();
83 :
84 : struct RepoDownload {
85 : std::string endpoint;
86 : fs::path cache_file;
87 : json cached_doc;
88 : std::future<GitHubDownloader::Result> future;
89 : };
90 1 : std::vector<RepoDownload> downloads;
91 1 : downloads.reserve(repo_paths_.size());
92 :
93 : // For each repo path, load its cached JSON (or an empty JSON) and issue the download
94 2 : for (const auto &endpoint : repo_paths_) {
95 : // Generate a unique cache filename per endpoint
96 : std::string cache_filename =
97 1 : "homepage_cache_" + std::to_string(std::hash<std::string>{}(endpoint)) + ".json";
98 1 : auto cache_file = directory_ / cache_filename;
99 :
100 : // Load cached JSON from file if it exists
101 1 : json cached_doc;
102 1 : if (std::filesystem::exists(cache_file)) {
103 : try {
104 0 : cached_doc = load_json(cache_file);
105 0 : } catch (const std::exception &e) {
106 0 : SPDLOG_WARN("Error reading {}: {}. Discarding homepage cache.", cache_file.string(),
107 : e.what());
108 0 : cached_doc = json{};
109 0 : }
110 : }
111 :
112 : // Extract the last-modified header from the cached JSON if available
113 1 : std::string last_modified;
114 1 : if (!cached_doc.is_null() && cached_doc.contains("last-modified")) {
115 0 : last_modified = cached_doc["last-modified"].get<std::string>();
116 : }
117 :
118 : // Issue the asynchronous download using the cached last-modified value.
119 1 : downloads.push_back(
120 1 : {endpoint, cache_file, cached_doc, downloader.download(endpoint, last_modified)});
121 1 : }
122 :
123 : // Process downloads for each repo path
124 2 : for (auto &dl : downloads) {
125 1 : auto result = dl.future.get();
126 1 : json doc;
127 1 : if (result.status_code == 200) {
128 1 : doc = json::parse(result.body, nullptr, /*allow_exceptions=*/false);
129 1 : doc["last-modified"] = result.last_modified;
130 1 : save_json(dl.cache_file, doc);
131 2 : SPDLOG_INFO("Using downloaded overview of available tables from {}.", dl.endpoint);
132 0 : } else if (result.status_code == 304) {
133 0 : if (dl.cached_doc.is_null() || dl.cached_doc.empty()) {
134 0 : throw std::runtime_error(
135 0 : fmt::format("Received 304 Not Modified but cached response {} does not exist.",
136 0 : dl.cache_file.string()));
137 : }
138 0 : doc = dl.cached_doc;
139 0 : SPDLOG_INFO("Using cached overview of available tables from {}.", dl.endpoint);
140 0 : } else if (result.status_code == 403 || result.status_code == 429) {
141 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
142 0 : return;
143 : } else {
144 0 : throw std::runtime_error(fmt::format(
145 : "Failed to download overview of available tables from {}: status code {}.",
146 0 : dl.endpoint, result.status_code));
147 : }
148 :
149 : // Validate the JSON response
150 1 : if (doc.is_discarded() || !doc.contains("assets")) {
151 0 : throw std::runtime_error(fmt::format(
152 0 : "Failed to parse remote JSON or missing 'assets' key from {}.", dl.endpoint));
153 : }
154 :
155 : // Update remote_asset_info based on the asset entries
156 2 : for (auto &asset : doc["assets"]) {
157 1 : std::string name = asset["name"].get<std::string>();
158 1 : std::smatch match;
159 :
160 1 : if (std::regex_match(name, match, remote_regex) && match.size() == 4) {
161 1 : std::string key = match[1].str();
162 1 : int version_major = std::stoi(match[2].str());
163 1 : int version_minor = std::stoi(match[3].str());
164 :
165 1 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
166 0 : continue;
167 : }
168 :
169 1 : auto it = remote_asset_info.find(key);
170 1 : if (it == remote_asset_info.end() || version_minor > it->second.version_minor) {
171 1 : std::string remote_url = asset["url"].get<std::string>();
172 1 : const std::string host = downloader.get_host();
173 2 : remote_asset_info[key] = {version_minor, remote_url.erase(0, host.size())};
174 1 : }
175 1 : }
176 1 : }
177 1 : }
178 :
179 : // Ensure that scan_remote was successful
180 1 : if (!downloads.empty() && remote_asset_info.empty()) {
181 0 : throw std::runtime_error(
182 : "No compatible database tables were found in the remote repositories. Consider "
183 0 : "upgrading PairInteraction to a newer version.");
184 : }
185 4 : }
186 :
187 5 : void ParquetManager::scan_local() {
188 5 : local_asset_info.clear();
189 :
190 : // Iterate over files in the directory to update local_asset_info
191 53 : for (const auto &entry : fs::directory_iterator(directory_ / "tables")) {
192 24 : std::string name = entry.path().filename().string();
193 24 : std::smatch match;
194 :
195 48 : if (entry.is_directory() && std::regex_match(name, match, local_regex) &&
196 24 : match.size() == 4) {
197 24 : std::string key = match[1].str();
198 24 : int version_major = std::stoi(match[2].str());
199 24 : int version_minor = std::stoi(match[3].str());
200 :
201 24 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
202 0 : continue;
203 : }
204 :
205 24 : auto it = local_asset_info.find(key);
206 24 : if (it == local_asset_info.end() || version_minor > it->second.version_minor) {
207 21 : local_asset_info[key].version_minor = version_minor;
208 203 : for (const auto &subentry : fs::directory_iterator(entry)) {
209 91 : if (subentry.is_regular_file() && subentry.path().extension() == ".parquet") {
210 182 : local_asset_info[key].paths[subentry.path().stem().string()] = {
211 273 : subentry.path().string(), false};
212 : }
213 21 : }
214 : }
215 24 : }
216 29 : }
217 96 : }
218 :
219 0 : void ParquetManager::react_on_rate_limit_reached(std::time_t reset_time) {
220 0 : repo_paths_.clear();
221 0 : remote_asset_info.clear();
222 0 : SPDLOG_WARN("Rate limit reached, resets at {}. The download of database tables is disabled.",
223 : format_time(reset_time));
224 0 : }
225 :
226 894 : void ParquetManager::update_local_asset(const std::string &key) {
227 894 : assert(remote_asset_info.empty() == repo_paths_.empty());
228 :
229 : // If remote_asset_info is empty, we have nothing to do
230 894 : if (remote_asset_info.empty()) {
231 893 : return;
232 : }
233 :
234 : // Get remote version if available
235 1 : int remote_version = -1;
236 1 : auto remote_it = remote_asset_info.find(key);
237 1 : if (remote_it != remote_asset_info.end()) {
238 1 : remote_version = remote_it->second.version_minor;
239 : }
240 :
241 : // Get local version if available and check if it is up-to-date
242 : {
243 1 : int local_version = -1;
244 1 : std::shared_lock<std::shared_mutex> lock(mtx_local);
245 1 : auto local_it = local_asset_info.find(key);
246 1 : if (local_it != local_asset_info.end()) {
247 1 : local_version = local_it->second.version_minor;
248 : }
249 1 : if (local_version >= remote_version) {
250 0 : return;
251 : }
252 1 : }
253 :
254 : // If it is not up-to-date, acquire a unique lock for updating the table
255 1 : std::unique_lock<std::shared_mutex> lock(mtx_local);
256 :
257 : // Re-check if the table is up to date because another thread might have updated it
258 1 : int local_version = -1;
259 1 : auto local_it = local_asset_info.find(key);
260 1 : if (local_it != local_asset_info.end()) {
261 1 : local_version = local_it->second.version_minor;
262 : }
263 1 : if (local_version >= remote_version) {
264 0 : return;
265 : }
266 :
267 : // Download the remote file
268 1 : std::string endpoint = remote_it->second.endpoint;
269 2 : SPDLOG_INFO("Downloading {}_v{}.{} from {}", key, COMPATIBLE_DATABASE_VERSION_MAJOR,
270 : remote_version, endpoint);
271 :
272 1 : auto result = downloader.download(endpoint, "", true).get();
273 1 : if ((result.status_code == 403 || result.status_code == 429) &&
274 0 : result.rate_limit.remaining == 0) {
275 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
276 0 : return;
277 : }
278 1 : if (result.status_code != 200) {
279 0 : throw std::runtime_error(fmt::format("Failed to download table {}: status code {}.",
280 0 : endpoint, result.status_code));
281 : }
282 :
283 : // Unzip the downloaded file
284 1 : mz_zip_archive zip_archive{};
285 1 : if (mz_zip_reader_init_mem(&zip_archive, result.body.data(), result.body.size(), 0) == 0) {
286 0 : throw std::runtime_error("Failed to initialize zip archive.");
287 : }
288 :
289 2 : for (mz_uint i = 0; i < mz_zip_reader_get_num_files(&zip_archive); i++) {
290 : mz_zip_archive_file_stat file_stat;
291 1 : if (mz_zip_reader_file_stat(&zip_archive, i, &file_stat) == 0) {
292 0 : throw std::runtime_error("Failed to get file stat from zip archive.");
293 : }
294 :
295 : // Skip directories
296 1 : const char *filename = static_cast<const char *>(file_stat.m_filename);
297 1 : size_t len = std::strlen(filename);
298 1 : if (len > 0 && filename[len - 1] == '/') {
299 0 : continue;
300 : }
301 :
302 : // Ensure that the filename matches the expectations
303 1 : std::string dir = fs::path(filename).parent_path().string();
304 1 : std::string stem = fs::path(filename).stem().string();
305 1 : std::string suffix = fs::path(filename).extension().string();
306 1 : std::smatch match;
307 2 : if (!std::regex_match(dir, match, local_regex) || match.size() != 4 ||
308 1 : suffix != ".parquet") {
309 0 : throw std::runtime_error(
310 0 : fmt::format("Unexpected filename {} in zip archive.", filename));
311 : }
312 :
313 : // Construct the path to store the table
314 1 : auto path = directory_ / "tables" / dir / (stem + suffix);
315 2 : SPDLOG_INFO("Storing table to {}", path.string());
316 :
317 : // Extract the file to memory
318 1 : std::vector<char> buffer(file_stat.m_uncomp_size);
319 1 : if (mz_zip_reader_extract_to_mem(&zip_archive, i, buffer.data(), buffer.size(), 0) == 0) {
320 0 : throw std::runtime_error(fmt::format("Failed to extract {}.", filename));
321 : }
322 :
323 : // Ensure the parent directory exists
324 1 : fs::create_directories(path.parent_path());
325 :
326 : // Save the extracted file
327 1 : std::ofstream out(path.string(), std::ios::binary);
328 1 : if (!out) {
329 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", path.string()));
330 : }
331 1 : out.write(buffer.data(), static_cast<std::streamsize>(buffer.size()));
332 1 : out.close();
333 :
334 : // Update the local asset/table info
335 1 : local_asset_info[key].version_minor = remote_version;
336 1 : local_asset_info[key].paths[path.stem().string()] = {path.string(), false};
337 1 : }
338 :
339 1 : mz_zip_reader_end(&zip_archive);
340 2 : }
341 :
342 891 : void ParquetManager::cache_table(std::unordered_map<std::string, PathInfo>::iterator table_it) {
343 : // Check if the table is already cached
344 : {
345 891 : std::shared_lock<std::shared_mutex> lock(mtx_local);
346 891 : if (table_it->second.cached) {
347 865 : return;
348 : }
349 891 : }
350 :
351 : // Acquire a unique lock for caching the table
352 26 : std::unique_lock<std::shared_mutex> lock(mtx_local);
353 :
354 : // Re-check if the table is already cached because another thread might have cached it
355 26 : if (table_it->second.cached) {
356 0 : return;
357 : }
358 :
359 : // Cache the table in memory
360 26 : std::string table_name;
361 : {
362 26 : auto result = con.Query(R"(SELECT UUID()::varchar)");
363 26 : if (result->HasError()) {
364 0 : throw cpptrace::runtime_error("Error selecting a unique table name: " +
365 0 : result->GetError());
366 : }
367 : table_name =
368 26 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
369 26 : }
370 :
371 : {
372 26 : auto result = con.Query(fmt::format(R"(CREATE TEMP TABLE '{}' AS SELECT * FROM '{}')",
373 52 : table_name, table_it->second.path));
374 26 : if (result->HasError()) {
375 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
376 : }
377 26 : }
378 :
379 26 : table_it->second.path = table_name;
380 26 : table_it->second.cached = true;
381 26 : }
382 :
383 894 : std::string ParquetManager::get_path(const std::string &key, const std::string &table) {
384 : // Update the local table if a newer version is available remotely
385 894 : this->update_local_asset(key);
386 :
387 : // Ensure availability of the local table file
388 894 : auto asset_it = local_asset_info.find(key);
389 894 : if (asset_it == local_asset_info.end()) {
390 : // If we do not know about any table that can be downloaded, downloading might be blocked.
391 : // Otherwise, the species might be misspelled.
392 0 : if (remote_asset_info.empty()) {
393 0 : throw std::runtime_error(
394 0 : "No tables found for species '" + key +
395 : "'. Check whether you have allowed downloading missing tables, or download the "
396 0 : "tables manually via `pairinteraction download " +
397 0 : key + "`.");
398 : }
399 0 : throw std::runtime_error("No tables found for species '" + key +
400 0 : "'. Check the spelling of the species.");
401 : }
402 894 : auto table_it = asset_it->second.paths.find(table);
403 894 : if (table_it == asset_it->second.paths.end()) {
404 1 : throw std::runtime_error("No table '" + table + ".parquet' found for species '" + key +
405 2 : "'. The tables for the species are incomplete.");
406 : }
407 :
408 : // Cache the local table in memory if requested
409 893 : if (use_cache_) {
410 891 : this->cache_table(table_it);
411 : }
412 :
413 : // Return the path to the local table file
414 1786 : return table_it->second.path;
415 : }
416 :
417 2 : std::string ParquetManager::get_versions_info() const {
418 : // Helper lambda returns the version string if available
419 36 : auto get_version = [](const auto &map, const std::string &table) -> int {
420 36 : if (auto it = map.find(table); it != map.end()) {
421 18 : return it->second.version_minor;
422 : }
423 18 : return -1;
424 : };
425 :
426 : // Gather all unique table names
427 2 : std::set<std::string> tables;
428 20 : for (const auto &entry : local_asset_info) {
429 18 : tables.insert(entry.first);
430 : }
431 2 : for (const auto &entry : remote_asset_info) {
432 0 : tables.insert(entry.first);
433 : }
434 :
435 : // Output versions info
436 2 : std::ostringstream oss;
437 :
438 2 : oss << " ";
439 2 : oss << std::left << std::setw(17) << "Asset";
440 2 : oss << std::left << std::setw(6 + 4) << "Local";
441 2 : oss << std::left << std::setw(7) << "Remote\n";
442 2 : oss << std::string(35, '-') << "\n";
443 :
444 20 : for (const auto &table : tables) {
445 18 : int local_version = get_version(local_asset_info, table);
446 18 : int remote_version = get_version(remote_asset_info, table);
447 :
448 : std::string comparator = (local_version < remote_version)
449 : ? "<"
450 18 : : ((local_version > remote_version) ? ">" : "==");
451 : std::string local_version_str = local_version == -1
452 18 : ? "N/A"
453 36 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
454 54 : std::to_string(local_version);
455 : std::string remote_version_str = remote_version == -1
456 18 : ? "N/A"
457 18 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
458 36 : std::to_string(remote_version);
459 :
460 18 : oss << " ";
461 18 : oss << std::left << std::setw(17) << table;
462 18 : oss << std::left << std::setw(6) << local_version_str;
463 18 : oss << std::left << std::setw(4) << comparator;
464 18 : oss << std::left << std::setw(7) << remote_version_str << "\n";
465 18 : }
466 4 : return oss.str();
467 2 : }
468 :
469 : } // namespace pairinteraction
|