Line data Source code
1 : // SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/ParquetManager.hpp"
5 :
6 : #include "pairinteraction/database/GitHubDownloader.hpp"
7 : #include "pairinteraction/version.hpp"
8 :
9 : #include <cpptrace/cpptrace.hpp>
10 : #include <ctime>
11 : #include <duckdb.hpp>
12 : #include <filesystem>
13 : #include <fmt/core.h>
14 : #include <fstream>
15 : #include <future>
16 : #include <iomanip>
17 : #include <miniz.h>
18 : #include <nlohmann/json.hpp>
19 : #include <regex>
20 : #include <set>
21 : #include <spdlog/spdlog.h>
22 : #include <sstream>
23 : #include <stdexcept>
24 : #include <string>
25 :
26 : namespace fs = std::filesystem;
27 : using json = nlohmann::json;
28 :
29 1 : std::string format_time(std::time_t time_val) {
30 1 : std::tm *ptm = std::localtime(&time_val);
31 1 : std::ostringstream oss;
32 1 : oss << std::put_time(ptm, "%Y-%m-%d %H:%M:%S");
33 2 : return oss.str();
34 1 : }
35 :
36 0 : json load_json(const fs::path &file) {
37 0 : std::ifstream in(file);
38 0 : in.exceptions(std::ifstream::failbit | std::ifstream::badbit);
39 0 : json doc;
40 0 : in >> doc;
41 0 : return doc;
42 0 : }
43 :
44 1 : void save_json(const fs::path &file, const json &doc) {
45 1 : std::ofstream out(file);
46 1 : if (!out) {
47 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", file.string()));
48 : }
49 1 : out << doc;
50 1 : out.close();
51 1 : }
52 :
53 : namespace pairinteraction {
54 5 : ParquetManager::ParquetManager(std::filesystem::path directory, const GitHubDownloader &downloader,
55 : std::vector<std::string> repo_paths, duckdb::Connection &con,
56 5 : bool use_cache)
57 5 : : directory_(std::move(directory)), downloader(downloader), repo_paths_(std::move(repo_paths)),
58 5 : con(con), use_cache_(use_cache) {
59 : // Ensure the local directory exists
60 5 : if (!std::filesystem::exists(directory_ / "tables")) {
61 0 : fs::create_directories(directory_ / "tables");
62 : }
63 :
64 : // If repo paths are provided, check the GitHub rate limit
65 5 : if (!repo_paths_.empty()) {
66 1 : auto rate_limit = downloader.get_rate_limit();
67 1 : if (rate_limit.remaining <= 0) {
68 0 : react_on_rate_limit_reached(rate_limit.reset_time);
69 : } else {
70 2 : SPDLOG_INFO("Remaining GitHub API requests: {}. Rate limit resets at {}.",
71 : rate_limit.remaining, format_time(rate_limit.reset_time));
72 : }
73 : }
74 5 : }
75 :
76 5 : void ParquetManager::scan_remote() {
77 5 : remote_asset_info.clear();
78 :
79 : struct RepoDownload {
80 : std::string endpoint;
81 : fs::path cache_file;
82 : json cached_doc;
83 : std::future<GitHubDownloader::Result> future;
84 : };
85 5 : std::vector<RepoDownload> downloads;
86 5 : downloads.reserve(repo_paths_.size());
87 :
88 : // For each repo path, load its cached JSON (or an empty JSON) and issue the download
89 6 : for (const auto &endpoint : repo_paths_) {
90 : // Generate a unique cache filename per endpoint
91 : std::string cache_filename =
92 1 : "homepage_cache_" + std::to_string(std::hash<std::string>{}(endpoint)) + ".json";
93 1 : auto cache_file = directory_ / cache_filename;
94 :
95 : // Load cached JSON from file if it exists
96 1 : json cached_doc;
97 1 : if (std::filesystem::exists(cache_file)) {
98 : try {
99 0 : cached_doc = load_json(cache_file);
100 0 : } catch (const std::exception &e) {
101 0 : SPDLOG_WARN("Error reading {}: {}. Discarding homepage cache.", cache_file.string(),
102 : e.what());
103 0 : cached_doc = json{};
104 0 : }
105 : }
106 :
107 : // Extract the last-modified header from the cached JSON if available
108 1 : std::string last_modified;
109 1 : if (!cached_doc.is_null() && cached_doc.contains("last-modified")) {
110 0 : last_modified = cached_doc["last-modified"].get<std::string>();
111 : }
112 :
113 : // Issue the asynchronous download using the cached last-modified value.
114 1 : downloads.push_back(
115 1 : {endpoint, cache_file, cached_doc, downloader.download(endpoint, last_modified)});
116 1 : }
117 :
118 : // Process downloads for each repo path
119 6 : for (auto &dl : downloads) {
120 1 : auto result = dl.future.get();
121 1 : json doc;
122 1 : if (result.status_code == 200) {
123 1 : doc = json::parse(result.body, nullptr, /*allow_exceptions=*/false);
124 1 : doc["last-modified"] = result.last_modified;
125 1 : save_json(dl.cache_file, doc);
126 2 : SPDLOG_INFO("Using downloaded overview of available tables from {}.", dl.endpoint);
127 0 : } else if (result.status_code == 304) {
128 0 : if (dl.cached_doc.is_null() || dl.cached_doc.empty()) {
129 0 : throw std::runtime_error(
130 0 : fmt::format("Received 304 Not Modified but cached response {} does not exist.",
131 0 : dl.cache_file.string()));
132 : }
133 0 : doc = dl.cached_doc;
134 0 : SPDLOG_INFO("Using cached overview of available tables from {}.", dl.endpoint);
135 0 : } else if (result.status_code == 403 || result.status_code == 429) {
136 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
137 0 : return;
138 : } else {
139 0 : throw std::runtime_error(fmt::format(
140 : "Failed to download overview of available tables from {}: status code {}.",
141 0 : dl.endpoint, result.status_code));
142 : }
143 :
144 : // Validate the JSON response
145 1 : if (doc.is_discarded() || !doc.contains("assets")) {
146 0 : throw std::runtime_error(fmt::format(
147 0 : "Failed to parse remote JSON or missing 'assets' key from {}.", dl.endpoint));
148 : }
149 :
150 : // Update remote_asset_info based on the asset entries
151 2 : for (auto &asset : doc["assets"]) {
152 1 : std::string name = asset["name"].get<std::string>();
153 1 : std::smatch match;
154 :
155 1 : if (std::regex_match(name, match, remote_regex) && match.size() == 4) {
156 1 : std::string key = match[1].str();
157 1 : int version_major = std::stoi(match[2].str());
158 1 : int version_minor = std::stoi(match[3].str());
159 :
160 1 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
161 0 : continue;
162 : }
163 :
164 1 : auto it = remote_asset_info.find(key);
165 1 : if (it == remote_asset_info.end() || version_minor > it->second.version_minor) {
166 1 : std::string remote_url = asset["url"].get<std::string>();
167 1 : const std::string host = downloader.get_host();
168 2 : remote_asset_info[key] = {version_minor, remote_url.erase(0, host.size())};
169 1 : }
170 1 : }
171 1 : }
172 1 : }
173 :
174 : // Ensure that scan_remote was successful
175 5 : if (!downloads.empty() && remote_asset_info.empty()) {
176 0 : throw std::runtime_error(
177 : "No compatible database tables were found in the remote repositories. Consider "
178 0 : "upgrading pairinteraction to a newer version.");
179 : }
180 8 : }
181 :
182 5 : void ParquetManager::scan_local() {
183 5 : local_asset_info.clear();
184 :
185 : // Iterate over files in the directory to update local_asset_info
186 53 : for (const auto &entry : fs::directory_iterator(directory_ / "tables")) {
187 24 : std::string name = entry.path().filename().string();
188 24 : std::smatch match;
189 :
190 48 : if (entry.is_directory() && std::regex_match(name, match, local_regex) &&
191 24 : match.size() == 4) {
192 24 : std::string key = match[1].str();
193 24 : int version_major = std::stoi(match[2].str());
194 24 : int version_minor = std::stoi(match[3].str());
195 :
196 24 : if (version_major != COMPATIBLE_DATABASE_VERSION_MAJOR) {
197 0 : continue;
198 : }
199 :
200 24 : auto it = local_asset_info.find(key);
201 24 : if (it == local_asset_info.end() || version_minor > it->second.version_minor) {
202 21 : local_asset_info[key].version_minor = version_minor;
203 203 : for (const auto &subentry : fs::directory_iterator(entry)) {
204 91 : if (subentry.is_regular_file() && subentry.path().extension() == ".parquet") {
205 182 : local_asset_info[key].paths[subentry.path().stem().string()] = {
206 273 : subentry.path().string(), false};
207 : }
208 21 : }
209 : }
210 24 : }
211 29 : }
212 96 : }
213 :
214 0 : void ParquetManager::react_on_rate_limit_reached(std::time_t reset_time) {
215 0 : repo_paths_.clear();
216 0 : remote_asset_info.clear();
217 0 : SPDLOG_WARN("Rate limit reached, resets at {}. The download of database tables is disabled.",
218 : format_time(reset_time));
219 0 : }
220 :
221 632 : void ParquetManager::update_local_asset(const std::string &key) {
222 : // Get remote version if available
223 632 : int remote_version = -1;
224 632 : auto remote_it = remote_asset_info.find(key);
225 632 : if (remote_it != remote_asset_info.end()) {
226 1 : remote_version = remote_it->second.version_minor;
227 : }
228 :
229 : // Get local version if available and check if it is up-to-date
230 : {
231 632 : int local_version = -1;
232 632 : std::shared_lock<std::shared_mutex> lock(mtx_local);
233 632 : auto local_it = local_asset_info.find(key);
234 632 : if (local_it != local_asset_info.end()) {
235 632 : local_version = local_it->second.version_minor;
236 : }
237 632 : if (local_version >= remote_version) {
238 631 : return;
239 : }
240 632 : }
241 :
242 : // If it is not up-to-date, acquire a unique lock for updating the table
243 1 : std::unique_lock<std::shared_mutex> lock(mtx_local);
244 :
245 : // Re-check if the table is up to date because another thread might have updated it
246 1 : int local_version = -1;
247 1 : auto local_it = local_asset_info.find(key);
248 1 : if (local_it != local_asset_info.end()) {
249 1 : local_version = local_it->second.version_minor;
250 : }
251 1 : if (local_version >= remote_version) {
252 0 : return;
253 : }
254 :
255 : // Download the remote file
256 1 : std::string endpoint = remote_it->second.endpoint;
257 2 : SPDLOG_INFO("Downloading {}_v{}.{} from {}", key, COMPATIBLE_DATABASE_VERSION_MAJOR,
258 : remote_version, endpoint);
259 :
260 1 : auto result = downloader.download(endpoint, "", true).get();
261 1 : if (result.status_code == 403 || result.status_code == 429) {
262 0 : react_on_rate_limit_reached(result.rate_limit.reset_time);
263 0 : return;
264 : }
265 1 : if (result.status_code != 200) {
266 0 : throw std::runtime_error(fmt::format("Failed to download table {}: status code {}.",
267 0 : endpoint, result.status_code));
268 : }
269 :
270 : // Unzip the downloaded file
271 1 : mz_zip_archive zip_archive{};
272 1 : if (mz_zip_reader_init_mem(&zip_archive, result.body.data(), result.body.size(), 0) == 0) {
273 0 : throw std::runtime_error("Failed to initialize zip archive.");
274 : }
275 :
276 2 : for (mz_uint i = 0; i < mz_zip_reader_get_num_files(&zip_archive); i++) {
277 : mz_zip_archive_file_stat file_stat;
278 1 : if (mz_zip_reader_file_stat(&zip_archive, i, &file_stat) == 0) {
279 0 : throw std::runtime_error("Failed to get file stat from zip archive.");
280 : }
281 :
282 : // Skip directories
283 1 : const char *filename = static_cast<const char *>(file_stat.m_filename);
284 1 : size_t len = std::strlen(filename);
285 1 : if (len > 0 && filename[len - 1] == '/') {
286 0 : continue;
287 : }
288 :
289 : // Ensure that the filename matches the expectations
290 1 : std::string dir = fs::path(filename).parent_path().string();
291 1 : std::string stem = fs::path(filename).stem().string();
292 1 : std::string suffix = fs::path(filename).extension().string();
293 1 : std::smatch match;
294 2 : if (!std::regex_match(dir, match, local_regex) || match.size() != 4 ||
295 1 : suffix != ".parquet") {
296 0 : throw std::runtime_error(
297 0 : fmt::format("Unexpected filename {} in zip archive.", filename));
298 : }
299 :
300 : // Construct the path to store the table
301 1 : auto path = directory_ / "tables" / dir / (stem + suffix);
302 2 : SPDLOG_INFO("Storing table to {}", path.string());
303 :
304 : // Extract the file to memory
305 1 : std::vector<char> buffer(file_stat.m_uncomp_size);
306 1 : if (mz_zip_reader_extract_to_mem(&zip_archive, i, buffer.data(), buffer.size(), 0) == 0) {
307 0 : throw std::runtime_error(fmt::format("Failed to extract {}.", filename));
308 : }
309 :
310 : // Ensure the parent directory exists
311 1 : fs::create_directories(path.parent_path());
312 :
313 : // Save the extracted file
314 1 : std::ofstream out(path.string(), std::ios::binary);
315 1 : if (!out) {
316 0 : throw std::runtime_error(fmt::format("Failed to open {} for writing", path.string()));
317 : }
318 1 : out.write(buffer.data(), static_cast<std::streamsize>(buffer.size()));
319 1 : out.close();
320 :
321 : // Update the local asset/table info
322 1 : local_asset_info[key].version_minor = remote_version;
323 1 : local_asset_info[key].paths[path.stem().string()] = {path.string(), false};
324 1 : }
325 :
326 1 : mz_zip_reader_end(&zip_archive);
327 2 : }
328 :
329 629 : void ParquetManager::cache_table(std::unordered_map<std::string, PathInfo>::iterator table_it) {
330 : // Check if the table is already cached
331 : {
332 629 : std::shared_lock<std::shared_mutex> lock(mtx_local);
333 629 : if (table_it->second.cached) {
334 604 : return;
335 : }
336 629 : }
337 :
338 : // Acquire a unique lock for caching the table
339 25 : std::unique_lock<std::shared_mutex> lock(mtx_local);
340 :
341 : // Re-check if the table is already cached because another thread might have cached it
342 25 : if (table_it->second.cached) {
343 0 : return;
344 : }
345 :
346 : // Cache the table in memory
347 25 : std::string table_name;
348 : {
349 25 : auto result = con.Query(R"(SELECT UUID()::varchar)");
350 25 : if (result->HasError()) {
351 0 : throw cpptrace::runtime_error("Error selecting a unique table name: " +
352 0 : result->GetError());
353 : }
354 : table_name =
355 25 : duckdb::FlatVector::GetData<duckdb::string_t>(result->Fetch()->data[0])[0].GetString();
356 25 : }
357 :
358 : {
359 25 : auto result = con.Query(fmt::format(R"(CREATE TEMP TABLE '{}' AS SELECT * FROM '{}')",
360 50 : table_name, table_it->second.path));
361 25 : if (result->HasError()) {
362 0 : throw cpptrace::runtime_error("Error creating table: " + result->GetError());
363 : }
364 25 : }
365 :
366 25 : table_it->second.path = table_name;
367 25 : table_it->second.cached = true;
368 25 : }
369 :
370 632 : std::string ParquetManager::get_path(const std::string &key, const std::string &table) {
371 : // Update the local table if a newer version is available remotely
372 632 : this->update_local_asset(key);
373 :
374 : // Ensure availability of the local table file
375 632 : auto asset_it = local_asset_info.find(key);
376 632 : if (asset_it == local_asset_info.end()) {
377 0 : throw std::runtime_error("Table " + key + "_" + table + " not found.");
378 : }
379 632 : auto table_it = asset_it->second.paths.find(table);
380 632 : if (table_it == asset_it->second.paths.end()) {
381 1 : throw std::runtime_error("Table " + key + "_" + table + " not found.");
382 : }
383 :
384 : // Cache the local table in memory if requested
385 631 : if (use_cache_) {
386 629 : this->cache_table(table_it);
387 : }
388 :
389 : // Return the path to the local table file
390 1262 : return table_it->second.path;
391 : }
392 :
393 2 : std::string ParquetManager::get_versions_info() const {
394 : // Helper lambda returns the version string if available
395 36 : auto get_version = [](const auto &map, const std::string &table) -> int {
396 36 : if (auto it = map.find(table); it != map.end()) {
397 18 : return it->second.version_minor;
398 : }
399 18 : return -1;
400 : };
401 :
402 : // Gather all unique table names
403 2 : std::set<std::string> tables;
404 20 : for (const auto &entry : local_asset_info) {
405 18 : tables.insert(entry.first);
406 : }
407 2 : for (const auto &entry : remote_asset_info) {
408 0 : tables.insert(entry.first);
409 : }
410 :
411 : // Output versions info
412 2 : std::ostringstream oss;
413 :
414 2 : oss << " ";
415 2 : oss << std::left << std::setw(17) << "Asset";
416 2 : oss << std::left << std::setw(6 + 4) << "Local";
417 2 : oss << std::left << std::setw(7) << "Remote\n";
418 2 : oss << std::string(35, '-') << "\n";
419 :
420 20 : for (const auto &table : tables) {
421 18 : int local_version = get_version(local_asset_info, table);
422 18 : int remote_version = get_version(remote_asset_info, table);
423 :
424 : std::string comparator = (local_version < remote_version)
425 : ? "<"
426 18 : : ((local_version > remote_version) ? ">" : "==");
427 : std::string local_version_str = local_version == -1
428 18 : ? "N/A"
429 36 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
430 54 : std::to_string(local_version);
431 : std::string remote_version_str = remote_version == -1
432 18 : ? "N/A"
433 18 : : "v" + std::to_string(COMPATIBLE_DATABASE_VERSION_MAJOR) + "." +
434 36 : std::to_string(remote_version);
435 :
436 18 : oss << " ";
437 18 : oss << std::left << std::setw(17) << table;
438 18 : oss << std::left << std::setw(6) << local_version_str;
439 18 : oss << std::left << std::setw(4) << comparator;
440 18 : oss << std::left << std::setw(7) << remote_version_str << "\n";
441 18 : }
442 4 : return oss.str();
443 2 : }
444 :
445 : } // namespace pairinteraction
|