LCOV - code coverage report
Current view: top level - src/database - ParquetManager.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 209 302 69.2 %
Date: 2026-06-19 12:50:25 Functions: 11 15 73.3 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/database/ParquetManager.hpp"
       5             : 
       6             : #include "pairinteraction/database/GitHubDownloader.hpp"
       7             : #include "pairinteraction/version.hpp"
       8             : 
       9             : #include <cpptrace/cpptrace.hpp>
      10             : #include <ctime>
      11             : #include <duckdb.hpp>
      12             : #include <filesystem>
      13             : #include <fmt/core.h>
      14             : #include <fmt/format.h>
      15             : #include <fstream>
      16             : #include <future>
      17             : #include <iomanip>
      18             : #include <miniz.h>
      19             : #include <nlohmann/json.hpp>
      20             : #include <regex>
      21             : #include <set>
      22             : #include <spdlog/spdlog.h>
      23             : #include <sstream>
      24             : #include <stdexcept>
      25             : #include <string>
      26             : 
      27             : namespace fs = std::filesystem;
      28             : using json = nlohmann::json;
      29             : 
      30             : namespace pairinteraction {
      31             : namespace {
      32             : 
      33           1 : std::string format_time(std::time_t time_val) {
      34           1 :     std::tm *ptm = std::localtime(&time_val);
      35           1 :     std::ostringstream oss;
      36           1 :     oss << std::put_time(ptm, "%Y-%m-%d %H:%M:%S");
      37           2 :     return oss.str();
      38           1 : }
      39             : 
      40           0 : json load_json(const fs::path &file) {
      41           0 :     std::ifstream in(file);
      42           0 :     in.exceptions(std::ifstream::failbit | std::ifstream::badbit);
      43           0 :     json doc;
      44           0 :     in >> doc;
      45           0 :     return doc;
      46           0 : }
      47             : 
      48           1 : void save_json(const fs::path &file, const json &doc) {
      49           1 :     std::ofstream out(file);
      50           1 :     if (!out) {
      51           0 :         throw std::runtime_error(fmt::format("Failed to open {} for writing", file.string()));
      52             :     }
      53           1 :     out << doc;
      54           1 :     out.close();
      55           1 : }
      56             : 
      57             : } // namespace
      58             : 
      59           0 : void ParquetManager::react_on_exception(const std::string &context, const std::exception &e) {
      60           0 :     repo_paths_.clear();
      61           0 :     remote_asset_info.clear();
      62           0 :     SPDLOG_ERROR("{}: {}. The download of database tables is disabled.", context, e.what());
      63           0 : }
      64             : 
      65           0 : void ParquetManager::react_on_error_code(const std::string &context, int error_code) {
      66           0 :     repo_paths_.clear();
      67           0 :     remote_asset_info.clear();
      68           0 :     SPDLOG_ERROR("{}: status code {}. The download of database tables is disabled.", context,
      69             :                  error_code);
      70           0 : }
      71             : 
      72           0 : void ParquetManager::react_on_rate_limit_reached(std::time_t reset_time) {
      73           0 :     repo_paths_.clear();
      74           0 :     remote_asset_info.clear();
      75           0 :     SPDLOG_ERROR("Rate limit reached, resets at {}. The download of database tables is disabled.",
      76             :                  format_time(reset_time));
      77           0 : }
      78             : 
      79           6 : ParquetManager::ParquetManager(std::filesystem::path directory, const GitHubDownloader &downloader,
      80             :                                std::vector<std::string> repo_paths, duckdb::Connection &con,
      81           6 :                                bool use_cache)
      82           6 :     : directory_(std::move(directory)), downloader(downloader), repo_paths_(std::move(repo_paths)),
      83           6 :       con(con), use_cache_(use_cache) {
      84             :     // Ensure the local directory exists
      85           6 :     if (!std::filesystem::exists(directory_ / "tables")) {
      86           0 :         fs::create_directories(directory_ / "tables");
      87             :     }
      88             : 
      89             :     // If repo paths are provided, check the GitHub rate limit
      90           6 :     if (!repo_paths_.empty()) {
      91           1 :         GitHubDownloader::Result result;
      92             :         try {
      93           1 :             result = downloader.download("/rate_limit", "", false).get();
      94           0 :         } catch (const std::exception &e) {
      95           0 :             react_on_exception("Failed obtaining the rate limit", e);
      96           0 :             return;
      97           0 :         }
      98           1 :         if (result.status_code != 200) {
      99           0 :             react_on_error_code("Failed obtaining the rate limit", result.status_code);
     100           0 :             return;
     101             :         }
     102           1 :         if (result.rate_limit.remaining == 0) {
     103           0 :             react_on_rate_limit_reached(result.rate_limit.reset_time);
     104           0 :             return;
     105             :         }
     106           1 :         SPDLOG_INFO("Remaining GitHub API requests: {}. Rate limit resets at {}.",
     107             :                     result.rate_limit.remaining, format_time(result.rate_limit.reset_time));
     108           1 :     }
     109           0 : }
     110             : 
     111           6 : void ParquetManager::scan_remote() {
     112             :     // If repo_paths_ is empty, we have nothing to do
     113           6 :     if (repo_paths_.empty()) {
     114           5 :         return;
     115             :     }
     116             : 
     117           1 :     remote_asset_info.clear();
     118             : 
     119             :     struct RepoDownload {
     120             :         std::string endpoint;
     121             :         fs::path cache_file;
     122             :         json cached_doc;
     123             :         std::future<GitHubDownloader::Result> future;
     124             :     };
     125           1 :     std::vector<RepoDownload> downloads;
     126           1 :     downloads.reserve(repo_paths_.size());
     127             : 
     128             :     // For each repo path, load its cached JSON (or an empty JSON) and issue the download
     129           2 :     for (const auto &endpoint : repo_paths_) {
     130             :         // Generate a unique cache filename per endpoint
     131             :         std::string cache_filename =
     132           1 :             "homepage_cache_" + std::to_string(std::hash<std::string>{}(endpoint)) + ".json";
     133           1 :         auto cache_file = directory_ / cache_filename;
     134             : 
     135             :         // Load cached JSON from file if it exists
     136           1 :         json cached_doc;
     137           1 :         if (std::filesystem::exists(cache_file)) {
     138             :             try {
     139           0 :                 cached_doc = load_json(cache_file);
     140           0 :             } catch (const std::exception &e) {
     141           0 :                 SPDLOG_WARN("Error reading {}: {}. Discarding homepage cache.", cache_file.string(),
     142             :                             e.what());
     143           0 :                 cached_doc = json{};
     144           0 :             }
     145             :         }
     146             : 
     147             :         // Extract the last-modified header from the cached JSON if available
     148           1 :         std::string last_modified;
     149           1 :         if (!cached_doc.is_null() && cached_doc.contains("last-modified")) {
     150           0 :             last_modified = cached_doc["last-modified"].get<std::string>();
     151             :         }
     152             : 
     153             :         // Issue the asynchronous download using the cached last-modified value.
     154             :         try {
     155           1 :             downloads.push_back(
     156           1 :                 {endpoint, cache_file, cached_doc, downloader.download(endpoint, last_modified)});
     157           0 :         } catch (const std::exception &e) {
     158           0 :             react_on_exception(
     159           0 :                 fmt::format("Failed to download overview of available tables from {}",
     160           0 :                             downloads.back().endpoint),
     161             :                 e);
     162           0 :             return;
     163           0 :         }
     164           1 :     }
     165             : 
     166             :     // Process downloads for each repo path
     167           2 :     for (auto &dl : downloads) {
     168           1 :         auto result = dl.future.get();
     169           1 :         if ((result.status_code == 403 || result.status_code == 429) &&
     170           0 :             result.rate_limit.remaining == 0) {
     171           0 :             react_on_rate_limit_reached(result.rate_limit.reset_time);
     172           0 :             return;
     173             :         }
     174           1 :         if (result.status_code != 200 && result.status_code != 304) {
     175           0 :             react_on_error_code(
     176           0 :                 fmt::format("Failed to download overview of available tables from {}", dl.endpoint),
     177             :                 result.status_code);
     178           0 :             return;
     179             :         }
     180             : 
     181           1 :         json doc;
     182           1 :         if (result.status_code == 304) {
     183           0 :             if (dl.cached_doc.is_null() || dl.cached_doc.empty()) {
     184           0 :                 throw std::runtime_error(
     185           0 :                     fmt::format("Received 304 Not Modified but cached response {} does not exist.",
     186           0 :                                 dl.cache_file.string()));
     187             :             }
     188           0 :             doc = dl.cached_doc;
     189           0 :             SPDLOG_INFO("Using cached overview of available tables from {}.", dl.endpoint);
     190             :         } else {
     191           1 :             doc = json::parse(result.body, nullptr, /*allow_exceptions=*/false);
     192           1 :             doc["last-modified"] = result.last_modified;
     193           1 :             save_json(dl.cache_file, doc);
     194           1 :             SPDLOG_INFO("Using downloaded overview of available tables from {}.", dl.endpoint);
     195             :         }
     196             : 
     197             :         // Validate the JSON response
     198           1 :         if (doc.is_discarded() || !doc.contains("assets")) {
     199           0 :             throw std::runtime_error(fmt::format(
     200           0 :                 "Failed to parse remote JSON or missing 'assets' key from {}.", dl.endpoint));
     201             :         }
     202             : 
     203             :         // Update remote_asset_info based on the asset entries
     204           2 :         for (auto &asset : doc["assets"]) {
     205           1 :             std::string name = asset["name"].get<std::string>();
     206           1 :             std::smatch match;
     207             : 
     208           1 :             if (std::regex_match(name, match, remote_regex) && match.size() == 4) {
     209           1 :                 std::string key = match[1].str();
     210           1 :                 int version_major = std::stoi(match[2].str());
     211           1 :                 int version_minor = std::stoi(match[3].str());
     212             : 
     213           1 :                 if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
     214           0 :                     continue;
     215             :                 }
     216             : 
     217           1 :                 auto it = remote_asset_info.find(key);
     218           1 :                 if (it == remote_asset_info.end() || version_minor > it->second.version_minor) {
     219           1 :                     std::string remote_url = asset["url"].get<std::string>();
     220           1 :                     const std::string host = downloader.get_host();
     221           2 :                     remote_asset_info[key] = {version_minor, remote_url.erase(0, host.size())};
     222           1 :                 }
     223           1 :             }
     224           1 :         }
     225           1 :     }
     226             : 
     227             :     // Ensure that scan_remote was successful
     228           1 :     if (!downloads.empty() && remote_asset_info.empty()) {
     229           0 :         throw std::runtime_error(
     230             :             "No compatible database tables were found in the remote repositories. Consider "
     231           0 :             "upgrading PairInteraction to a newer version.");
     232             :     }
     233           4 : }
     234             : 
     235           6 : void ParquetManager::scan_local() {
     236           6 :     local_asset_info.clear();
     237             : 
     238             :     // Iterate over files in the directory to update local_asset_info
     239          72 :     for (const auto &entry : fs::directory_iterator(directory_ / "tables")) {
     240          33 :         std::string name = entry.path().filename().string();
     241          33 :         std::smatch match;
     242             : 
     243          66 :         if (entry.is_directory() && std::regex_match(name, match, local_regex) &&
     244          33 :             match.size() == 4) {
     245          33 :             std::string key = match[1].str();
     246          33 :             int version_major = std::stoi(match[2].str());
     247          33 :             int version_minor = std::stoi(match[3].str());
     248             : 
     249          33 :             if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
     250           0 :                 continue;
     251             :             }
     252             : 
     253          33 :             auto it = local_asset_info.find(key);
     254          33 :             if (it == local_asset_info.end() || version_minor > it->second.version_minor) {
     255          33 :                 local_asset_info[key].version_minor = version_minor;
     256         309 :                 for (const auto &subentry : fs::directory_iterator(entry)) {
     257         138 :                     if (subentry.is_regular_file() && subentry.path().extension() == ".parquet") {
     258         276 :                         local_asset_info[key].paths[subentry.path().stem().string()] = {
     259         414 :                             subentry.path().string(), false};
     260             :                     }
     261          33 :                 }
     262             :             }
     263          33 :         }
     264          39 :     }
     265         144 : }
     266             : 
     267        4815 : void ParquetManager::update_local_asset(const std::string &key) {
     268        4815 :     assert(remote_asset_info.empty() == repo_paths_.empty());
     269             : 
     270             :     // If remote_asset_info is empty, we have nothing to do
     271        4815 :     if (remote_asset_info.empty()) {
     272        4814 :         return;
     273             :     }
     274             : 
     275             :     // Get remote version if available
     276           1 :     int remote_version = -1;
     277           1 :     auto remote_it = remote_asset_info.find(key);
     278           1 :     if (remote_it != remote_asset_info.end()) {
     279           1 :         remote_version = remote_it->second.version_minor;
     280             :     }
     281             : 
     282             :     // Get local version if available and check if it is up-to-date
     283             :     {
     284           1 :         int local_version = -1;
     285           1 :         std::shared_lock<std::shared_mutex> lock(mtx_local);
     286           1 :         auto local_it = local_asset_info.find(key);
     287           1 :         if (local_it != local_asset_info.end()) {
     288           1 :             local_version = local_it->second.version_minor;
     289             :         }
     290           1 :         if (local_version >= remote_version) {
     291           0 :             return;
     292             :         }
     293           1 :     }
     294             : 
     295             :     // If it is not up-to-date, acquire a unique lock for updating the table
     296           1 :     std::unique_lock<std::shared_mutex> lock(mtx_local);
     297             : 
     298             :     // Re-check if the table is up to date because another thread might have updated it
     299           1 :     int local_version = -1;
     300           1 :     auto local_it = local_asset_info.find(key);
     301           1 :     if (local_it != local_asset_info.end()) {
     302           1 :         local_version = local_it->second.version_minor;
     303             :     }
     304           1 :     if (local_version >= remote_version) {
     305           0 :         return;
     306             :     }
     307             : 
     308             :     // Download the remote file
     309           1 :     std::string endpoint = remote_it->second.endpoint;
     310           1 :     SPDLOG_INFO("Downloading {}_v{}.{} from {}", key, COMPATIBLE_DATABASE_VERSION_MAJOR,
     311             :                 remote_version, endpoint);
     312             : 
     313           1 :     GitHubDownloader::Result result;
     314             :     try {
     315           1 :         result = downloader.download(endpoint, "", true).get();
     316           0 :     } catch (const std::exception &e) {
     317           0 :         react_on_exception(fmt::format("Failed to download table {}", endpoint), e);
     318           0 :         return;
     319           0 :     }
     320           1 :     if ((result.status_code == 403 || result.status_code == 429) &&
     321           0 :         result.rate_limit.remaining == 0) {
     322           0 :         react_on_rate_limit_reached(result.rate_limit.reset_time);
     323           0 :         return;
     324             :     }
     325           1 :     if (result.status_code != 200) {
     326           0 :         react_on_error_code(fmt::format("Failed to download table {}", endpoint),
     327             :                             result.status_code);
     328           0 :         return;
     329             :     }
     330             : 
     331             :     // Unzip the downloaded file
     332           1 :     mz_zip_archive zip_archive{};
     333           1 :     if (mz_zip_reader_init_mem(&zip_archive, result.body.data(), result.body.size(), 0) == 0) {
     334           0 :         throw std::runtime_error("Failed to initialize zip archive.");
     335             :     }
     336             : 
     337           2 :     for (mz_uint i = 0; i < mz_zip_reader_get_num_files(&zip_archive); i++) {
     338             :         mz_zip_archive_file_stat file_stat;
     339           1 :         if (mz_zip_reader_file_stat(&zip_archive, i, &file_stat) == 0) {
     340           0 :             throw std::runtime_error("Failed to get file stat from zip archive.");
     341             :         }
     342             : 
     343             :         // Skip directories
     344           1 :         const char *filename = static_cast<const char *>(file_stat.m_filename);
     345           1 :         size_t len = std::strlen(filename);
     346           1 :         if (len > 0 && filename[len - 1] == '/') {
     347           0 :             continue;
     348             :         }
     349             : 
     350             :         // Ensure that the filename matches the expectations
     351           1 :         std::string dir = fs::path(filename).parent_path().string();
     352           1 :         std::string stem = fs::path(filename).stem().string();
     353           1 :         std::string suffix = fs::path(filename).extension().string();
     354           1 :         std::smatch match;
     355           2 :         if (!std::regex_match(dir, match, local_regex) || match.size() != 4 ||
     356           1 :             suffix != ".parquet") {
     357           0 :             throw std::runtime_error(
     358           0 :                 fmt::format("Unexpected filename {} in zip archive.", filename));
     359             :         }
     360             : 
     361             :         // Construct the path to store the table
     362           1 :         auto path = directory_ / "tables" / dir / (stem + suffix);
     363           1 :         SPDLOG_INFO("Storing table to {}", path.string());
     364             : 
     365             :         // Extract the file to memory
     366           1 :         std::vector<char> buffer(file_stat.m_uncomp_size);
     367           1 :         if (mz_zip_reader_extract_to_mem(&zip_archive, i, buffer.data(), buffer.size(), 0) == 0) {
     368           0 :             throw std::runtime_error(fmt::format("Failed to extract {}.", filename));
     369             :         }
     370             : 
     371             :         // Ensure the parent directory exists
     372           1 :         fs::create_directories(path.parent_path());
     373             : 
     374             :         // Save the extracted file
     375           1 :         std::ofstream out(path.string(), std::ios::binary);
     376           1 :         if (!out) {
     377           0 :             throw std::runtime_error(fmt::format("Failed to open {} for writing", path.string()));
     378             :         }
     379           1 :         out.write(buffer.data(), static_cast<std::streamsize>(buffer.size()));
     380           1 :         out.close();
     381             : 
     382             :         // Update the local asset/table info
     383           1 :         local_asset_info[key].version_minor = remote_version;
     384           1 :         local_asset_info[key].paths[path.stem().string()] = {path.string(), false};
     385           1 :     }
     386             : 
     387           1 :     mz_zip_reader_end(&zip_archive);
     388           2 : }
     389             : 
     390        4812 : void ParquetManager::cache_table(std::unordered_map<std::string, PathInfo>::iterator table_it) {
     391             :     // Check if the table is already cached
     392             :     {
     393        4812 :         std::shared_lock<std::shared_mutex> lock(mtx_local);
     394        4812 :         if (table_it->second.cached) {
     395        4785 :             return;
     396             :         }
     397        4812 :     }
     398             : 
     399             :     // Acquire a unique lock for caching the table
     400          27 :     std::unique_lock<std::shared_mutex> lock(mtx_local);
     401             : 
     402             :     // Re-check if the table is already cached because another thread might have cached it
     403          27 :     if (table_it->second.cached) {
     404           0 :         return;
     405             :     }
     406             : 
     407             :     // Cache the table in memory
     408          27 :     std::string table_name;
     409             :     {
     410          27 :         auto result = con.Query(R"(SELECT UUID()::varchar)");
     411          27 :         if (result->HasError()) {
     412           0 :             throw cpptrace::runtime_error("Error selecting a unique table name: " +
     413           0 :                                           result->GetError());
     414             :         }
     415             :         table_name =
     416          27 :             duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
     417          27 :     }
     418             : 
     419             :     {
     420          27 :         auto result = con.Query(fmt::format(R"(CREATE TEMP TABLE '{}' AS SELECT * FROM '{}')",
     421          54 :                                             table_name, table_it->second.path));
     422          27 :         if (result->HasError()) {
     423           0 :             throw cpptrace::runtime_error("Error creating table: " + result->GetError());
     424             :         }
     425          27 :     }
     426             : 
     427          27 :     table_it->second.path = table_name;
     428          27 :     table_it->second.cached = true;
     429          27 : }
     430             : 
     431        4815 : std::string ParquetManager::get_path(const std::string &key, const std::string &table) {
     432             :     // Update the local table if a newer version is available remotely
     433        4815 :     this->update_local_asset(key);
     434             : 
     435             :     // Ensure availability of the local table file
     436        4815 :     auto asset_it = local_asset_info.find(key);
     437        4815 :     if (asset_it == local_asset_info.end()) {
     438             :         // If we do not know about any table that can be downloaded, downloading might be blocked.
     439             :         // Otherwise, the species might be misspelled.
     440           0 :         if (remote_asset_info.empty()) {
     441           0 :             throw std::runtime_error(
     442           0 :                 "No tables found for species '" + key +
     443             :                 "'. Check whether you have allowed downloading missing tables, or download the "
     444           0 :                 "tables manually via `pairinteraction database download " +
     445           0 :                 key + "`.");
     446             :         }
     447           0 :         throw std::runtime_error("No tables found for species '" + key +
     448           0 :                                  "'. Check the spelling of the species.");
     449             :     }
     450        4815 :     auto table_it = asset_it->second.paths.find(table);
     451        4815 :     if (table_it == asset_it->second.paths.end()) {
     452           1 :         throw std::runtime_error("No table '" + table + ".parquet' found for species '" + key +
     453           2 :                                  "'. The tables for the species are incomplete.");
     454             :     }
     455             : 
     456             :     // Cache the local table in memory if requested
     457        4814 :     if (use_cache_) {
     458        4812 :         this->cache_table(table_it);
     459             :     }
     460             : 
     461             :     // Return the path to the local table file
     462        9628 :     return table_it->second.path;
     463             : }
     464             : 
     465           3 : std::string ParquetManager::get_versions_info() const {
     466             :     // Helper lambda returns the version string if available
     467          54 :     auto get_version = [](const auto &map, const std::string &table) -> int {
     468          54 :         if (auto it = map.find(table); it != map.end()) {
     469          27 :             return it->second.version_minor;
     470             :         }
     471          27 :         return -1;
     472             :     };
     473             : 
     474             :     // Gather all unique table names
     475           3 :     std::set<std::string> tables;
     476          30 :     for (const auto &entry : local_asset_info) {
     477          27 :         tables.insert(entry.first);
     478             :     }
     479           3 :     for (const auto &entry : remote_asset_info) {
     480           0 :         tables.insert(entry.first);
     481             :     }
     482             : 
     483             :     // Output versions info
     484           3 :     std::ostringstream oss;
     485             : 
     486           3 :     oss << " ";
     487           3 :     oss << std::left << std::setw(17) << "Asset";
     488           3 :     oss << std::left << std::setw(6 + 4) << "Local";
     489           3 :     oss << std::left << std::setw(7) << "Remote\n";
     490           3 :     oss << std::string(35, '-') << "\n";
     491             : 
     492          30 :     for (const auto &table : tables) {
     493          27 :         int local_version = get_version(local_asset_info, table);
     494          27 :         int remote_version = get_version(remote_asset_info, table);
     495             : 
     496             :         std::string comparator = (local_version < remote_version)
     497             :             ? "<"
     498          27 :             : ((local_version > remote_version) ? ">" : "==");
     499             :         std::string local_version_str = local_version == -1
     500          27 :             ? "N/A"
     501          54 :             : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
     502          81 :                 std::to_string(local_version);
     503             :         std::string remote_version_str = remote_version == -1
     504          27 :             ? "N/A"
     505          27 :             : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
     506          54 :                 std::to_string(remote_version);
     507             : 
     508          27 :         oss << " ";
     509          27 :         oss << std::left << std::setw(17) << table;
     510          27 :         oss << std::left << std::setw(6) << local_version_str;
     511          27 :         oss << std::left << std::setw(4) << comparator;
     512          27 :         oss << std::left << std::setw(7) << remote_version_str << "\n";
     513          27 :     }
     514           6 :     return oss.str();
     515           3 : }
     516             : 
     517             : } // namespace pairinteraction

Generated by: LCOV version 1.16