Line data Source code
1 : // SPDX-FileCopyrightText: 2025 PairInteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/ParquetManager.hpp"
5 :
6 : #include "pairinteraction/database/Database.hpp"
7 : #include "pairinteraction/database/GitHubDownloader.hpp"
8 :
9 : #include <doctest/doctest.h>
10 : #include <duckdb.hpp>
11 : #include <filesystem>
12 : #include <fstream>
13 : #include <miniz.h>
14 : #include <nlohmann/json.hpp>
15 :
16 : namespace pairinteraction {
17 : class MockDownloader : public GitHubDownloader {
18 : public:
19 : std::future<GitHubDownloader::Result>
20 3 : download(const std::string &remote_url, const std::string & /*if_modified_since*/ = "",
21 : bool /*use_octet_stream*/ = false) const override {
22 3 : GitHubDownloader::Result result;
23 3 : result.status_code = 200;
24 3 : result.rate_limit.remaining = 60;
25 3 : result.rate_limit.reset_time = 2147483647;
26 :
27 3 : if (remote_url == "/test/repo/path") {
28 : // This is the repo path request, return JSON with assets
29 1 : nlohmann::json assets = nlohmann::json::array();
30 1 : nlohmann::json asset;
31 1 : asset["name"] = "misc_v1.2.zip";
32 1 : asset["url"] = "https://api.github.com/test/path/misc_v1.2.zip";
33 1 : assets.push_back(asset);
34 1 : nlohmann::json response;
35 1 : response["assets"] = assets;
36 1 : result.body = response.dump();
37 3 : } else if (remote_url == "/rate_limit") {
38 : // This is the rate limit request
39 1 : result.body = "";
40 : } else {
41 : // This is the file download request
42 1 : std::string content = "updated_file_content";
43 1 : std::string filename = "misc_v1.2/wigner.parquet";
44 :
45 1 : mz_zip_archive zip_archive{};
46 1 : size_t zip_size = 0;
47 1 : void *zip_data = nullptr;
48 :
49 1 : mz_zip_writer_init_heap(&zip_archive, 0, 0);
50 1 : mz_zip_writer_add_mem(&zip_archive, filename.c_str(), content.data(), content.size(),
51 : MZ_BEST_SPEED);
52 1 : mz_zip_writer_finalize_heap_archive(&zip_archive, &zip_data, &zip_size);
53 :
54 1 : result.body = std::string(static_cast<char *>(zip_data), zip_size);
55 :
56 1 : mz_free(zip_data);
57 1 : mz_zip_writer_end(&zip_archive);
58 1 : }
59 :
60 12 : return std::async(std::launch::deferred, [result]() { return result; });
61 6 : }
62 : };
63 :
64 3 : TEST_CASE("ParquetManager functionality with mocked downloader") {
65 3 : MockDownloader downloader;
66 3 : auto test_dir = std::filesystem::temp_directory_path() / "pairinteraction_test_db";
67 3 : std::filesystem::create_directories(test_dir / "tables" / "misc_v1.0");
68 3 : std::filesystem::create_directories(test_dir / "tables" / "misc_v1.1");
69 3 : std::ofstream(test_dir / "tables" / "misc_v1.0" / "wigner.parquet").close();
70 3 : std::ofstream(test_dir / "tables" / "misc_v1.1" / "wigner.parquet").close();
71 3 : duckdb::DuckDB db(nullptr);
72 3 : duckdb::Connection con(db);
73 :
74 3 : SUBCASE("Check missing table") {
75 1 : std::vector<std::string> repo_paths;
76 1 : ParquetManager manager(test_dir, downloader, repo_paths, con, false);
77 1 : manager.scan_local();
78 1 : manager.scan_remote();
79 :
80 5 : CHECK_THROWS_WITH_AS(manager.get_path("misc", "missing_table"),
81 : "No table 'missing_table.parquet' found for species 'misc'. The "
82 : "tables for the species are incomplete.",
83 : std::runtime_error);
84 4 : }
85 :
86 3 : SUBCASE("Check version parsing") {
87 1 : std::vector<std::string> repo_paths;
88 1 : ParquetManager manager(test_dir, downloader, repo_paths, con, false);
89 1 : manager.scan_local();
90 1 : manager.scan_remote();
91 :
92 1 : std::string expected = (test_dir / "tables" / "misc_v1.1" / "wigner.parquet").string();
93 1 : CHECK(manager.get_path("misc", "wigner") == expected);
94 4 : }
95 :
96 3 : SUBCASE("Check update table") {
97 2 : std::vector<std::string> repo_paths = {"/test/repo/path"};
98 1 : ParquetManager manager(test_dir, downloader, repo_paths, con, false);
99 1 : manager.scan_local();
100 1 : manager.scan_remote();
101 :
102 1 : std::string expected = (test_dir / "tables" / "misc_v1.2" / "wigner.parquet").string();
103 1 : CHECK(manager.get_path("misc", "wigner") == expected);
104 :
105 1 : std::ifstream in(expected, std::ios::binary);
106 1 : std::stringstream buffer;
107 1 : buffer << in.rdbuf();
108 1 : CHECK(buffer.str() == "updated_file_content");
109 4 : }
110 :
111 3 : std::filesystem::remove_all(test_dir);
112 5 : }
113 :
114 1 : DOCTEST_TEST_CASE("ParquetManager functionality with GitHub downloader") {
115 1 : if (!Database::get_global_instance().get_download_missing()) {
116 1 : DOCTEST_MESSAGE("Skipping test because download_missing is false.");
117 1 : return;
118 : }
119 0 : GitHubDownloader downloader;
120 0 : duckdb::DuckDB db(nullptr);
121 0 : duckdb::Connection con(db);
122 :
123 : std::vector<std::string> repo_paths = {"/repos/pairinteraction/database-sqdt/releases/latest",
124 0 : "/repos/pairinteraction/database-mqdt/releases/latest"};
125 0 : ParquetManager manager(Database::get_global_instance().get_database_dir(), downloader,
126 0 : repo_paths, con, Database::get_global_instance().get_use_cache());
127 0 : manager.scan_local();
128 0 : manager.scan_remote();
129 :
130 0 : std::string info = manager.get_versions_info();
131 :
132 : // Check that all species are present
133 : std::vector<std::string> should_contain = {
134 : "Cs", "K", "Li", "Na",
135 : "Rb", "Sr87_mqdt", "Sr88_singlet", "Sr88_triplet",
136 : "Sr88_mqdt", "Yb171_mqdt", "Yb173_mqdt", "Yb174_mqdt",
137 0 : "misc"};
138 0 : for (const auto &substr : should_contain) {
139 0 : DOCTEST_CHECK(info.find(substr) != std::string::npos);
140 : }
141 0 : }
142 :
143 : } // namespace pairinteraction
|