Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enma-api support #14

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions shimeji/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class ModelSampleArgs(BaseModel):
bad_words: Optional[List[str]] = None
logit_biases: Optional[List[ModelLogitBiasArgs]] = None
phrase_biases: Optional[List[ModelPhraseBiasArgs]] = None
#enma specific
do_sample: Optional[bool] = None
penalty_alpha: Optional[float] = None
num_return_sequences: Optional[int] = None
stop_sequence: Optional[str] = None


def toJSON(self):
return json.dumps(self.dict())
Expand Down Expand Up @@ -458,6 +464,196 @@ async def response_async(self, context):
response = await self.generate_async(args)
return response

class Enma_ModelProvider(ModelProvider):
def __init__(self, endpoint_url: str, **kwargs):
"""Constructor for Enma_ModelProvider.

:param endpoint_url: The URL for the Enma endpoint. (this is the completion endpoint on the gateway!)
:type endpoint_url: str
"""

super().__init__(endpoint_url, **kwargs)

def auth(self):
"""
:drollwide:\n
enma doesnt have authentication (at least on fab8e60) so this just returns true
"""

return True
def conv_listobj_to_listdict(self, list_objects):
"""Convert the elements of a list to a dictionary for JSON compatability.

:param list_objects: The list.
:type list_objects: list
:return: A list which has it's elements converted to dictionaries.
:rtype: list
"""

list_dict = []
if list_objects:
for object in list_objects:
list_dict.append(vars(object))
return list_dict
else:
return list_objects

def generate(self, args: ModelGenRequest):
"""Generate a response from the Enma endpoint.

:param args: The arguments to pass to the endpoint.
:type args: dict
:return: The response from the endpoint.
:rtype: str
:raises Exception: If the request fails.
"""
argdict = {
'engine': args.model, #enma uses engine instead of model
'prompt': args.prompt,
'temperature': args.sample_args.temp,
'top_p': args.sample_args.top_p,
'top_k': args.sample_args.top_k,
'repetition_penalty': args.sample_args.rep_p,
'do_sample': args.sample_args.do_sample,
'penalty_alpha': args.sample_args.penalty_alpha,
'num_return_sequences': args.sample_args.num_return_sequences,
'stop_sequence': args.sample_args.stop_sequence,
}

for arg in argdict.values():
if arg is None:
raise Exception('Missing required argument: ' + arg)
try:
r = requests.post(f'{self.endpoint_url}', data=json.dumps(argdict))
except Exception as e:
raise e
if r.status_code == 200:
return r.json()[0]['generated_text'][len(argdict['prompt']):]
else:
raise Exception(f'Could not generate text with Enma. Error: {r.json()}')

async def generate_async(self, args: ModelGenRequest):
"""Generate a response from the Enma endpoint asynchronously.

:param args: The arguments to pass to the endpoint.
:type args: dict
:return: The response from the endpoint.
:rtype: str
:raises Exception: If the request fails.
"""

argdict = {
'engine': args.model, #enma uses engine instead of model
'prompt': args.prompt,
'temperature': args.sample_args.temp,
'top_p': args.sample_args.top_p,
'top_k': args.sample_args.top_k,
'repetition_penalty': args.sample_args.rep_p,
'do_sample': args.sample_args.do_sample,
'penalty_alpha': args.sample_args.penalty_alpha,
'num_return_sequences': args.sample_args.num_return_sequences,
'stop_sequence': args.sample_args.stop_sequence,
}
for arg in argdict.values():
if arg is None:
raise Exception('Missing required argument: ' + arg)
async with aiohttp.ClientSession() as session:
try:
async with session.post(f'{self.endpoint_url}', json=argdict) as resp:
if resp.status == 200:
js = await resp.json()
return js[0]['generated_text'][len(argdict['prompt']):]
else:
raise Exception(f'Could not generate response. Error: {await resp.text()}')
except Exception as e:
raise e

def should_respond(self, context, name):
"""Determine if the Enma endpoint predicts that the name should respond to the given context.

:param context: The context to use.
:type context: str
:param name: The name to check.
:type name: str
:return: Whether or not the name should respond to the given context.
:rtype: bool
"""


args = copy.deepcopy(self.kwargs['args'])
args.prompt = context
args.sample_args.temp = 0.25
args.sample_args.top_p = 0.9
args.sample_args.top_k = 40
args.sample_args.rep_p = None
args.sample_args.do_sample = None
args.sample_args.penalty_alpha = None
args.sample_args.num_return_sequences = None #i have no idea what these should be
args.sample_args.stop_sequence = None
response = self.generate(args)
if name in response:
return True
else:
return False

async def should_respond_async(self, context, name):
"""Determine if the Enma endpoint predicts that the name should respond to the given context asynchronously.

:param context: The context to use.
:type context: str
:param name: The name to check.
:type name: str
:return: Whether or not the name should respond to the given context.
:rtype: bool
"""

args = copy.deepcopy(self.kwargs['args'])
args.prompt = context
args.sample_args.temp = 0.25
args.sample_args.top_p = 0.9
args.sample_args.top_k = 40
args.sample_args.rep_p = None
args.sample_args.do_sample = None
args.sample_args.penalty_alpha = None
args.sample_args.num_return_sequences = None #i have no idea what these should be
args.sample_args.stop_sequence = None
response = await self.generate_async(args)
if response.startswith(name):
return True
else:
return False

def response(self, context):
"""Generate a response from the Enma endpoint.

:param context: The context to use.
:type context: str
:return: The response from the endpoint.
:rtype: str
"""
args = self.kwargs['args']
args.prompt = context
args.gen_args.eos_token_id = 198
args.gen_args.min_length = 1
response = self.generate(args)
return response

async def response_async(self, context):
"""Generate a response from the Enma endpoint asynchronously.

:param context: The context to use.
:type context: str
:return: The response from the endpoint.
:rtype: str
"""
args = self.kwargs['args']
args.prompt = context
args.gen_args.eos_token_id = 198
args.gen_args.min_length = 1
response = await self.generate_async(args)
return response


class TextSynth_ModelProvider(ModelProvider):
def __init__(self, endpoint_url: str = 'https://api.textsynth.com', **kwargs):
"""Constructor for TextSynth_ModelProvider.
Expand Down