Skip to content

Commit

Permalink
pass on tinygrad set_on_download_progress
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexCheema committed Jul 31, 2024
1 parent d6a7e46 commit 1d54f10
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
6 changes: 5 additions & 1 deletion exo/inference/inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from typing import Tuple, Optional
from typing import Tuple, Optional, Callable
from abc import ABC, abstractmethod
from .shard import Shard

Expand All @@ -13,3 +13,7 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr
@abstractmethod
async def infer_prompt(self, request_id: str, shard: Shard, prompt: str, inference_state: Optional[str] = None) -> Tuple[np.ndarray, str, bool]:
pass

@abstractmethod
def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
pass
5 changes: 4 additions & 1 deletion exo/inference/tinygrad/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from functools import partial
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional, Union, Callable
import json
import tiktoken
from tiktoken.load import load_tiktoken_bpe
Expand Down Expand Up @@ -294,3 +294,6 @@ async def ensure_shard(self, shard: Shard):
self.shard = shard
self.model = model
self.tokenizer = tokenizer

def set_on_download_progress(self, on_download_progress: Callable[[int, int], None]):
pass

0 comments on commit 1d54f10

Please sign in to comment.