diff --git a/anonipy/anonymize/generators/llm_label_generator.py b/anonipy/anonymize/generators/llm_label_generator.py index 9846ec8..42e0ab8 100644 --- a/anonipy/anonymize/generators/llm_label_generator.py +++ b/anonipy/anonymize/generators/llm_label_generator.py @@ -149,7 +149,7 @@ def _generate_response( # tokenize the message input_ids = self.tokenizer.apply_chat_template( - message, tokenize=True, return_tensors="pt" + message, tokenize=True, return_tensors="pt", add_generation_prompt=True ).to(self.model.device) # generate the response @@ -166,18 +166,4 @@ def _generate_response( response = self.tokenizer.decode( output_ids[0][len(input_ids[0]) :], skip_special_tokens=True ) - return self._parse_response(response) - - def _parse_response(self, response: str) -> str: - """Parse the response from the LLM. - - Args: - response: The response to parse. - - Returns: - The parsed response. - - """ - - match = re.search(r"assistant\s*(.*)", response, re.IGNORECASE | re.DOTALL) - return match.group(1).strip() if match else response + return response