LCOV - code coverage report
Current view: top level - src/database - ParquetManager.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 200 256 78.1 %
Date: 2025-04-29 15:56:08 Functions: 11 13 84.6 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2025 Pairinteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/database/ParquetManager.hpp"
       5             : 
       6             : #include "pairinteraction/database/GitHubDownloader.hpp"
       7             : #include "pairinteraction/version.hpp"
       8             : 
       9             : #include <cpptrace/cpptrace.hpp>
      10             : #include <ctime>
      11             : #include <duckdb.hpp>
      12             : #include <filesystem>
      13             : #include <fmt/core.h>
      14             : #include <fstream>
      15             : #include <future>
      16             : #include <iomanip>
      17             : #include <miniz.h>
      18             : #include <nlohmann/json.hpp>
      19             : #include <regex>
      20             : #include <set>
      21             : #include <spdlog/spdlog.h>
      22             : #include <sstream>
      23             : #include <stdexcept>
      24             : #include <string>
      25             : 
      26             : namespace fs = std::filesystem;
      27             : using json = nlohmann::json;
      28             : 
      29           1 : std::string format_time(std::time_t time_val) {
      30           1 :     std::tm *ptm = std::localtime(&time_val);
      31           1 :     std::ostringstream oss;
      32           1 :     oss << std::put_time(ptm, "%Y-%m-%d %H:%M:%S");
      33           2 :     return oss.str();
      34           1 : }
      35             : 
      36           0 : json load_json(const fs::path &file) {
      37           0 :     std::ifstream in(file);
      38           0 :     in.exceptions(std::ifstream::failbit | std::ifstream::badbit);
      39           0 :     json doc;
      40           0 :     in >> doc;
      41           0 :     return doc;
      42           0 : }
      43             : 
      44           1 : void save_json(const fs::path &file, const json &doc) {
      45           1 :     std::ofstream out(file);
      46           1 :     if (!out) {
      47           0 :         throw std::runtime_error(fmt::format("Failed to open {} for writing", file.string()));
      48             :     }
      49           1 :     out << doc;
      50           1 :     out.close();
      51           1 : }
      52             : 
      53             : namespace pairinteraction {
      54           8 : ParquetManager::ParquetManager(std::filesystem::path directory, const GitHubDownloader &downloader,
      55             :                                std::vector<std::string> repo_paths, duckdb::Connection &con,
      56           8 :                                bool use_cache)
      57           8 :     : directory_(std::move(directory)), downloader(downloader), repo_paths_(std::move(repo_paths)),
      58           8 :       con(con), use_cache_(use_cache) {
      59             :     // Ensure the local directory exists
      60           8 :     if (!std::filesystem::exists(directory_ / "tables")) {
      61           0 :         fs::create_directories(directory_ / "tables");
      62             :     }
      63             : 
      64             :     // If repo paths are provided, check the GitHub rate limit
      65           8 :     if (!repo_paths_.empty()) {
      66           1 :         auto rate_limit = downloader.get_rate_limit();
      67           1 :         if (rate_limit.remaining <= 0) {
      68           0 :             react_on_rate_limit_reached(rate_limit.reset_time);
      69             :         } else {
      70           2 :             SPDLOG_INFO("Remaining GitHub API requests: {}. Rate limit resets at {}.",
      71             :                         rate_limit.remaining, format_time(rate_limit.reset_time));
      72             :         }
      73             :     }
      74           8 : }
      75             : 
      76           8 : void ParquetManager::scan_remote() {
      77           8 :     remote_asset_info.clear();
      78             : 
      79             :     struct RepoDownload {
      80             :         std::string endpoint;
      81             :         fs::path cache_file;
      82             :         json cached_doc;
      83             :         std::future<GitHubDownloader::Result> future;
      84             :     };
      85           8 :     std::vector<RepoDownload> downloads;
      86           8 :     downloads.reserve(repo_paths_.size());
      87             : 
      88             :     // For each repo path, load its cached JSON (or an empty JSON) and issue the download
      89           9 :     for (const auto &endpoint : repo_paths_) {
      90             :         // Generate a unique cache filename per endpoint
      91             :         std::string cache_filename =
      92           1 :             "homepage_cache_" + std::to_string(std::hash<std::string>{}(endpoint)) + ".json";
      93           1 :         auto cache_file = directory_ / cache_filename;
      94             : 
      95             :         // Load cached JSON from file if it exists
      96           1 :         json cached_doc;
      97           1 :         if (std::filesystem::exists(cache_file)) {
      98             :             try {
      99           0 :                 cached_doc = load_json(cache_file);
     100           0 :             } catch (const std::exception &e) {
     101           0 :                 SPDLOG_WARN("Error reading {}: {}. Discarding homepage cache.", cache_file.string(),
     102             :                             e.what());
     103           0 :                 cached_doc = json{};
     104           0 :             }
     105             :         }
     106             : 
     107             :         // Extract the last-modified header from the cached JSON if available
     108           1 :         std::string last_modified;
     109           1 :         if (!cached_doc.is_null() && cached_doc.contains("last-modified")) {
     110           0 :             last_modified = cached_doc["last-modified"].get<std::string>();
     111             :         }
     112             : 
     113             :         // Issue the asynchronous download using the cached last-modified value.
     114           1 :         downloads.push_back(
     115           1 :             {endpoint, cache_file, cached_doc, downloader.download(endpoint, last_modified)});
     116           1 :     }
     117             : 
     118             :     // Process downloads for each repo path
     119           9 :     for (auto &dl : downloads) {
     120           1 :         auto result = dl.future.get();
     121           1 :         json doc;
     122           1 :         if (result.status_code == 200) {
     123           1 :             doc = json::parse(result.body, nullptr, /*allow_exceptions=*/false);
     124           1 :             doc["last-modified"] = result.last_modified;
     125           1 :             save_json(dl.cache_file, doc);
     126           2 :             SPDLOG_INFO("Using downloaded overview of available tables from {}.", dl.endpoint);
     127           0 :         } else if (result.status_code == 304) {
     128           0 :             if (dl.cached_doc.is_null() || dl.cached_doc.empty()) {
     129           0 :                 throw std::runtime_error(
     130           0 :                     fmt::format("Received 304 Not Modified but cached response {} does not exist.",
     131           0 :                                 dl.cache_file.string()));
     132             :             }
     133           0 :             doc = dl.cached_doc;
     134           0 :             SPDLOG_INFO("Using cached overview of available tables from {}.", dl.endpoint);
     135           0 :         } else if (result.status_code == 403 || result.status_code == 429) {
     136           0 :             react_on_rate_limit_reached(result.rate_limit.reset_time);
     137           0 :             return;
     138             :         } else {
     139           0 :             throw std::runtime_error(fmt::format(
     140             :                 "Failed to download overview of available tables from {}: status code {}.",
     141           0 :                 dl.endpoint, result.status_code));
     142             :         }
     143             : 
     144             :         // Validate the JSON response
     145           1 :         if (doc.is_discarded() || !doc.contains("assets")) {
     146           0 :             throw std::runtime_error(fmt::format(
     147           0 :                 "Failed to parse remote JSON or missing 'assets' key from {}.", dl.endpoint));
     148             :         }
     149             : 
     150             :         // Update remote_asset_info based on the asset entries
     151           2 :         for (auto &asset : doc["assets"]) {
     152           1 :             std::string name = asset["name"].get<std::string>();
     153           1 :             std::smatch match;
     154             : 
     155           1 :             if (std::regex_match(name, match, remote_regex) && match.size() == 4) {
     156           1 :                 std::string key = match[1].str();
     157           1 :                 int version_major = std::stoi(match[2].str());
     158           1 :                 int version_minor = std::stoi(match[3].str());
     159             : 
     160           1 :                 if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
     161           0 :                     continue;
     162             :                 }
     163             : 
     164           1 :                 auto it = remote_asset_info.find(key);
     165           1 :                 if (it == remote_asset_info.end() || version_minor > it->second.version_minor) {
     166           1 :                     std::string remote_url = asset["url"].get<std::string>();
     167           1 :                     const std::string host = downloader.get_host();
     168           2 :                     remote_asset_info[key] = {version_minor, remote_url.erase(0, host.size())};
     169           1 :                 }
     170           1 :             }
     171           1 :         }
     172           1 :     }
     173             : 
     174             :     // Ensure that scan_remote was successful
     175           8 :     if (!downloads.empty() && remote_asset_info.empty()) {
     176           0 :         throw std::runtime_error(
     177             :             "No compatible database tables were found in the remote repositories. Consider "
     178           0 :             "upgrading pairinteraction to a newer version.");
     179             :     }
     180          11 : }
     181             : 
     182           8 : void ParquetManager::scan_local() {
     183           8 :     local_asset_info.clear();
     184             : 
     185             :     // Iterate over files in the directory to update local_asset_info
     186         110 :     for (const auto &entry : fs::directory_iterator(directory_ / "tables")) {
     187          51 :         std::string name = entry.path().filename().string();
     188          51 :         std::smatch match;
     189             : 
     190         102 :         if (entry.is_directory() && std::regex_match(name, match, local_regex) &&
     191          51 :             match.size() == 4) {
     192          51 :             std::string key = match[1].str();
     193          51 :             int version_major = std::stoi(match[2].str());
     194          51 :             int version_minor = std::stoi(match[3].str());
     195             : 
     196          51 :             if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
     197           0 :                 continue;
     198             :             }
     199             : 
     200          51 :             auto it = local_asset_info.find(key);
     201          51 :             if (it == local_asset_info.end() || version_minor > it->second.version_minor) {
     202          48 :                 local_asset_info[key].version_minor = version_minor;
     203         494 :                 for (const auto &subentry : fs::directory_iterator(entry)) {
     204         223 :                     if (subentry.is_regular_file() && subentry.path().extension() == ".parquet") {
     205         446 :                         local_asset_info[key].paths[subentry.path().stem().string()] = {
     206         669 :                             subentry.path().string(), false};
     207             :                     }
     208          48 :                 }
     209             :             }
     210          51 :         }
     211          59 :     }
     212         231 : }
     213             : 
     214           0 : void ParquetManager::react_on_rate_limit_reached(std::time_t reset_time) {
     215           0 :     repo_paths_.clear();
     216           0 :     remote_asset_info.clear();
     217           0 :     SPDLOG_WARN("Rate limit reached, resets at {}. The download of database tables is disabled.",
     218             :                 format_time(reset_time));
     219           0 : }
     220             : 
     221         166 : void ParquetManager::update_local_asset(const std::string &key) {
     222             :     // Get remote version if available
     223         166 :     int remote_version = -1;
     224         166 :     auto remote_it = remote_asset_info.find(key);
     225         166 :     if (remote_it != remote_asset_info.end()) {
     226           1 :         remote_version = remote_it->second.version_minor;
     227             :     }
     228             : 
     229             :     // Get local version if available and check if it is up-to-date
     230             :     {
     231         166 :         int local_version = -1;
     232         166 :         std::shared_lock<std::shared_mutex> lock(mtx_local);
     233         168 :         auto local_it = local_asset_info.find(key);
     234         167 :         if (local_it != local_asset_info.end()) {
     235         168 :             local_version = local_it->second.version_minor;
     236             :         }
     237         168 :         if (local_version >= remote_version) {
     238         167 :             return;
     239             :         }
     240         168 :     }
     241             : 
     242             :     // If it is not up-to-date, acquire a unique lock for updating the table
     243           1 :     std::unique_lock<std::shared_mutex> lock(mtx_local);
     244             : 
     245             :     // Re-check if the table is up to date because another thread might have updated it
     246           1 :     int local_version = -1;
     247           1 :     auto local_it = local_asset_info.find(key);
     248           1 :     if (local_it != local_asset_info.end()) {
     249           1 :         local_version = local_it->second.version_minor;
     250             :     }
     251           1 :     if (local_version >= remote_version) {
     252           0 :         return;
     253             :     }
     254             : 
     255             :     // Download the remote file
     256           1 :     std::string endpoint = remote_it->second.endpoint;
     257           2 :     SPDLOG_INFO("Downloading {}_v{}.{} from {}", key, COMPATIBLE_DATABASE_VERSION_MAJOR,
     258             :                 remote_version, endpoint);
     259             : 
     260           1 :     auto result = downloader.download(endpoint, "", true).get();
     261           1 :     if (result.status_code == 403 || result.status_code == 429) {
     262           0 :         react_on_rate_limit_reached(result.rate_limit.reset_time);
     263           0 :         return;
     264             :     }
     265           1 :     if (result.status_code != 200) {
     266           0 :         throw std::runtime_error(fmt::format("Failed to download table {}: status code {}.",
     267           0 :                                              endpoint, result.status_code));
     268             :     }
     269             : 
     270             :     // Unzip the downloaded file
     271           1 :     mz_zip_archive zip_archive{};
     272           1 :     if (mz_zip_reader_init_mem(&zip_archive, result.body.data(), result.body.size(), 0) == 0) {
     273           0 :         throw std::runtime_error("Failed to initialize zip archive.");
     274             :     }
     275             : 
     276           2 :     for (mz_uint i = 0; i < mz_zip_reader_get_num_files(&zip_archive); i++) {
     277             :         mz_zip_archive_file_stat file_stat;
     278           1 :         if (mz_zip_reader_file_stat(&zip_archive, i, &file_stat) == 0) {
     279           0 :             throw std::runtime_error("Failed to get file stat from zip archive.");
     280             :         }
     281             : 
     282             :         // Skip directories
     283           1 :         const char *filename = static_cast<const char *>(file_stat.m_filename);
     284           1 :         size_t len = std::strlen(filename);
     285           1 :         if (len > 0 && filename[len - 1] == '/') {
     286           0 :             continue;
     287             :         }
     288             : 
     289             :         // Ensure that the filename matches the expectations
     290           1 :         std::string dir = fs::path(filename).parent_path().string();
     291           1 :         std::string stem = fs::path(filename).stem().string();
     292           1 :         std::string suffix = fs::path(filename).extension().string();
     293           1 :         std::smatch match;
     294           2 :         if (!std::regex_match(dir, match, local_regex) || match.size() != 4 ||
     295           1 :             suffix != ".parquet") {
     296           0 :             throw std::runtime_error(
     297           0 :                 fmt::format("Unexpected filename {} in zip archive.", filename));
     298             :         }
     299             : 
     300             :         // Construct the path to store the table
     301           1 :         auto path = directory_ / "tables" / dir / (stem + suffix);
     302           2 :         SPDLOG_INFO("Storing table to {}", path.string());
     303             : 
     304             :         // Extract the file to memory
     305           1 :         std::vector<char> buffer(file_stat.m_uncomp_size);
     306           1 :         if (mz_zip_reader_extract_to_mem(&zip_archive, i, buffer.data(), buffer.size(), 0) == 0) {
     307           0 :             throw std::runtime_error(fmt::format("Failed to extract {}.", filename));
     308             :         }
     309             : 
     310             :         // Ensure the parent directory exists
     311           1 :         fs::create_directories(path.parent_path());
     312             : 
     313             :         // Save the extracted file
     314           1 :         std::ofstream out(path.string(), std::ios::binary);
     315           1 :         if (!out) {
     316           0 :             throw std::runtime_error(fmt::format("Failed to open {} for writing", path.string()));
     317             :         }
     318           1 :         out.write(buffer.data(), static_cast<std::streamsize>(buffer.size()));
     319           1 :         out.close();
     320             : 
     321             :         // Update the local asset/table info
     322           1 :         local_asset_info[key].version_minor = remote_version;
     323           1 :         local_asset_info[key].paths[path.stem().string()] = {path.string(), false};
     324           1 :     }
     325             : 
     326           1 :     mz_zip_reader_end(&zip_archive);
     327           2 : }
     328             : 
     329         164 : void ParquetManager::cache_table(std::unordered_map<std::string, PathInfo>::iterator table_it) {
     330             :     // Check if the table is already cached
     331             :     {
     332         164 :         std::shared_lock<std::shared_mutex> lock(mtx_local);
     333         164 :         if (table_it->second.cached) {
     334         144 :             return;
     335             :         }
     336         164 :     }
     337             : 
     338             :     // Acquire a unique lock for caching the table
     339          19 :     std::unique_lock<std::shared_mutex> lock(mtx_local);
     340             : 
     341             :     // Re-check if the table is already cached because another thread might have cached it
     342          20 :     if (table_it->second.cached) {
     343           3 :         return;
     344             :     }
     345             : 
     346             :     // Cache the table in memory
     347          17 :     std::string table_name;
     348             :     {
     349          17 :         auto result = con.Query(R"(SELECT UUID()::varchar)");
     350          17 :         if (result->HasError()) {
     351           0 :             throw cpptrace::runtime_error("Error selecting a unique table name: " +
     352           0 :                                           result->GetError());
     353             :         }
     354             :         table_name =
     355          17 :             duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
     356          17 :     }
     357             : 
     358             :     {
     359          17 :         auto result = con.Query(fmt::format(R"(CREATE TEMP TABLE '{}' AS SELECT * FROM '{}')",
     360          34 :                                             table_name, table_it->second.path));
     361          17 :         if (result->HasError()) {
     362           0 :             throw cpptrace::runtime_error("Error creating table: " + result->GetError());
     363             :         }
     364          17 :     }
     365             : 
     366          17 :     table_it->second.path = table_name;
     367          17 :     table_it->second.cached = true;
     368          20 : }
     369             : 
     370         168 : std::string ParquetManager::get_path(const std::string &key, const std::string &table) {
     371             :     // Update the local table if a newer version is available remotely
     372         168 :     this->update_local_asset(key);
     373             : 
     374             :     // Ensure availability of the local table file
     375         168 :     auto asset_it = local_asset_info.find(key);
     376         168 :     if (asset_it == local_asset_info.end()) {
     377           0 :         throw std::runtime_error("Table " + key + "_" + table + " not found.");
     378             :     }
     379         168 :     auto table_it = asset_it->second.paths.find(table);
     380         168 :     if (table_it == asset_it->second.paths.end()) {
     381           1 :         throw std::runtime_error("Table " + key + "_" + table + " not found.");
     382             :     }
     383             : 
     384             :     // Cache the local table in memory if requested
     385         166 :     if (use_cache_) {
     386         164 :         this->cache_table(table_it);
     387             :     }
     388             : 
     389             :     // Return the path to the local table file
     390         333 :     return table_it->second.path;
     391             : }
     392             : 
     393           5 : std::string ParquetManager::get_versions_info() const {
     394             :     // Helper lambda returns the version string if available
     395          90 :     auto get_version = [](const auto &map, const std::string &table) -> int {
     396          90 :         if (auto it = map.find(table); it != map.end()) {
     397          45 :             return it->second.version_minor;
     398             :         }
     399          45 :         return -1;
     400             :     };
     401             : 
     402             :     // Gather all unique table names
     403           5 :     std::set<std::string> tables;
     404          50 :     for (const auto &entry : local_asset_info) {
     405          45 :         tables.insert(entry.first);
     406             :     }
     407           5 :     for (const auto &entry : remote_asset_info) {
     408           0 :         tables.insert(entry.first);
     409             :     }
     410             : 
     411             :     // Output versions info
     412           5 :     std::ostringstream oss;
     413             : 
     414           5 :     oss << " ";
     415           5 :     oss << std::left << std::setw(17) << "Asset";
     416           5 :     oss << std::left << std::setw(6 + 4) << "Local";
     417           5 :     oss << std::left << std::setw(7) << "Remote\n";
     418           5 :     oss << std::string(35, '-') << "\n";
     419             : 
     420          50 :     for (const auto &table : tables) {
     421          45 :         int local_version = get_version(local_asset_info, table);
     422          45 :         int remote_version = get_version(remote_asset_info, table);
     423             : 
     424             :         std::string comparator = (local_version < remote_version)
     425             :             ? "<"
     426          45 :             : ((local_version > remote_version) ? ">" : "==");
     427             :         std::string local_version_str = local_version == -1
     428          45 :             ? "N/A"
     429          90 :             : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
     430         135 :                 std::to_string(local_version);
     431             :         std::string remote_version_str = remote_version == -1
     432          45 :             ? "N/A"
     433          45 :             : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
     434          90 :                 std::to_string(remote_version);
     435             : 
     436          45 :         oss << " ";
     437          45 :         oss << std::left << std::setw(17) << table;
     438          45 :         oss << std::left << std::setw(6) << local_version_str;
     439          45 :         oss << std::left << std::setw(4) << comparator;
     440          45 :         oss << std::left << std::setw(7) << remote_version_str << "\n";
     441          45 :     }
     442          10 :     return oss.str();
     443           5 : }
     444             : 
     445             : } // namespace pairinteraction

Generated by: LCOV version 1.16