Line data Source code
1 : // SPDX-FileCopyrightText: 2025 Pairinteraction Developers
2 : // SPDX-License-Identifier: LGPL-3.0-or-later
3 :
4 : #include "pairinteraction/database/ParquetManager.hpp"
5 :
6 : #include "pairinteraction/database/Database.hpp"
7 : #include "pairinteraction/database/GitHubDownloader.hpp"
8 :
9 : #include <doctest/doctest.h>
10 : #include <duckdb.hpp>
11 : #include <filesystem>
12 : #include <fstream>
13 : #include <miniz.h>
14 : #include <nlohmann/json.hpp>
15 :
16 : namespace pairinteraction {
17 : class MockDownloader : public GitHubDownloader {
18 : public:
19 : std::future<GitHubDownloader::Result>
20 3 : download(const std::string &remote_url, const std::string & /*if_modified_since*/ = "",
21 : bool /*use_octet_stream*/ = false) const override {
22 3 : GitHubDownloader::Result result;
23 3 : result.status_code = 200;
24 3 : result.rate_limit.remaining = 60;
25 3 : result.rate_limit.reset_time = 2147483647;
26 :
27 3 : if (remote_url == "/test/repo/path") {
28 : // This is the repo path request, return JSON with assets
29 1 : nlohmann::json assets = nlohmann::json::array();
30 1 : nlohmann::json asset;
31 1 : asset["name"] = "misc_v1.2.zip";
32 1 : asset["url"] = "https://api.github.com/test/path/misc_v1.2.zip";
33 1 : assets.push_back(asset);
34 1 : nlohmann::json response;
35 1 : response["assets"] = assets;
36 1 : result.body = response.dump();
37 3 : } else if (remote_url == "/rate_limit") {
38 : // This is the rate limit request
39 1 : result.body = "";
40 : } else {
41 : // This is the file download request
42 1 : std::string content = "updated_file_content";
43 1 : std::string filename = "misc_v1.2/wigner.parquet";
44 :
45 1 : mz_zip_archive zip_archive{};
46 1 : size_t zip_size = 0;
47 1 : void *zip_data = nullptr;
48 :
49 1 : mz_zip_writer_init_heap(&zip_archive, 0, 0);
50 1 : mz_zip_writer_add_mem(&zip_archive, filename.c_str(), content.data(), content.size(),
51 : MZ_BEST_SPEED);
52 1 : mz_zip_writer_finalize_heap_archive(&zip_archive, &zip_data, &zip_size);
53 :
54 1 : result.body = std::string(static_cast<char *>(zip_data), zip_size);
55 :
56 1 : mz_free(zip_data);
57 1 : mz_zip_writer_end(&zip_archive);
58 1 : }
59 :
60 12 : return std::async(std::launch::deferred, [result]() { return result; });
61 6 : }
62 : };
63 :
64 3 : TEST_CASE("ParquetManager functionality with mocked downloader") {
65 3 : MockDownloader downloader;
66 3 : auto test_dir = std::filesystem::temp_directory_path() / "pairinteraction_test_db";
67 3 : std::filesystem::create_directories(test_dir / "tables" / "misc_v1.0");
68 3 : std::filesystem::create_directories(test_dir / "tables" / "misc_v1.1");
69 3 : std::ofstream(test_dir / "tables" / "misc_v1.0" / "wigner.parquet").close();
70 3 : std::ofstream(test_dir / "tables" / "misc_v1.1" / "wigner.parquet").close();
71 3 : duckdb::DuckDB db(nullptr);
72 3 : duckdb::Connection con(db);
73 :
74 3 : SUBCASE("Check missing table") {
75 1 : std::vector<std::string> repo_paths;
76 1 : ParquetManager manager(test_dir, downloader, repo_paths, con, false);
77 1 : manager.scan_local();
78 1 : manager.scan_remote();
79 :
80 5 : CHECK_THROWS_WITH_AS(manager.get_path("misc", "missing_table"),
81 : "Table misc_missing_table not found.", std::runtime_error);
82 4 : }
83 :
84 3 : SUBCASE("Check version parsing") {
85 1 : std::vector<std::string> repo_paths;
86 1 : ParquetManager manager(test_dir, downloader, repo_paths, con, false);
87 1 : manager.scan_local();
88 1 : manager.scan_remote();
89 :
90 1 : std::string expected = (test_dir / "tables" / "misc_v1.1" / "wigner.parquet").string();
91 1 : CHECK(manager.get_path("misc", "wigner") == expected);
92 4 : }
93 :
94 3 : SUBCASE("Check update table") {
95 2 : std::vector<std::string> repo_paths = {"/test/repo/path"};
96 1 : ParquetManager manager(test_dir, downloader, repo_paths, con, false);
97 1 : manager.scan_local();
98 1 : manager.scan_remote();
99 :
100 1 : std::string expected = (test_dir / "tables" / "misc_v1.2" / "wigner.parquet").string();
101 1 : CHECK(manager.get_path("misc", "wigner") == expected);
102 :
103 1 : std::ifstream in(expected, std::ios::binary);
104 1 : std::stringstream buffer;
105 1 : buffer << in.rdbuf();
106 1 : CHECK(buffer.str() == "updated_file_content");
107 4 : }
108 :
109 3 : std::filesystem::remove_all(test_dir);
110 5 : }
111 :
112 1 : DOCTEST_TEST_CASE("ParquetManager functionality with github downloader") {
113 1 : if (!Database::get_global_instance().get_download_missing()) {
114 1 : DOCTEST_MESSAGE("Skipping test because download_missing is false.");
115 1 : return;
116 : }
117 0 : GitHubDownloader downloader;
118 0 : duckdb::DuckDB db(nullptr);
119 0 : duckdb::Connection con(db);
120 :
121 : std::vector<std::string> repo_paths = {"/repos/pairinteraction/database-sqdt/releases/latest",
122 0 : "/repos/pairinteraction/database-mqdt/releases/latest"};
123 0 : ParquetManager manager(Database::get_global_instance().get_database_dir(), downloader,
124 0 : repo_paths, con, Database::get_global_instance().get_use_cache());
125 0 : manager.scan_local();
126 0 : manager.scan_remote();
127 :
128 0 : std::string info = manager.get_versions_info();
129 :
130 : // Check that all species are present
131 : std::vector<std::string> should_contain = {
132 : "Cs", "K", "Li", "Na",
133 : "Rb", "Sr87_mqdt", "Sr88_singlet", "Sr88_triplet",
134 : "Sr88_mqdt", "Yb171_mqdt", "Yb173_mqdt", "Yb174_mqdt",
135 0 : "misc"};
136 0 : for (const auto &substr : should_contain) {
137 0 : DOCTEST_CHECK(info.find(substr) != std::string::npos);
138 : }
139 0 : }
140 :
141 : } // namespace pairinteraction
|