Skip to content

Commit

Permalink
Merge pull request #186 from Bishoy-at-pieces/fix-models
Browse files Browse the repository at this point in the history
fix: list models command
  • Loading branch information
bishoy-at-pieces authored Sep 30, 2024
2 parents 144f47b + dce2399 commit 7bd73c0
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/pieces/autocommit/autocommit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_current_working_changes() -> Optional[Tuple[str, "Seeds"]]:
List of seeded asset to be input to the relevance
"""
from pieces_os_client.models.seed import Seed
from pieces_os_client.models.seeds import Seeds
from pieces_os_client.models.seeded_asset import SeededAsset
from pieces_os_client.models.seeded_asset_metadata import SeededAssetMetadata
from pieces_os_client.models.seeded_format import SeededFormat
Expand Down Expand Up @@ -223,7 +224,7 @@ def get_commit_message(changes_summary,seeds):
query=message_prompt,
seeds=seeds,
application=Settings.pieces_client.application.id,
model=Settings.get_model(),
model=Settings.get_model_id(),
options=QGPTRelevanceInputOptions(question=True)
)).answer.answers.iterable[0].text

Expand Down
34 changes: 26 additions & 8 deletions src/pieces/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class Settings:
pieces_data_dir, "model_data.pkl"
) # model data file just store the model_id that the user is using (eg. {"model_id": UUID })

file_cache = {}

config_file = Path(pieces_data_dir, "pieces_config.json")

run_in_loop = False # is CLI looping?
Expand All @@ -53,21 +55,37 @@ def get_model(cls):
return cls._model_name

model_id = cls.get_from_pickle(cls.models_file,"model_id")
models_reverse = {v:k for k,v in cls.pieces_client.get_models().items()}
cls._model_name = models_reverse.get(model_id)
if model_id:
models_reverse = {v:k for k,v in cls.pieces_client.get_models().items()}
cls._model_name = models_reverse.get(model_id)
else:
cls._model_name = cls.pieces_client.model_name

try:
cls.pieces_client.model_name = cls._model_name
except ValueError:
return cls.pieces_client.model_name
return cls._model_name
return cls._model_name

@classmethod
def get_model_id(cls):
"""
Retrives the model id from the saved file
"""
cls.pieces_client.model_name # Let's load the models first
return cls.get_from_pickle(cls.models_file,"model_id") or cls.pieces_client.model_id

@staticmethod
def get_from_pickle(file,key):
with open(file, 'rb') as f:
data = pickle.load(f)
return data.get(key)
@classmethod
def get_from_pickle(cls, file, key):
try:
cache = cls.file_cache.get(str(file))
if not cache:
with open(file, 'rb') as f:
cache = pickle.load(f)
cls.file_cache[str(file)] = cache
return cache.get(key)
except FileNotFoundError:
return None

@staticmethod
def dump_pickle(file,**data):
Expand Down
2 changes: 1 addition & 1 deletion src/pieces/wrapper/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def application(self) -> "Application":
if not self._application:
self._application = self.connector_api.connect(seeded_connector_connection=SeededConnectorConnection(
application=SeededTrackedApplication(
name = "OPEN_SOURCE",
name = "PIECES_FOR_DEVELOPERS_CLI",
platform = self.local_os,
version = __version__))).application
self.api_client.set_default_header("application",self._application.id)
Expand Down

0 comments on commit 7bd73c0

Please sign in to comment.