Skip to content

Commit

Permalink
Added versioning manager (#8092)
Browse files Browse the repository at this point in the history
  • Loading branch information
qstokkink authored Aug 19, 2024
2 parents db74f82 + b8b2563 commit efe5db5
Show file tree
Hide file tree
Showing 25 changed files with 1,061 additions and 24 deletions.
1 change: 1 addition & 0 deletions .ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ lint.ignore = [
"ARG002",
"ARG005",
"ASYNC109",
"ASYNC110",
"BLE001",
"COM812",
"COM819",
Expand Down
10 changes: 6 additions & 4 deletions src/run_tribler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from pathlib import Path

import pystray
import tribler
from aiohttp import ClientSession
from PIL import Image

import tribler
from tribler.core.session import Session
from tribler.tribler_config import TriblerConfigManager
from tribler.tribler_config import VERSION_SUBDIR, TriblerConfigManager

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,7 +47,7 @@ def get_root_state_directory(requested_path: os.PathLike | None) -> Path:
Get the default application state directory.
"""
root_state_dir = (Path(requested_path) if os.path.isabs(requested_path)
else (Path(os.environ.get("APPDATA", "~")) / ".TriblerExperimental").expanduser().absolute())
else (Path(os.environ.get("APPDATA", "~")) / ".Tribler").expanduser().absolute())
root_state_dir.mkdir(parents=True, exist_ok=True)
return root_state_dir

Expand All @@ -73,8 +74,9 @@ async def main() -> None:
logger.info("Run Tribler: %s", parsed_args)

root_state_dir = get_root_state_directory(os.environ.get('TSTATEDIR', 'state_directory'))
(root_state_dir / VERSION_SUBDIR).mkdir(exist_ok=True, parents=True)
logger.info("Root state dir: %s", root_state_dir)
config = TriblerConfigManager(root_state_dir / "configuration.json")
config = TriblerConfigManager(root_state_dir / VERSION_SUBDIR / "configuration.json")
config.set("state_dir", str(root_state_dir))

if "CORE_API_PORT" in os.environ:
Expand Down
36 changes: 31 additions & 5 deletions src/tribler/core/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def prepare(self, ipv8: IPv8, session: Session) -> None:
from tribler.core.database.tribler_database import TriblerDatabase
from tribler.core.notifier import Notification

db_path = str(Path(session.config.get("state_dir")) / "sqlite" / "tribler.db")
mds_path = str(Path(session.config.get("state_dir")) / "sqlite" / "metadata.db")
db_path = str(Path(session.config.get_version_state_dir()) / "sqlite" / "tribler.db")
mds_path = str(Path(session.config.get_version_state_dir()) / "sqlite" / "metadata.db")
if session.config.get("memory_db"):
db_path = ":memory:"
mds_path = ":memory:"
Expand Down Expand Up @@ -221,7 +221,8 @@ def get_kwargs(self, session: Session) -> dict:
from tribler.core.rendezvous.database import RendezvousDatabase

out = super().get_kwargs(session)
out["database"] = RendezvousDatabase(db_path=Path(session.config.get("state_dir")) / "sqlite" / "rendezvous.db")
out["database"] = (RendezvousDatabase(db_path=Path(session.config.get_version_state_dir()) / "sqlite"
/ "rendezvous.db"))

return out

Expand Down Expand Up @@ -249,7 +250,8 @@ def prepare(self, overlay_provider: IPv8, session: Session) -> None:
from tribler.core.torrent_checker.torrent_checker import TorrentChecker
from tribler.core.torrent_checker.tracker_manager import TrackerManager

tracker_manager = TrackerManager(state_dir=session.config.get("state_dir"), metadata_store=session.mds)
tracker_manager = TrackerManager(state_dir=Path(session.config.get_version_state_dir()),
metadata_store=session.mds)
torrent_checker = TorrentChecker(config=session.config,
download_manager=session.download_manager,
notifier=session.notifier,
Expand Down Expand Up @@ -298,7 +300,7 @@ def get_kwargs(self, session: Session) -> dict:
from ipv8.dht.provider import DHTCommunityProvider

out = super().get_kwargs(session)
out["exitnode_cache"] = Path(session.config.get("state_dir")) / "exitnode_cache.dat"
out["exitnode_cache"] = Path(session.config.get_version_state_dir()) / "exitnode_cache.dat"
out["notifier"] = session.notifier
out["download_manager"] = session.download_manager
out["socks_servers"] = session.socks_servers
Expand Down Expand Up @@ -336,3 +338,27 @@ def get_kwargs(self, session: Session) -> dict:
max_query_history = session.config.get("user_activity/max_query_history")
out["manager"] = UserActivityManager(TaskManager(), session, max_query_history)
return out

@precondition('session.config.get("versioning/enabled")')
class VersioningComponent(ComponentLauncher):
"""
Launch instructions for the versioning of Tribler.
"""

def finalize(self, ipv8: IPv8, session: Session, community: Community) -> None:
"""
When we are done launching, register our REST API.
"""
from tribler.core.versioning.manager import VersioningManager

session.rest_manager.get_endpoint("/api/versioning").versioning_manager = VersioningManager(
community, session.config
)

def get_endpoints(self) -> list[RESTEndpoint]:
"""
Add the database endpoint.
"""
from tribler.core.versioning.restapi.versioning_endpoint import VersioningEndpoint

return [*super().get_endpoints(), VersioningEndpoint()]
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_spec_file_name(settings: TriblerConfigManager) -> str:
"""
Get the file name of the download spec.
"""
return str(Path(settings.get("state_dir")) / SPEC_FILENAME)
return str(Path(settings.get_version_state_dir()) / SPEC_FILENAME)

@staticmethod
def from_defaults(settings: TriblerConfigManager) -> DownloadConfig:
Expand All @@ -127,6 +127,7 @@ def from_defaults(settings: TriblerConfigManager) -> DownloadConfig:
spec_file_name = DownloadConfig.get_spec_file_name(settings)
defaults = ConfigObj(StringIO(SPEC_CONTENT))
defaults["filename"] = spec_file_name
Path(spec_file_name).parent.mkdir(parents=True, exist_ok=True) # Required for the next write
with open(spec_file_name, "wb") as spec_file:
defaults.write(spec_file)
defaults = ConfigObj(StringIO(), configspec=spec_file_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, config: TriblerConfigManager, notifier: Notifier,
super().__init__()
self.config = config

self.state_dir = Path(config.get("state_dir"))
self.state_dir = Path(config.get_version_state_dir())
self.ltsettings: dict[lt.session, dict] = {} # Stores a copy of the settings dict for each libtorrent session
self.ltsessions: dict[int, lt.session] = {}
self.dht_health_manager: DHTHealthManager | None = None
Expand Down Expand Up @@ -176,7 +176,7 @@ def initialize(self) -> None:
Initialize the directory structure, launch the periodic tasks and start libtorrent background processes.
"""
# Create the checkpoints directory
self.checkpoint_directory.mkdir(exist_ok=True)
self.checkpoint_directory.mkdir(exist_ok=True, parents=True)

# Start upnp
if self.config.get("libtorrent/upnp"):
Expand Down Expand Up @@ -245,7 +245,7 @@ async def shutdown(self, timeout: int = 30) -> None:
if self.has_session():
logger.info("Saving state...")
self.notify_shutdown_state("Writing session state to disk.")
with open(self.state_dir / LTSTATE_FILENAME, "wb") as ltstate_file: # noqa: ASYNC101
with open(self.state_dir / LTSTATE_FILENAME, "wb") as ltstate_file: # noqa: ASYNC230
ltstate_file.write(lt.bencode(self.get_session().save_state()))

if self.has_session() and self.config.get("libtorrent/upnp"):
Expand Down
7 changes: 5 additions & 2 deletions src/tribler/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TorrentCheckerComponent,
TunnelComponent,
UserActivityComponent,
VersioningComponent,
)
from tribler.core.libtorrent.download_manager.download_manager import DownloadManager
from tribler.core.libtorrent.restapi.create_torrent_endpoint import CreateTorrentEndpoint
Expand Down Expand Up @@ -121,7 +122,8 @@ def register_launchers(self) -> None:
Register all IPv8 launchers that allow communities to be loaded.
"""
for launcher_class in [ContentDiscoveryComponent, DatabaseComponent, DHTDiscoveryComponent, KnowledgeComponent,
RendezvousComponent, TorrentCheckerComponent, TunnelComponent, UserActivityComponent]:
RendezvousComponent, TorrentCheckerComponent, TunnelComponent, UserActivityComponent,
VersioningComponent]:
instance = launcher_class()
for rest_ep in instance.get_endpoints():
self.rest_manager.add_endpoint(rest_ep)
Expand Down Expand Up @@ -168,7 +170,8 @@ async def start(self) -> None:
self.rest_manager.get_endpoint("/api/ipv8").initialize(self.ipv8)
self.rest_manager.get_endpoint("/api/statistics").ipv8 = self.ipv8
if self.config.get("statistics"):
self.rest_manager.get_endpoint("/api/ipv8").endpoints["/overlays"].enable_overlay_statistics(True, None, True)
self.rest_manager.get_endpoint("/api/ipv8").endpoints["/overlays"].enable_overlay_statistics(True, None,
True)

async def find_api_server(self) -> str | None:
"""
Expand Down
Empty file.
114 changes: 114 additions & 0 deletions src/tribler/core/versioning/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from __future__ import annotations

import logging
import os
import platform
import shutil
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import TYPE_CHECKING

from aiohttp import ClientSession
from packaging.version import Version

from tribler.tribler_config import TriblerConfigManager
from tribler.upgrade_script import FROM, TO, upgrade

if TYPE_CHECKING:
from ipv8.taskmanager import TaskManager

logger = logging.getLogger(__name__)


class VersioningManager:
"""
Version related logic.
"""

def __init__(self, task_manager: TaskManager, config: TriblerConfigManager | None) -> None:
"""
Create a new versioning manager.
"""
super().__init__()
self.task_manager = task_manager
self.config = config or TriblerConfigManager()

def get_current_version(self) -> str | None:
"""
Get the current release version, or None when running from archive or GIT.
"""
try:
return version("tribler")
except PackageNotFoundError:
return None

def get_versions(self) -> list[str]:
"""
Get all versions in our state directory.
"""
return [p for p in os.listdir(self.config.get("state_dir"))
if os.path.isdir(os.path.join(self.config.get("state_dir"), p))]

async def check_version(self) -> str | None:
"""
Check the tribler.org + GitHub websites for a new version.
"""
current_version = self.get_current_version()
if current_version is None:
return None

headers = {
"User-Agent": (f"Tribler/{current_version} "
f"(machine={platform.machine()}; os={platform.system()} {platform.release()}; "
f"python={platform.python_version()}; executable={platform.architecture()[0]})")
}
urls = [
f"https://release.tribler.org/releases/latest?current={current_version}",
"https://api.github.com/repos/tribler/tribler/releases/latest"
]

for url in urls:
try:
async with ClientSession(raise_for_status=True) as session:
response = await session.get(url, headers=headers, timeout=5.0)
response_dict = await response.json(content_type=None)
response_version = response_dict["name"]
if response_version.startswith("v"):
response_version = response_version[1:]
except Exception as e:
logger.info(e)
continue # Case 1: this failed, but we may still have another URL to check. Continue.
if Version(response_version) > Version(current_version):
return response_version # Case 2: we found a newer version. Stop.
break # Case 3: we got a response, but we are already at a newer or equal version. Stop.
return None # Either Case 3 or repeated Case 1: no URLs responded. No new version available.

def can_upgrade(self) -> str | bool:
"""
Check if we have old database/download files to port to our current version.
Returns the version that can be upgraded from.
"""
if os.path.isfile(os.path.join(self.config.get_version_state_dir(), ".upgraded")):
return False # We have the upgraded marker: nothing to do.

if FROM not in self.get_versions():
return False # We can't upgrade from this version.

return FROM if (self.get_current_version() in [None, TO]) else False # Always allow upgrades to git (None).

def perform_upgrade(self) -> None:
"""
Upgrade old database/download files to our current version.
"""
src_dir = Path(self.config.get("state_dir")) / FROM
dst_dir = Path(self.config.get_version_state_dir())
self.task_manager.register_executor_task("Upgrade", upgrade, self.config,
str(src_dir.expanduser().absolute()),
str(dst_dir.expanduser().absolute()))

def remove_version(self, version: str) -> None:
"""
Remove the files for a version.
"""
shutil.rmtree(os.path.join(self.config.get("state_dir"), version), ignore_errors=True)
Empty file.
Loading

0 comments on commit efe5db5

Please sign in to comment.