diff --git a/TTS/tts/layers/xtts/stream_generator.py b/TTS/tts/layers/xtts/stream_generator.py index 7921102cb2..b7e07589c5 100644 --- a/TTS/tts/layers/xtts/stream_generator.py +++ b/TTS/tts/layers/xtts/stream_generator.py @@ -885,10 +885,10 @@ def init_stream_support(): if __name__ == "__main__": - from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel + from transformers import AutoModelForCausalLM, AutoTokenizer + + init_stream_support() - PreTrainedModel.generate = NewGenerationMixin.generate - PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.float16) tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")