Skip to content

Commit

Permalink
Implemented the banned_words feature.
Browse files Browse the repository at this point in the history
It is now possible to send a list of banned words to the generator. It
will not generate theses words. It allows clients to send their own list
of banned words.
  • Loading branch information
Lyaaaaaaaaaaaaaaa committed May 7, 2024
1 parent 92e52a9 commit 966a8eb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
18 changes: 15 additions & 3 deletions server/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
#-- - 31/01/2024 Lyaaaaa
#-- - generate_text now longer receives memory and context as parameters.
#-- They are embedded in the prompt parameter by the client.
#--
#-- - 07/05/2024 Lyaaaaa
#-- - Updated generate_text to now be able to censor generation. The words
#-- passed in p_banned_words parameters won't be generated anymore.
#------------------------------------------------------------------------------

from model import Model
Expand All @@ -48,16 +52,24 @@
import logger

class Generator(Model):

#------------------------------------------------------------------------------
#-- generate_text
#------------------------------------------------------------------------------
def generate_text(self,
p_prompt = None,
p_parameters = None):
p_prompt = None,
p_parameters = None,
p_banned_words = []):

model_input = self._Tokenizer(p_prompt, return_tensors = "pt")

if p_banned_words:
banned_words_ids = self._Tokenizer(
p_banned_words,
add_special_tokens=False
).input_ids

p_parameters["bad_words_ids"] = banned_words_ids

if self.is_cuda_available:
logger.log.info("Loading inputs to GPU")
model_input.to("cuda")
Expand Down
9 changes: 8 additions & 1 deletion server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@
#-- - p_model_path is now the second parameter of __init__. p_parameters the third.
#-- - Added a log message to display the model's name and its path.
#-- - Added a log message to display if cuda is supported.
#--
#-- - 07/05/2024 Lyaaaaa
#-- - Updated _load_tokens to set add_prefix_space to True. It is needed
#-- for using bad_words_ids parameter for generation.
#------------------------------------------------------------------------------

from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
Expand Down Expand Up @@ -296,7 +300,10 @@ def _load(self):
#------------------------------------------------------------------------------
def _load_tokens(self):
try:
self._Tokenizer = AutoTokenizer.from_pretrained(self._model_path)
self._Tokenizer = AutoTokenizer.from_pretrained(
self._model_path,
add_prefix_space=True
)
except Exception as e:
logger.log.error("Error loading tokens in " + self._model_path)
logger.log.error(e)
Expand Down
11 changes: 8 additions & 3 deletions server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@
#-- - 31/01/2024 Lyaaaaa
#-- - generate_text now longer receives memory and context as parameters.
#-- They are embedded in the prompt parameter by the client.
#--
#-- - 07/05/2024 Lyaaaaa
#-- - Updated handle_request and generation case to receive a banned_words
#-- parameter and pass it to generator.generate_text
#------------------------------------------------------------------------------

import asyncio
Expand Down Expand Up @@ -177,10 +181,11 @@ def handle_request(p_websocket, p_data : dict):
request = p_data['request']

if request == Request.TEXT_GENERATION.value:
prompt = p_data['prompt']
parameters = p_data['parameters']
prompt = p_data['prompt']
parameters = p_data['parameters']
banned_words = p_data['banned_words']

generated_text = generator.generate_text(prompt, parameters)
generated_text = generator.generate_text(prompt, parameters, banned_words)

p_data["generated_text"] = generated_text

Expand Down

0 comments on commit 966a8eb

Please sign in to comment.