diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index ce06895..42dcddb 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -27,6 +27,8 @@ jobs: pip install poetry -U poetry install + - name: Run the tests + run: pytest --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html - name: Build and publish if: startsWith(github.ref, 'refs/tags/v') diff --git a/pyproject.toml b/pyproject.toml index 2cb8d0a..b38bcb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pieces_os_client" -version = "3.0.1" +version = "3.1.0" description = "A powerful code engine package for writing applications on top of Pieces OS" authors = ["Pieces "] license = "MIT" diff --git a/src/pieces_os_client/__init__.py b/src/pieces_os_client/__init__.py index a149be8..8a77a0e 100644 --- a/src/pieces_os_client/__init__.py +++ b/src/pieces_os_client/__init__.py @@ -15,7 +15,7 @@ """ # noqa: E501 -__version__ = "3.0.1" +__version__ = "3.1.0" # import apis into sdk package from pieces_os_client.api.activities_api import ActivitiesApi diff --git a/src/pieces_os_client/api_client.py b/src/pieces_os_client/api_client.py index 884b9df..0922e65 100644 --- a/src/pieces_os_client/api_client.py +++ b/src/pieces_os_client/api_client.py @@ -77,7 +77,7 @@ def __init__(self, configuration=None, header_name=None, header_value=None, self.default_headers[header_name] = header_value self.cookie = cookie # Set default User-Agent. - self.user_agent = 'OpenAPI-Generator/3.0.1/python' + self.user_agent = 'OpenAPI-Generator/3.1.0/python' self.client_side_validation = configuration.client_side_validation def __enter__(self): diff --git a/src/pieces_os_client/configuration.py b/src/pieces_os_client/configuration.py index fe2200a..2d0618f 100644 --- a/src/pieces_os_client/configuration.py +++ b/src/pieces_os_client/configuration.py @@ -412,7 +412,7 @@ def to_debug_report(self): "OS: {env}\n"\ "Python Version: {pyversion}\n"\ "Version of the API: 1.0\n"\ - "SDK Package Version: 3.0.1".\ + "SDK Package Version: 3.1.0".\ format(env=sys.platform, pyversion=sys.version) def get_host_settings(self): diff --git a/src/pieces_os_client/wrapper/basic_identifier/__init__.py b/src/pieces_os_client/wrapper/basic_identifier/__init__.py index 1c6851d..f1ebe12 100644 --- a/src/pieces_os_client/wrapper/basic_identifier/__init__.py +++ b/src/pieces_os_client/wrapper/basic_identifier/__init__.py @@ -1,5 +1,6 @@ from .asset import BasicAsset from .chat import BasicChat from .message import BasicMessage +from .user import BasicUser -__all__ = ["BasicAsset","BasicChat","BasicMessage"] +__all__ = ["BasicAsset","BasicChat","BasicMessage","BasicUser"] diff --git a/src/pieces_os_client/wrapper/basic_identifier/asset.py b/src/pieces_os_client/wrapper/basic_identifier/asset.py index 3105a7a..ce4647c 100644 --- a/src/pieces_os_client/wrapper/basic_identifier/asset.py +++ b/src/pieces_os_client/wrapper/basic_identifier/asset.py @@ -16,10 +16,14 @@ SeededFragment, TransferableString, FragmentMetadata, - AssetReclassification) + AssetReclassification, + Linkify, + Shares +) from typing import Optional from .basic import Basic +from .user import BasicUser # Friendly wrapper (to avoid interacting with the pieces_os_client sdks models) @@ -202,22 +206,46 @@ def delete(self): AssetSnapshot.pieces_client.assets_api.assets_delete_asset(self.id) @classmethod - def create(cls,raw: str, metadata: Optional[FragmentMetadata] = None) -> str: + def create(cls,raw_content: str, metadata: Optional[FragmentMetadata] = None) -> str: """ Create a new asset. Args: - raw (str): The raw content of the asset. + raw_content (str): The raw content of the asset. metadata (Optional[FragmentMetadata]): The metadata of the asset. Returns: str: The ID of the created asset. """ - seed = cls._get_seed(raw,metadata) + seed = cls._get_seed(raw_content,metadata) created_asset_id = AssetSnapshot.pieces_client.assets_api.assets_create_new_asset(transferables=False, seed=seed).id return created_asset_id + def share(self) -> Shares: + """ + Generates a shareable link for the given asset. + + Raises: + PermissionError: If the user is not logged in or is not connected to the cloud. + """ + return self._share(self.asset) + + + @classmethod + def share_raw_content(cls,raw_content:str) -> Shares: + """ + Generates a shareable link for the given user raw content. + Note: this will create an asset + + Args: + raw_content (str): The raw content of the asset that will be shared. + + Raises: + PermissionError: If the user is not logged in or is not connected to the cloud. + """ + return cls._share(seed = cls._get_seed(raw_content)) + @staticmethod def _get_seed(raw: str, metadata: Optional[FragmentMetadata] = None) -> Seed: return Seed( @@ -283,4 +311,27 @@ def _ocr_from_format(src: Optional[Format]) -> Optional[str]: def _edit_asset(asset): AssetSnapshot.pieces_client.asset_api.asset_update(False,asset) + @staticmethod + def _share(asset=None,seed=None): + """ + You need to either give the seed or the asset_id + """ + if asset: + kwargs = {"asset" : asset} + else: + kwargs = {"seed" : seed} + + user = BasicUser.user_profile + + if not user: + raise PermissionError("You need to be logged in to generate a shareable link") + if not user.allocation: + raise PermissionError("You need to connect to the cloud to generate a shareable link") + + return AssetSnapshot.pieces_client.linkfy_api.linkify( + linkify=Linkify( + access="PUBLIC", + **kwargs + ) + ) diff --git a/src/pieces_os_client/wrapper/basic_identifier/user.py b/src/pieces_os_client/wrapper/basic_identifier/user.py new file mode 100644 index 0000000..61eaf19 --- /dev/null +++ b/src/pieces_os_client/wrapper/basic_identifier/user.py @@ -0,0 +1,144 @@ +import threading +from pieces_os_client import UserProfile,AllocationStatusEnum +from typing import Optional + +## TODO: Modify the Basic class to be able to fit in the BasicUser +class BasicUser: + """ + A class to represent a basic user and manage their connection to the cloud. + + Attributes: + user_profile: The profile of the user. + pieces_client: The client used to interact with the pieces OS API. + """ + + user_profile: Optional[UserProfile] = None + + def __init__(self, pieces_client) -> None: + """ + Initializes the BasicUser with a pieces client. + + Args: + pieces_client: The client used to interact with the pieces OS API. + """ + self.pieces_client = pieces_client + + def on_user_callback(self, user: Optional[UserProfile], connecting=False): + """ + Callback function to set the user profile. + + Args: + user: The profile of the user. + connecting: A flag indicating if the user is connecting to the cloud (default is False). + """ + self.user_profile = user + + def _on_login_connect(self, thread, timeout): + """ + Waits for the user to login and then connects to the cloud. + + Args: + thread: The thread handling the login process. + timeout: The maximum time to wait for the login process. + """ + thread.get(timeout) # Wait for the user to login + self.connect() + + def login(self, connect_after_login=True, timeout=120): + """ + Logs the user into the OS and optionally connects to the cloud. + + Args: + connect_after_login: A flag indicating if the user should connect to the cloud after login (default is True). + timeout: The maximum time to wait for the login process (default is 120 seconds). + """ + thread = self.pieces_client.os_api.sign_into_os(async_req=True) + if connect_after_login: + threading.Thread(target=lambda: self._on_login_connect(thread, timeout)) + + def logout(self): + """ + Logs the user out of the OS. + """ + self.pieces_client.api_client.os_api.sign_out_of_os() + + def connect(self): + """ + Connects the user to the cloud. + + Raises: + PermissionError: If the user is not logged in. + """ + if not self.user_profile: + raise PermissionError("You must be logged in to use this feature") + self.on_user_callback(self.user_profile, True) # Set the connecting to cloud bool to true + self.pieces_client.allocations_api.allocations_connect_new_cloud(self.user_profile) + + def disconnect(self): + """ + Disconnects the user from the cloud. + + Raises: + PermissionError: If the user is not logged in. + """ + if not self.user_profile: + raise PermissionError("You must be logged in to use this feature") + if self.user_profile.allocation: # Check if there is an allocation iterable + self.pieces_client.api_client.allocations_api.allocations_disconnect_cloud(self.user_profile.allocation) + + @property + def picture(self) -> Optional[str]: + """ + Returns the user's profile picture URL. + + Returns: + The URL of the user's profile picture, or None if not available. + """ + if self.user_profile: + return self.user_profile.picture + + + @property + def name(self) -> Optional[str]: + """ + Returns the name of the user. + + Returns: + Optional[str]: The name of the user if the user logged in, otherwise None. + """ + if self.user_profile: + return self.user_profile.name + + @property + def email(self) -> Optional[str]: + """ + Returns the email of the user. + + Returns: + Optional[str]: The email of the user if the user logged in, otherwise None. + """ + if self.user_profile: + return self.user_profile.email + + @property + def vanity_name(self) -> Optional[str]: # TODO: implements the setter object + """ + Returns the vanity name of the user which is the base name of the cloud url. + For example, if the cloud URL is 'bishoyatpieces.pieces.cloud', this method returns 'bishoyatpieces'. + + Returns: + Optional[str]: The vanity name of the user if the user user logged in and connected to the cloud, otherwise None. + """ + if self.user_profile: + return self.user_profile.vanityname + + @property + def cloud_status(self) -> Optional[AllocationStatusEnum]: + """ + Returns the cloud status of the user's cloud. + + Returns: + Optional[AllocationStatusEnum]: The cloud status of the user's cloud. + """ + if self.user_profile and self.user_profile.allocation: + return self.user_profile.allocation.status.cloud diff --git a/src/pieces_os_client/wrapper/client.py b/src/pieces_os_client/wrapper/client.py index 6b56ed3..2bce4ed 100644 --- a/src/pieces_os_client/wrapper/client.py +++ b/src/pieces_os_client/wrapper/client.py @@ -1,6 +1,5 @@ from pieces_os_client import ( ApiClient, - Application, Configuration, ConversationApi, ConversationMessageApi, @@ -16,32 +15,33 @@ AssetsApi, FragmentMetadata, ModelsApi, - AnnotationApi + AnnotationApi, + LinkifyApi, + WellKnownApi, + OSApi, + AllocationsApi, + __version__ ) from typing import Optional,Dict import platform import atexit +import subprocess + from .copilot import Copilot -from .basic_identifier import BasicAsset +from .basic_identifier import BasicAsset,BasicUser from .streamed_identifiers import AssetSnapshot from .websockets import * class PiecesClient: - def __init__(self, host:str="",config: dict={}, seeded_connector: Optional[SeededConnectorConnection] = None): + def __init__(self, host:str="", seeded_connector: Optional[SeededConnectorConnection] = None,**kwargs): if host: self.host = host else: self.host = "http://localhost:5323" if 'Linux' in platform.platform() else "http://localhost:1000" - connect_websockets= True - if "connect_websockets" in config: - connect_websockets = config["connect_websockets"] - del config["connect_websockets"] - - self.config = Configuration(**config) - self.api_client = ApiClient(self.config) + self.api_client = ApiClient(Configuration(self.host)) self.conversation_message_api = ConversationMessageApi(self.api_client) self.conversation_messages_api = ConversationMessagesApi(self.api_client) @@ -55,10 +55,14 @@ def __init__(self, host:str="",config: dict={}, seeded_connector: Optional[Seede self.connector_api = ConnectorApi(self.api_client) self.models_api = ModelsApi(self.api_client) self.annotation_api = AnnotationApi(self.api_client) + self.well_known_api = WellKnownApi(self.api_client) + self.os_api = OSApi(self.api_client) + self.allocations_api = AllocationsApi(self.api_client) + self.linkfy_api = LinkifyApi(self.api_client) # Websocket urls - if 'http' not in self.host: - raise TypeError("Invalid host url\n Host should start with http or https") + if not self.host.startswith("http"): + raise ValueError("Invalid host url\n Host should start with http or https") ws_base_url:str = self.host.replace('http','ws') self.ASSETS_IDENTIFIERS_WS_URL = ws_base_url + "/assets/stream/identifiers" @@ -67,25 +71,28 @@ def __init__(self, host:str="",config: dict={}, seeded_connector: Optional[Seede self.CONVERSATION_WS_URL = ws_base_url + "/conversations/stream/identifiers" self.HEALTH_WS_URL = ws_base_url + "/.well-known/stream/health" - local_os = platform.system().upper() if platform.system().upper() in ["WINDOWS","LINUX","DARWIN"] else "WEB" - local_os = "MACOS" if local_os == "DARWIN" else local_os + self.local_os = platform.system().upper() if platform.system().upper() in ["WINDOWS","LINUX","DARWIN"] else "WEB" + self.local_os = "MACOS" if self.local_os == "DARWIN" else self.local_os seeded_connector = seeded_connector or SeededConnectorConnection( application=SeededTrackedApplication( name = "OPEN_SOURCE", - platform = local_os, - version = "0.0.1")) + platform = self.local_os, + version = __version__)) self.tracked_application = self.connector_api.connect(seeded_connector_connection=seeded_connector).application - if connect_websockets: + self.user = BasicUser(self) + + if kwargs.get("connect_wesockets",True): self.conversation_ws = ConversationWS(self) self.assets_ws = AssetsIdentifiersWS(self) - + self.user_websocket = AuthWS(self,self.user.on_user_callback) # Start all initilized websockets BaseWebsocket.start_all() self.models = None self.model_name = "GPT-3.5-turbo Chat Model" + self.copilot = Copilot(self) def assets(self): @@ -100,13 +107,6 @@ def asset(self,asset_id): def create_asset(content:str,metadata:Optional[FragmentMetadata]=None): return BasicAsset.create(content,metadata) - def get_user_profile_picture(self) -> Optional[str]: - try: - user_res = self.user_api.user_snapshot() - return user_res.user.picture or None - except Exception as error: - print(f'Error getting user profile picture: {error}') - return None def get_models(self) -> Dict[str, str]: if self.models: @@ -132,10 +132,6 @@ def model_name(self,model): def available_models_names(self) -> list: return list(self.get_models().keys()) - @property - def copilot(self): - return Copilot(self) - def ensure_initialization(self): """ Waits for all the assets/conversations and all the started websockets to open @@ -148,6 +144,38 @@ def close(self): """ BaseWebsocket.close_all() + @property + def version(self) -> str: + """ + Returns Pieces OS Version + """ + return self.well_known_api.get_well_known_version() + + @property + def health(self) -> bool: + """ + Returns True Pieces OS health is ok else False + """ + try: + return self.well_known_api.get_well_known_health_with_http_info().status_code == 200 + except: + pass + return False + + def open_pieces_os(self) -> bool: + """ + Open Pieces OS + + Returns (bool): true if Pieces OS runned successfully else false + """ + if self.local_os == "WINDOWS": + subprocess.run(["start", "pieces://launch"], shell=True) + elif self.local_os == "MACOS": + subprocess.run(["open","pieces://launch"]) + elif self.local_os == "LINUX": + subprocess.run(["xdg-open","pieces://launch"]) + return self.health + # Register the function to be called on exit atexit.register(BaseWebsocket.close_all) diff --git a/src/pieces_os_client/wrapper/copilot.py b/src/pieces_os_client/wrapper/copilot.py index 2793a77..75bdb57 100644 --- a/src/pieces_os_client/wrapper/copilot.py +++ b/src/pieces_os_client/wrapper/copilot.py @@ -141,3 +141,19 @@ def chat(self, chat: Optional[BasicChat]): self._chat = chat + def create_chat(self, name:Optional[str]=None): + """ + Creates a New Chat and change the current Copilot chat state to the new generated one + """ + new_conversation = self.pieces_client.conversations_api.conversations_create_specific_conversation( + seeded_conversation={ + 'name': name, + 'type': 'COPILOT', + } + ) + + ConversationsSnapshot.identifiers_snapshot[new_conversation.id] = new_conversation # Make sure to update the cache + self.chat = BasicChat(new_conversation.id) + + return self.chat + diff --git a/src/pieces_os_client/wrapper/streamed_identifiers/_streamed_identifiers.py b/src/pieces_os_client/wrapper/streamed_identifiers/_streamed_identifiers.py index 0fe5b7d..6a464b5 100644 --- a/src/pieces_os_client/wrapper/streamed_identifiers/_streamed_identifiers.py +++ b/src/pieces_os_client/wrapper/streamed_identifiers/_streamed_identifiers.py @@ -98,7 +98,7 @@ def update_identifier(cls, identifier: str): id_value = cls._api_call(identifier) with cls._lock: cls.identifiers_snapshot[identifier] = id_value - cls.on_update(id_value) + cls.on_update(id_value) return id_value except Exception as e: print(f"Error updating identifier {identifier}: {e}") diff --git a/src/pieces_os_client/wrapper/streamed_identifiers/assets_snapshot.py b/src/pieces_os_client/wrapper/streamed_identifiers/assets_snapshot.py index 82faddf..547b180 100644 --- a/src/pieces_os_client/wrapper/streamed_identifiers/assets_snapshot.py +++ b/src/pieces_os_client/wrapper/streamed_identifiers/assets_snapshot.py @@ -10,7 +10,7 @@ class AssetSnapshot(StreamedIdentifiersCache): @classmethod def _api_call(cls, id): asset = cls.pieces_client.asset_api.asset_snapshot(id) - cls.on_update(asset) + # cls.on_update(asset) return asset diff --git a/src/pieces_os_client/wrapper/streamed_identifiers/conversations_snapshot.py b/src/pieces_os_client/wrapper/streamed_identifiers/conversations_snapshot.py index 5fdd4ff..16fd991 100644 --- a/src/pieces_os_client/wrapper/streamed_identifiers/conversations_snapshot.py +++ b/src/pieces_os_client/wrapper/streamed_identifiers/conversations_snapshot.py @@ -19,5 +19,5 @@ def _sort_first_shot(cls): @classmethod def _api_call(cls,id): con = cls.pieces_client.conversation_api.conversation_get_specific_conversation(id) - cls.on_update(con) + # cls.on_update(con) return con diff --git a/src/pieces_os_client/wrapper/version_compatibility.py b/src/pieces_os_client/wrapper/version_compatibility.py new file mode 100644 index 0000000..f5f3257 --- /dev/null +++ b/src/pieces_os_client/wrapper/version_compatibility.py @@ -0,0 +1,49 @@ +from enum import Enum +import re +from typing import Optional + +class VersionChecker: + def __init__(self, min_version: str, max_version: str, pieces_os_version:str): + self.min_version = min_version + self.max_version = max_version + self.pieces_os_version = pieces_os_version + + @staticmethod + def _parse_version(version_str): + """Parse a version string into a tuple of integers and pre-release labels.""" + match = re.match(r'^(\d+)\.(\d+)\.(\d+)(?:[-.](\S+))?$', version_str) + if not match: + raise ValueError(f"Invalid version format: {version_str}") + + major, minor, patch, pre_release = match.groups() + version_tuple = (int(major), int(minor), int(patch)) + pre_release_tuple = tuple(pre_release.split('.')) if pre_release else () + return version_tuple, pre_release_tuple + + def version_check(self): + """Check if the Pieces OS version is within the supported range.""" + # Parse version numbers + os_version_parsed, os_pre_release = self._parse_version(self.pieces_os_version) + min_version_parsed, min_pre_release = self._parse_version(self.min_version) + max_version_parsed, max_pre_release = self._parse_version(self.max_version) + + # Determine compatibility + if (os_version_parsed < min_version_parsed or + (os_version_parsed == min_version_parsed and os_pre_release < min_pre_release)): + return VersionCheckResult(False, UpdateEnum.PiecesOS) + elif (os_version_parsed > max_version_parsed or + (os_version_parsed == max_version_parsed and os_pre_release >= max_pre_release)): + return VersionCheckResult(False, UpdateEnum.Plugin) + + return VersionCheckResult(True) + + +class UpdateEnum(Enum): + PiecesOS = 1 + Plugin = 2 + +class VersionCheckResult: + """Result of the version check.""" + def __init__(self, compatible, update:Optional[UpdateEnum]=None): + self.compatible = compatible + self.update = update diff --git a/tests/test_basic_asset.py b/tests/test_basic_asset.py index af67014..d37f248 100644 --- a/tests/test_basic_asset.py +++ b/tests/test_basic_asset.py @@ -1,140 +1,191 @@ import pytest -from unittest.mock import patch, MagicMock -from pieces_os_client import Asset, Format, ClassificationGenericEnum, ClassificationSpecificEnum, Annotations, Annotation +from unittest.mock import Mock, patch, MagicMock +from pieces_os_client import (Asset, + Format, + ClassificationGenericEnum, + ClassificationSpecificEnum, + Annotations, + Annotation, + Linkify) from datetime import datetime from pieces_os_client.wrapper.basic_identifier import BasicAsset from pieces_os_client.wrapper.streamed_identifiers.assets_snapshot import AssetSnapshot -from pieces_os_client.wrapper.client import PiecesClient - -class BasicAssetTest(BasicAsset): - def __init__(self, id) -> None: - self.asset:Asset = AssetSnapshot.identifiers_snapshot.get(id) - if not self.asset: - print(f"Asset not found for ID: {id}") - print("Available IDs in AssetSnapshot.identifiers_snapshot:", AssetSnapshot.identifiers_snapshot.keys()) - raise ValueError("Asset not found") - -@pytest.fixture(scope="function") -def pieces_client(): - return PiecesClient() - -@pytest.fixture(scope="function", autouse=True) -def setup_asset_snapshot(pieces_client): - AssetSnapshot.identifiers_snapshot = {} - print("Before test:", AssetSnapshot.identifiers_snapshot) - yield - print("After test:", AssetSnapshot.identifiers_snapshot) - AssetSnapshot.identifiers_snapshot = {} - -@pytest.fixture -def mock_asset(): - mock = MagicMock(spec=Asset) - mock.id = "test_asset_id" - mock.name = "Test Asset" - mock.original = MagicMock() - mock.original.reference.fragment.string.raw = "Test content" - mock.original.reference.classification.specific = ClassificationSpecificEnum.PY - mock.original.reference.classification.generic = ClassificationGenericEnum.CODE - mock.formats = MagicMock() - mock.formats.iterable = [] - return mock - -@pytest.fixture -def mock_asset_snapshot(mock_asset): - AssetSnapshot.identifiers_snapshot[mock_asset.id] = mock_asset - print(f"Added mock asset with ID {mock_asset.id} to AssetSnapshot.identifiers_snapshot") - return AssetSnapshot - -def test_basic_asset_initialization(mock_asset_snapshot, mock_asset): - print("In test_basic_asset_initialization:") - print("AssetSnapshot.identifiers_snapshot:", AssetSnapshot.identifiers_snapshot) - print("mock_asset.id:", mock_asset.id) - asset = BasicAsset(mock_asset.id) - assert asset.asset == mock_asset - -def test_raw_content_property(mock_asset_snapshot, mock_asset): - asset = BasicAsset(mock_asset.id) - assert asset.raw_content == "Test content" - -def test_is_image(mock_asset_snapshot, mock_asset): - asset = BasicAsset(mock_asset.id) - assert not asset.is_image - - mock_asset.original.reference.classification.generic = ClassificationGenericEnum.IMAGE - assert asset.is_image - -def test_classification_property(mock_asset_snapshot, mock_asset): - asset = BasicAsset(mock_asset.id) - assert asset.classification == ClassificationSpecificEnum.PY - -def test_edit_content(mock_asset_snapshot, mock_asset, pieces_client): - asset = BasicAsset(mock_asset.id) - new_content = "Updated content" - - with patch.object(pieces_client, 'format_api') as mock_format_api: - mock_format = MagicMock(spec=Format) - mock_format.classification = MagicMock() - mock_format.classification.generic = ClassificationGenericEnum.CODE - mock_format.fragment = MagicMock() - mock_format.fragment.string = MagicMock() - mock_format.fragment.string.raw = "Test content" - mock_format_api.format_snapshot.return_value = mock_format - - asset.raw_content = new_content - - mock_format_api.format_update_value.assert_called_once() - assert mock_format.fragment.string.raw == new_content - -def test_edit_name(mock_asset_snapshot, mock_asset, pieces_client): - asset = BasicAsset(mock_asset.id) - new_name = "New Asset Name" - - with patch.object(pieces_client, 'asset_api') as mock_asset_api: - asset.name = new_name - - assert asset.asset.name == new_name - mock_asset_api.asset_update.assert_called_once() - -def test_name_property(mock_asset_snapshot, mock_asset): - asset = BasicAsset(mock_asset.id) - assert asset.name == "Test Asset" - - mock_asset.name = None - assert asset.name == "Unnamed snippet" - -def test_description_property(mock_asset_snapshot, mock_asset): - asset = BasicAsset(mock_asset.id) - mock_annotation = MagicMock(spec=Annotation) - mock_annotation.type = "DESCRIPTION" - mock_annotation.text = "Test description" - mock_annotation.updated = MagicMock() - mock_annotation.updated.value = datetime.now() - mock_annotations = [ - mock_annotation, - MagicMock(spec=Annotation, type="OTHER", text="Other annotation", updated=MagicMock(value=datetime.now())) - ] - mock_asset.annotations = MagicMock(spec=Annotations, iterable=mock_annotations) - - assert asset.description == "Test description" - -def test_annotations_property(mock_asset_snapshot, mock_asset): - asset = BasicAsset(mock_asset.id) - mock_annotations = [MagicMock(spec=Annotation), MagicMock(spec=Annotation)] - mock_asset.annotations = MagicMock(spec=Annotations, iterable=mock_annotations) - - assert asset.annotations == mock_annotations - -def test_delete(mock_asset_snapshot, mock_asset, pieces_client): - asset = BasicAsset(mock_asset.id) - with patch.object(pieces_client, 'assets_api') as mock_assets_api: - asset.delete() - mock_assets_api.assets_delete_asset.assert_called_once_with(mock_asset.id) - -def test_create(pieces_client): - with patch.object(pieces_client, 'assets_api') as mock_assets_api: - mock_assets_api.assets_create_new_asset.return_value = MagicMock(id="new_asset_id") - - new_asset_id = BasicAsset.create("New asset content") + +class TestBasicAsset: + @pytest.fixture(autouse=True) + def setup(self): + self.mock_asset = Mock(spec=Asset) + self.mock_asset.id = "test_asset_id" + self.mock_asset.name = "Test Asset" + self.mock_asset.original = Mock() + self.mock_asset.original.reference.fragment.string.raw = "Test content" + self.mock_asset.original.reference.classification.specific = ClassificationSpecificEnum.PY + self.mock_asset.original.reference.classification.generic = ClassificationGenericEnum.CODE + self.mock_asset.formats = Mock() + self.mock_asset.formats.iterable = [] + AssetSnapshot.identifiers_snapshot = {"test_asset_id": self.mock_asset} + AssetSnapshot.pieces_client = Mock() + + def test_basic_asset_initialization(self): + asset = BasicAsset("test_asset_id") + assert asset.id == "test_asset_id" + assert asset.asset == self.mock_asset + + def test_basic_asset_initialization_invalid_id(self): + with pytest.raises(ValueError, match="Asset not found"): + BasicAsset("invalid_id").asset + + def test_raw_content_property(self): + asset = BasicAsset("test_asset_id") + assert asset.raw_content == "Test content" + + def test_is_image(self): + asset = BasicAsset("test_asset_id") + assert not asset.is_image + + self.mock_asset.original.reference.classification.generic = ClassificationGenericEnum.IMAGE + assert asset.is_image + + def test_classification_property(self): + asset = BasicAsset("test_asset_id") + assert asset.classification == ClassificationSpecificEnum.PY + + def test_edit_content(self): + asset = BasicAsset("test_asset_id") + new_content = "Updated content" + + with patch.object(AssetSnapshot.pieces_client, 'format_api') as mock_format_api: + mock_format = MagicMock(spec=Format) + mock_format.classification = MagicMock() + mock_format.classification.generic = ClassificationGenericEnum.CODE + mock_format.fragment = MagicMock() + mock_format.fragment.string = MagicMock() + mock_format.fragment.string.raw = "Test content" + mock_format_api.format_snapshot.return_value = mock_format + + asset.raw_content = new_content + + mock_format_api.format_update_value.assert_called_once() + assert mock_format.fragment.string.raw == new_content + + def test_edit_name(self): + asset = BasicAsset("test_asset_id") + new_name = "New Asset Name" + + with patch.object(AssetSnapshot.pieces_client, 'asset_api') as mock_asset_api: + asset.name = new_name + + assert asset.asset.name == new_name + mock_asset_api.asset_update.assert_called_once() + + def test_name_property(self): + asset = BasicAsset("test_asset_id") + assert asset.name == "Test Asset" + + self.mock_asset.name = None + assert asset.name == "Unnamed snippet" + + def test_description_property(self): + asset = BasicAsset("test_asset_id") + mock_annotation = MagicMock(spec=Annotation) + mock_annotation.type = "DESCRIPTION" + mock_annotation.text = "Test description" + mock_annotation.updated = MagicMock() + mock_annotation.updated.value = datetime.now() + mock_annotations = [ + mock_annotation, + MagicMock(spec=Annotation, type="OTHER", text="Other annotation", updated=MagicMock(value=datetime.now())) + ] + self.mock_asset.annotations = MagicMock(spec=Annotations, iterable=mock_annotations) + + assert asset.description == "Test description" + + def test_annotations_property(self): + asset = BasicAsset("test_asset_id") + mock_annotations = [MagicMock(spec=Annotation), MagicMock(spec=Annotation)] + self.mock_asset.annotations = MagicMock(spec=Annotations, iterable=mock_annotations) + + assert asset.annotations == mock_annotations + + def test_delete(self): + asset = BasicAsset("test_asset_id") + with patch.object(AssetSnapshot.pieces_client, 'assets_api') as mock_assets_api: + asset.delete() + mock_assets_api.assets_delete_asset.assert_called_once_with("test_asset_id") + + def test_create(self): + raw_content = "New asset content" + mock_seed = MagicMock() - assert new_asset_id == "new_asset_id" - mock_assets_api.assets_create_new_asset.assert_called_once() + with patch.object(BasicAsset, '_get_seed', return_value=mock_seed) as mock_get_seed, \ + patch.object(AssetSnapshot.pieces_client.assets_api, 'assets_create_new_asset') as mock_create_new_asset: + + mock_create_new_asset.return_value = MagicMock(id="new_asset_id") + + new_asset_id = BasicAsset.create(raw_content) + + mock_get_seed.assert_called_once_with(raw_content, None) + mock_create_new_asset.assert_called_once_with(transferables=False, seed=mock_seed) + assert new_asset_id == "new_asset_id" + + def test_share_with_asset(self): + asset = BasicAsset("test_asset_id") + mock_user_profile = MagicMock() + mock_user_profile.allocation = True + + with patch('pieces_os_client.wrapper.basic_identifier.BasicUser.user_profile', mock_user_profile), \ + patch.object(AssetSnapshot.pieces_client.linkfy_api, 'linkify') as mock_linkify: + + mock_linkify.return_value = "shareable_link" + + shareable_link = asset._share(asset=asset.asset) + + mock_linkify.assert_called_once_with( + linkify=Linkify( + access="PUBLIC", + asset=asset.asset + ) + ) + assert shareable_link == "shareable_link" + + def test_share_with_seed(self): + mock_seed = MagicMock() + mock_user_profile = MagicMock() + mock_user_profile.allocation = True + + with patch('pieces_os_client.wrapper.basic_identifier.BasicUser.user_profile', mock_user_profile), \ + patch.object(AssetSnapshot.pieces_client.linkfy_api, 'linkify') as mock_linkify: + + mock_linkify.return_value = "shareable_link" + + shareable_link = BasicAsset._share(seed=mock_seed) + + mock_linkify.assert_called_once_with( + linkify=Linkify( + access="PUBLIC", + seed=mock_seed + ) + ) + assert shareable_link == "shareable_link" + + def test_share_without_user_profile(self): + basic_asset = BasicAsset("test_asset_id") + + with patch('pieces_os_client.wrapper.basic_identifier.BasicUser.user_profile', None): + with pytest.raises(PermissionError, match="You need to be logged in to generate a shareable link"): + basic_asset._share(asset=basic_asset.asset) + + def test_share_without_allocation(self): + basic_asset = BasicAsset("test_asset_id") + mock_user_profile = MagicMock() + mock_user_profile.allocation = False + + with patch('pieces_os_client.wrapper.basic_identifier.BasicUser.user_profile', mock_user_profile): + with pytest.raises(PermissionError, match="You need to connect to the cloud to generate a shareable link"): + basic_asset._share(asset=basic_asset.asset) + + + +if __name__ == '__main__': + pytest.main([__file__]) + diff --git a/tests/test_basic_chat.py b/tests/test_basic_chat.py index c101086..ad25698 100644 --- a/tests/test_basic_chat.py +++ b/tests/test_basic_chat.py @@ -19,7 +19,7 @@ def test_init_valid_id(self): def test_init_invalid_id(self): with pytest.raises(ValueError, match="Conversation not found"): - BasicChat("invalid_id") + b = BasicChat("invalid_id").conversation # Call the conversation to check if it is vaild def test_name_property(self): chat = BasicChat("test_id") @@ -35,19 +35,18 @@ def test_name_property_default(self): assert chat.name == "New Conversation" @patch.object(BasicMessage, '__init__', return_value=None) - @patch.object(BasicChat, '_get_message') - def test_messages(self, mock_get_message, mock_basic_message_init): + def test_messages(self, mock_basic_message_init): + ConversationsSnapshot.identifiers_snapshot["test_id"].messages = Mock() ConversationsSnapshot.identifiers_snapshot["test_id"].messages.indices = { "msg1": 0, "msg2": 1, "msg3": -1 # Deleted message } - # Mock the _get_message method to return a Mock object - mock_get_message.side_effect = lambda message_id: Mock(id=message_id) chat = BasicChat("test_id") messages = chat.messages() + assert len(messages) == 2 assert all(isinstance(msg, BasicMessage) for msg in messages) @@ -55,16 +54,10 @@ def test_messages(self, mock_get_message, mock_basic_message_init): # Check that BasicMessage.__init__ was called twice assert mock_basic_message_init.call_count == 2 - # Check that _get_message was called with the correct message IDs - mock_get_message.assert_has_calls([ - call("msg1"), - call("msg2") - ], any_order=True) - # Check that BasicMessage.__init__ was called with the results of _get_message for call_args in mock_basic_message_init.call_args_list: assert isinstance(call_args[0][0], Mock) - assert call_args[0][0].id in ["msg1", "msg2"] + assert call_args[0][1] in ["msg1", "msg2"] def test_annotations_property(self): mock_annotations = Mock(iterable=["annotation1", "annotation2"]) @@ -91,3 +84,7 @@ def test_edit_conversation(self, mock_pieces_client): BasicChat._edit_conversation(mock_conversation) mock_pieces_client.conversation_api.conversation_update.assert_called_once_with(False, mock_conversation) + +if __name__ == '__main__': + pytest.main([__file__]) + diff --git a/tests/test_basic_copilot.py b/tests/test_basic_copilot.py index ea49bd0..f535229 100644 --- a/tests/test_basic_copilot.py +++ b/tests/test_basic_copilot.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch from pieces_os_client import ( QGPTStreamEnum, + RelevantQGPTSeeds ) from pieces_os_client.wrapper.websockets import AskStreamWS from pieces_os_client.wrapper.copilot import Copilot @@ -16,16 +17,15 @@ def setUp(self): self.mock_client.model_id = "mock_model_id" self.mock_client.qgpt_api = Mock() - # Define a real BasicChat class for testing - global BasicChat - class BasicChat: - def __init__(self, id): - self.id = id - self.copilot = Copilot(self.mock_client) - # Mock ConversationsSnapshot - self.mock_conversations = patch('__main__.ConversationsSnapshot.identifiers_snapshot', {"test_conversation_id": Mock()}).start() + self.mock_conversation = Mock() + self.mock_conversation.id = "test_conversation_id" + self.mock_conversations_snapshot = patch( + 'pieces_os_client.wrapper.streamed_identifiers.conversations_snapshot.ConversationsSnapshot.identifiers_snapshot', + {"test_conversation_id": self.mock_conversation} + ).start() + self.addCleanup(patch.stopall) def tearDown(self): patch.stopall() @@ -37,22 +37,28 @@ def test_init(self): self.assertIsInstance(self.copilot.ask_stream_ws, AskStreamWS) self.assertIsNone(self.copilot._chat) - @patch('__main__.AskStreamWS') + @patch('pieces_os_client.wrapper.websockets.AskStreamWS') def test_ask(self, mock_ask_stream_ws): + conversation_id = "test_conversation_id" query = "Test query" - mock_output = Mock(status=QGPTStreamEnum.COMPLETED, conversation="test_conversation_id", text="Test response") + mock_output = Mock(status=QGPTStreamEnum.COMPLETED, conversation=conversation_id, text="Test response") self.copilot._on_message_queue.put(mock_output) # Create a mock for send_message mock_send_message = Mock() self.copilot.ask_stream_ws.send_message = mock_send_message + self.copilot.context._relevance_api = lambda query: RelevantQGPTSeeds(iterable=[])# Mock the contexts for now + result = list(self.copilot.stream_question(query)) - result = list(self.copilot.ask(query)) - + # Mock the conversation created + mock_conversation = Mock() + mock_conversation.name = "Test Conversation" + mock_conversation.id = conversation_id + ConversationsSnapshot.identifiers_snapshot = {conversation_id: mock_conversation} self.assertEqual(len(result), 1) self.assertEqual(result[0], mock_output) self.assertIsInstance(self.copilot.chat, BasicChat) - self.assertEqual(self.copilot.chat.id, "test_conversation_id") + self.assertEqual(self.copilot.chat.id, conversation_id) # Assert that send_message was called once mock_send_message.assert_called_once() @@ -84,3 +90,9 @@ def test_chat_property(self): with self.assertRaises(ValueError): self.copilot.chat = "invalid_chat" + self.copilot.chat = None + self.assertEqual(self.copilot.chat, None) + +if __name__ == '__main__': + unittest.main() + diff --git a/tests/test_basic_message.py b/tests/test_basic_message.py index 7e48c6a..d2009e5 100644 --- a/tests/test_basic_message.py +++ b/tests/test_basic_message.py @@ -34,11 +34,11 @@ def test_init_invalid_id(self): def test_raw_property(self): message = BasicMessage(self.mock_pieces_client, "test_message_id") - assert message.raw == "Test message content" + assert message.raw_content == "Test message content" def test_raw_setter(self): message = BasicMessage(self.mock_pieces_client, "test_message_id") - message.raw = "New content" + message.raw_content = "New content" assert message.message.fragment.string.raw == "New content" self.mock_pieces_client.conversation_message_api.message_update_value.assert_called_once_with( False, message.message @@ -63,50 +63,53 @@ def test_annotations_property_none(self): message = BasicMessage(self.mock_pieces_client, "test_message_id") assert message.annotations is None - def test_annotations_property_with_annotations(self): - mock_annotation = Mock() - mock_annotation.id = "test_annotation_id" - self.mock_message.annotations = Mock(iterable=[Mock(id="test_annotation_id")]) - self.mock_pieces_client.annotation_api.annotation_specific_annotation_snapshot.return_value = mock_annotation - - message = BasicMessage(self.mock_pieces_client, "test_message_id") - annotations = message.annotations - - assert annotations is not None - assert isinstance(annotations, Annotations) - assert len(annotations.iterable) == 1 - assert annotations.iterable[0].id == "test_annotation_id" - def test_description_property_no_annotations(self): message = BasicMessage(self.mock_pieces_client, "test_message_id") assert message.description is None - def test_description_property_with_description(self): - # Create a mock annotation with the correct structure - mock_annotation = Mock() - mock_annotation.type = AnnotationTypeEnum.DESCRIPTION - mock_annotation.text = "Test description" + # Did not see a message that have an annotation before but it still in the model.. + # def test_annotations_property_with_annotations(self): + # mock_annotation = Mock() + # mock_annotation.id = "test_annotation_id" + # self.mock_message.annotations = Mock(iterable=[Mock(id="test_annotation_id")]) + # self.mock_pieces_client.annotation_api.annotation_specific_annotation_snapshot.return_value = mock_annotation - # Set up the mock message with annotations - self.mock_message.annotations = Mock(iterable=[Mock(id="test_annotation_id")]) + # message = BasicMessage(self.mock_pieces_client, "test_message_id") + # annotations = message.annotations - # Mock the annotation API call - self.mock_pieces_client.annotation_api.annotation_specific_annotation_snapshot.return_value = mock_annotation + # assert annotations is not None + # assert isinstance(annotations, Annotations) + # assert len(annotations.iterable) == 1 + # assert annotations.iterable[0].id == "test_annotation_id" - # Create the BasicMessage instance - message = BasicMessage(self.mock_pieces_client, "test_message_id") - # Replace the message's annotations with our mocked annotations - message.message.annotations = self.mock_message.annotations - # Now test the description property - description = message.description - print("Returned description:", description) + # def test_description_property_with_description(self): + # # Create a mock annotation with the correct structure + # mock_annotation = Mock() + # mock_annotation.type = AnnotationTypeEnum.DESCRIPTION + # mock_annotation.text = "Test description" + + # # Set up the mock message with annotations + # self.mock_message.annotations = Mock(iterable=[Mock(id="test_annotation_id")]) + + # # Mock the annotation API call + # self.mock_pieces_client.annotation_api.annotation_specific_annotation_snapshot.return_value = mock_annotation - assert description == "Test description" + # # Create the BasicMessage instance + # message = BasicMessage(self.mock_pieces_client, "test_message_id") - # Verify that the annotation API was called - self.mock_pieces_client.annotation_api.annotation_specific_annotation_snapshot.assert_called_once_with("test_annotation_id") + # # Replace the message's annotations with our mocked annotations + # message.message.annotations = self.mock_message.annotations + + # # Now test the description property + # description = message.description + # print("Returned description:", description) + + # assert description == "Test description" + + # # Verify that the annotation API was called + # self.mock_pieces_client.annotation_api.annotation_specific_annotation_snapshot.assert_called_once_with("test_annotation_id") def test_repr(self): message = BasicMessage(self.mock_pieces_client, "test_message_id") @@ -154,3 +157,8 @@ def test_str(self): def test_hash(self): message = BasicMessage(self.mock_pieces_client, "test_message_id") assert hash(message) == hash("test_message_id") + + +if __name__ == '__main__': + pytest.main([__file__]) + diff --git a/tests/test_basic_user.py b/tests/test_basic_user.py new file mode 100644 index 0000000..246a391 --- /dev/null +++ b/tests/test_basic_user.py @@ -0,0 +1,70 @@ +import unittest +from unittest.mock import MagicMock, patch +from pieces_os_client import UserProfile, AllocationStatusEnum +from pieces_os_client.wrapper.basic_identifier import BasicUser + +class BasicUserTest(unittest.TestCase): + def setUp(self): + self.mock_pieces_client = MagicMock() + self.basic_user = BasicUser(self.mock_pieces_client) + self.mock_user_profile = MagicMock(spec=UserProfile) + self.basic_user.user_profile = self.mock_user_profile + + def test_login(self): + self.mock_pieces_client.os_api.sign_into_os.return_value = MagicMock() + self.basic_user.login(connect_after_login=False) + self.mock_pieces_client.os_api.sign_into_os.assert_called_once_with(async_req=True) + + def test_login_and_connect(self): + self.mock_pieces_client.os_api.sign_into_os.return_value = MagicMock() + with patch('threading.Thread') as mock_thread: + self.basic_user.login(connect_after_login=True) + mock_thread.assert_called_once() + + def test_logout(self): + self.basic_user.logout() + self.mock_pieces_client.api_client.os_api.sign_out_of_os.assert_called_once() + + def test_connect(self): + self.basic_user.connect() + self.mock_pieces_client.allocations_api.allocations_connect_new_cloud.assert_called_once_with(self.mock_user_profile) + + def test_connect_without_login(self): + self.basic_user.user_profile = None + with self.assertRaises(PermissionError): + self.basic_user.connect() + + def test_disconnect(self): + self.mock_user_profile.allocation = MagicMock() + self.basic_user.disconnect() + self.mock_pieces_client.api_client.allocations_api.allocations_disconnect_cloud.assert_called_once_with(self.mock_user_profile.allocation) + + def test_disconnect_without_login(self): + self.basic_user.user_profile = None + with self.assertRaises(PermissionError): + self.basic_user.disconnect() + + def test_picture_property(self): + self.mock_user_profile.picture = "http://example.com/picture.jpg" + self.assertEqual(self.basic_user.picture, "http://example.com/picture.jpg") + + def test_name_property(self): + self.mock_user_profile.name = "John Doe" + self.assertEqual(self.basic_user.name, "John Doe") + + def test_email_property(self): + self.mock_user_profile.email = "john.doe@example.com" + self.assertEqual(self.basic_user.email, "john.doe@example.com") + + def test_vanity_name_property(self): + self.mock_user_profile.vanityname = "johnatpieces" + self.assertEqual(self.basic_user.vanity_name, "johnatpieces") + + def test_cloud_status_property(self): + mock_allocation = MagicMock() + mock_allocation.status.cloud = AllocationStatusEnum.RUNNING # Use a valid enum value + self.mock_user_profile.allocation = mock_allocation + self.assertEqual(self.basic_user.cloud_status, AllocationStatusEnum.RUNNING) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_version_compatibility.py b/tests/test_version_compatibility.py new file mode 100644 index 0000000..859c66c --- /dev/null +++ b/tests/test_version_compatibility.py @@ -0,0 +1,54 @@ +import unittest +from pieces_os_client.wrapper.version_compatibility import VersionChecker,UpdateEnum + +class TestVersionChecker(unittest.TestCase): + def test_version_in_range(self): + checker = VersionChecker("1.0.0", "2.0.0", "1.5.0") + result = checker.version_check() + self.assertTrue(result.compatible) + self.assertIsNone(result.update) + + def test_version_below_minimum(self): + checker = VersionChecker("1.0.0", "2.0.0", "0.9.0") + result = checker.version_check() + self.assertFalse(result.compatible) + self.assertEqual(result.update, UpdateEnum.PiecesOS) + + def test_version_above_maximum(self): + checker = VersionChecker("1.0.0", "2.0.0", "2.1.0") + result = checker.version_check() + self.assertFalse(result.compatible) + self.assertEqual(result.update, UpdateEnum.Plugin) + + def test_version_at_minimum(self): + checker = VersionChecker("1.0.0", "2.0.0", "1.0.0") + result = checker.version_check() + self.assertTrue(result.compatible) + self.assertIsNone(result.update) + + def test_version_at_maximum(self): + checker = VersionChecker("1.0.0", "2.0.0", "2.0.0") + result = checker.version_check() + self.assertFalse(result.compatible) + self.assertEqual(result.update,UpdateEnum.Plugin) + + def test_version_with_pre_release(self): + checker = VersionChecker("1.0.0", "2.0.0", "1.0.0-alpha") + result = checker.version_check() + self.assertTrue(result.compatible) + self.assertIsNone(result.update) + + def test_version_below_minimum_with_pre_release(self): + checker = VersionChecker("1.0.0", "2.0.0", "0.9.0-beta") + result = checker.version_check() + self.assertFalse(result.compatible) + self.assertEqual(result.update, UpdateEnum.PiecesOS) + + def test_version_above_maximum_with_pre_release(self): + checker = VersionChecker("1.0.0", "2.0.0", "2.1.0-beta") + result = checker.version_check() + self.assertFalse(result.compatible) + self.assertEqual(result.update, UpdateEnum.Plugin) + +if __name__ == "__main__": + unittest.main()