LCOV - code coverage report
Current view: top level - src/database - GitHubDownloader.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 26 102 25.5 %
Date: 2025-09-15 16:23:56 Functions: 5 10 50.0 %

          Line data    Source code
       1             : // SPDX-FileCopyrightText: 2025 PairInteraction Developers
       2             : // SPDX-License-Identifier: LGPL-3.0-or-later
       3             : 
       4             : #include "pairinteraction/database/GitHubDownloader.hpp"
       5             : 
       6             : #include "pairinteraction/utils/paths.hpp"
       7             : 
       8             : #include <filesystem>
       9             : #include <fmt/core.h>
      10             : #include <fstream>
      11             : #include <future>
      12             : #include <httplib.h>
      13             : #include <spdlog/spdlog.h>
      14             : #include <stdexcept>
      15             : 
      16             : namespace pairinteraction {
      17             : 
      18           0 : void log(const httplib::Request &req, const httplib::Response &res) {
      19           0 :     if (!spdlog::default_logger()->should_log(spdlog::level::debug)) {
      20           0 :         return;
      21             :     }
      22             : 
      23           0 :     SPDLOG_DEBUG("[httplib] {} {}", req.method, req.path);
      24           0 :     for (const auto &[k, v] : req.headers) {
      25           0 :         SPDLOG_DEBUG("[httplib]   {}: {}\n", k, v);
      26             :     }
      27             : 
      28           0 :     SPDLOG_DEBUG("[httplib] Response with status {}", res.status);
      29           0 :     for (const auto &[k, v] : res.headers) {
      30           0 :         SPDLOG_DEBUG("[httplib]   {}: {}\n", k, v);
      31             :     }
      32             : 
      33           0 :     if (res.body.empty()) {
      34           0 :         return;
      35             :     }
      36             : 
      37           0 :     if (res.body.size() > 1024) {
      38           0 :         SPDLOG_DEBUG("[httplib] Body ({} bytes, first 1024 bytes):", res.body.size());
      39             :     } else {
      40           0 :         SPDLOG_DEBUG("[httplib] Body ({} bytes):", res.body.size());
      41             :     }
      42           0 :     SPDLOG_DEBUG("[httplib]   {}", res.body.substr(0, 1024));
      43             : }
      44             : 
      45           8 : GitHubDownloader::GitHubDownloader() : client(std::make_unique<httplib::SSLClient>(host)) {
      46           8 :     std::filesystem::path configdir = paths::get_config_directory();
      47           8 :     if (!std::filesystem::exists(configdir)) {
      48           0 :         std::filesystem::create_directories(configdir);
      49           8 :     } else if (!std::filesystem::is_directory(configdir)) {
      50           0 :         throw std::filesystem::filesystem_error("Cannot access config directory ",
      51           0 :                                                 configdir.string(),
      52           0 :                                                 std::make_error_code(std::errc::not_a_directory));
      53             :     }
      54             : 
      55           8 :     std::filesystem::path cert_path = configdir / "ca-bundle.crt";
      56           8 :     if (!std::filesystem::exists(cert_path)) {
      57           1 :         std::ofstream out(cert_path);
      58           1 :         if (!out) {
      59           0 :             throw std::runtime_error("Failed to create certificate file at " + cert_path.string());
      60             :         }
      61           1 :         out << cert;
      62           1 :         out.close();
      63           1 :     }
      64           8 :     cert_path_str = cert_path.string();
      65             : 
      66           8 :     client->set_follow_location(true);
      67           8 :     client->set_connection_timeout(5, 0); // seconds
      68           8 :     client->set_read_timeout(60, 0);      // seconds
      69           8 :     client->set_write_timeout(1, 0);      // seconds
      70           8 :     client->set_ca_cert_path(cert_path_str);
      71           8 :     client->set_logger(log);
      72           8 : }
      73             : 
      74          13 : GitHubDownloader::~GitHubDownloader() = default;
      75             : 
      76             : std::future<GitHubDownloader::Result>
      77           0 : GitHubDownloader::download(const std::string &remote_url, const std::string &if_modified_since,
      78             :                            bool use_octet_stream) const {
      79             :     return std::async(
      80           0 :         std::launch::async, [this, remote_url, if_modified_since, use_octet_stream]() -> Result {
      81           0 :             SPDLOG_DEBUG("Downloading from GitHub: {}", remote_url);
      82             : 
      83             :             // Prepare headers
      84             :             httplib::Headers headers{
      85             :                 {"User-Agent", "pairinteraction"},
      86             :                 {"X-GitHub-Api-Version", "2022-11-28"},
      87             :                 {"Accept",
      88           0 :                  use_octet_stream ? "application/octet-stream" : "application/vnd.github+json"}};
      89             : 
      90           0 :             if (!if_modified_since.empty()) {
      91           0 :                 headers.emplace("if-modified-since", if_modified_since);
      92             :             }
      93             : 
      94             :             // Use the GitHub token if available; otherwise, if we have a conditional request,
      95             :             // insert a dummy authorization header to avoid increasing rate limits
      96           0 :             if (auto *token = std::getenv("GITHUB_TOKEN"); token) {
      97           0 :                 headers.emplace("Authorization", fmt::format("Bearer {}", token));
      98           0 :             } else if (!if_modified_since.empty()) {
      99           0 :                 headers.emplace("Authorization",
     100             :                                 "avoids-an-increase-in-ratelimits-used-if-304-is-returned");
     101             :             }
     102             : 
     103             :             // If we're fetching binary, stream with a progress callback; otherwise use a simple get
     104           0 :             httplib::Result response;
     105           0 :             std::string streamed_body;
     106           0 :             if (use_octet_stream) {
     107           0 :                 auto content_receiver = [&](const char *data, size_t len) {
     108           0 :                     streamed_body.append(data, len);
     109           0 :                     return true;
     110           0 :                 };
     111             : 
     112             :                 // Progress display
     113           0 :                 int last_pct = -1;
     114           0 :                 auto progress_display = [&last_pct, remote_url](uint64_t cur, uint64_t total) {
     115           0 :                     if (total == 0) {
     116           0 :                         fmt::print(stderr, "\rDownloading {}...", remote_url);
     117           0 :                         (void)std::fflush(stderr);
     118           0 :                     } else if (int pct = static_cast<int>((cur * 100) / total); pct != last_pct) {
     119           0 :                         last_pct = pct;
     120           0 :                         fmt::print(stderr, "\rDownloading {}... {:3d}%", remote_url, pct);
     121           0 :                         (void)std::fflush(stderr);
     122             :                     }
     123           0 :                     return true;
     124           0 :                 };
     125             : 
     126           0 :                 response = client->Get(remote_url, headers, content_receiver, progress_display);
     127             : 
     128             :                 // Ensure the progress display ends cleanly if we showed it
     129           0 :                 if (last_pct >= 0) {
     130           0 :                     fmt::print(stderr, "\n");
     131           0 :                     (void)std::fflush(stderr);
     132             :                 }
     133           0 :             } else {
     134           0 :                 response = client->Get(remote_url, headers);
     135             :             }
     136             : 
     137             :             // Handle if the response is null
     138           0 :             if (!response) {
     139             :                 // Defensive handling: if response is null and the error is unknown,
     140             :                 // treat this as a 304 Not Modified
     141           0 :                 if (response.error() == httplib::Error::Unknown) {
     142           0 :                     return Result{304, "", "", {}};
     143             :                 }
     144           0 :                 throw std::runtime_error(fmt::format("Error downloading '{}': {}", remote_url,
     145           0 :                                                      httplib::to_string(response.error())));
     146             :             }
     147             : 
     148             :             // Parse the response
     149           0 :             Result result;
     150           0 :             if (response->has_header("x-ratelimit-remaining")) {
     151           0 :                 result.rate_limit.remaining =
     152           0 :                     std::stoi(response->get_header_value("x-ratelimit-remaining"));
     153             :             }
     154           0 :             if (response->has_header("x-ratelimit-reset")) {
     155           0 :                 result.rate_limit.reset_time =
     156           0 :                     std::stoi(response->get_header_value("x-ratelimit-reset"));
     157             :             }
     158           0 :             if (response->has_header("last-modified")) {
     159           0 :                 result.last_modified = response->get_header_value("last-modified");
     160             :             }
     161           0 :             result.body = use_octet_stream ? std::move(streamed_body) : response->body;
     162           0 :             result.status_code = response->status;
     163             : 
     164           0 :             SPDLOG_DEBUG("Response status: {}", response->status);
     165           0 :             return result;
     166           0 :         });
     167           0 : }
     168             : 
     169           1 : GitHubDownloader::RateLimit GitHubDownloader::get_rate_limit() const {
     170             :     // This call now either returns valid rate limit data or throws an exception on error
     171           1 :     Result result = download("/rate_limit", "", false).get();
     172           1 :     if (result.status_code != 200) {
     173           0 :         throw std::runtime_error(
     174           0 :             fmt::format("Failed obtaining the rate limit: status code {}.", result.status_code));
     175             :     }
     176           1 :     return result.rate_limit;
     177           1 : }
     178             : 
     179           1 : std::string GitHubDownloader::get_host() const { return "https://" + host; }
     180             : 
     181             : } // namespace pairinteraction

Generated by: LCOV version 1.16