Skip to content

Commit

Permalink
Update of state handlers
Browse files Browse the repository at this point in the history
No need to create state handlers
  • Loading branch information
coder2020official committed Oct 1, 2021
1 parent 4a6b5b3 commit 2e4280a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 93 deletions.
42 changes: 28 additions & 14 deletions examples/custom_states.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,48 @@
import telebot

from telebot.handler_backends import State
from telebot import custom_filters

bot = telebot.TeleBot("")



@bot.message_handler(commands=['start'])
def start_ex(message):
bot.set_state(message.chat.id, 1)
bot.send_message(message.chat.id, 'Hi, write me a name')



@bot.state_handler(state=1)
def name_get(message, state:State):
@bot.message_handler(state="*", commands='cancel')
def any_state(message):
bot.send_message(message.chat.id, "Your state was cancelled.")
bot.delete_state(message.chat.id)

@bot.message_handler(state=1)
def name_get(message):
bot.send_message(message.chat.id, f'Now write me a surname')
state.set(message.chat.id, 2)
with state.retrieve_data(message.chat.id) as data:
bot.set_state(message.chat.id, 2)
with bot.retrieve_data(message.chat.id) as data:
data['name'] = message.text


@bot.state_handler(state=2)
def ask_age(message, state:State):
@bot.message_handler(state=2)
def ask_age(message):
bot.send_message(message.chat.id, "What is your age?")
state.set(message.chat.id, 3)
with state.retrieve_data(message.chat.id) as data:
bot.set_state(message.chat.id, 3)
with bot.retrieve_data(message.chat.id) as data:
data['surname'] = message.text

@bot.state_handler(state=3)
def ready_for_answer(message, state: State):
with state.retrieve_data(message.chat.id) as data:
@bot.message_handler(state=3, is_digit=True)
def ready_for_answer(message):
with bot.retrieve_data(message.chat.id) as data:
bot.send_message(message.chat.id, "Ready, take a look:\n<b>Name: {name}\nSurname: {surname}\nAge: {age}</b>".format(name=data['name'], surname=data['surname'], age=message.text), parse_mode="html")
state.finish(message.chat.id)
bot.delete_state(message.chat.id)

@bot.message_handler(state=3, is_digit=False)
def age_incorrect(message):
bot.send_message(message.chat.id, 'Looks like you are submitting a string in the field age. Please enter a number')

bot.infinity_polling()
bot.add_custom_filter(custom_filters.StateFilter(bot))
bot.add_custom_filter(custom_filters.IsDigitFilter())
bot.infinity_polling(skip_pending=True)
97 changes: 18 additions & 79 deletions telebot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,6 @@ def process_new_updates(self, updates):
def process_new_messages(self, new_messages):
self._notify_next_handlers(new_messages)
self._notify_reply_handlers(new_messages)
self._notify_state_handlers(new_messages)
self.__notify_update(new_messages)
self._notify_command_handlers(self.message_handlers, new_messages)

Expand Down Expand Up @@ -2386,6 +2385,9 @@ def delete_state(self, chat_id):
"""
self.current_states.delete_state(chat_id)

def retrieve_data(self, chat_id):
return self.current_states.retrieve_data(chat_id)

def get_state(self, chat_id):
"""
Get current state of a user.
Expand All @@ -2394,6 +2396,14 @@ def get_state(self, chat_id):
"""
return self.current_states.current_state(chat_id)

def add_data(self, chat_id, **kwargs):
"""
Add data to states.
:param chat_id:
"""
for key, value in kwargs.items():
self.current_states._add_data(chat_id, key, value)

def register_next_step_handler_by_chat_id(
self, chat_id: Union[int, str], callback: Callable, *args, **kwargs) -> None:
"""
Expand Down Expand Up @@ -2459,32 +2469,6 @@ def _notify_next_handlers(self, new_messages):
new_messages.pop(i) # removing message that was detected with next_step_handler


def _notify_state_handlers(self, new_messages):
"""
Description: TBD
:param new_messages:
:return:
"""
if not self.current_states: return

for i, message in enumerate(new_messages):
need_pop = False
user_state = self.current_states.current_state(message.from_user.id)
if user_state:
for handler in self.state_handlers:
if handler['filters']['state'] == user_state:
for message_filter, filter_value in handler['filters'].items():
if filter_value is None:
continue
if not self._test_filter(message_filter, filter_value, message):
return False
need_pop = True
state = self.current_states
self._exec_task(handler["function"], message, state)
if need_pop:
new_messages.pop(i) # removing message that was detected by states


@staticmethod
def _build_handler_dict(handler, **filters):
"""
Expand Down Expand Up @@ -2548,7 +2532,7 @@ def add_middleware_handler(self, handler, update_types=None):
else:
self.default_middleware_handlers.append(handler)

def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, **kwargs):
def message_handler(self, commands=None, regexp=None, func=None, content_types=None, chat_types=None, state=None, **kwargs):
"""
Message handler decorator.
This decorator can be used to decorate functions that must handle certain types of messages.
Expand Down Expand Up @@ -2591,6 +2575,9 @@ def default_command(message):
if content_types is None:
content_types = ["text"]

if type(state) is not list and state is not None:
state = [state]

if isinstance(commands, str):
logger.warning("message_handler: 'commands' filter should be List of strings (commands), not string.")
commands = [commands]
Expand All @@ -2605,6 +2592,7 @@ def decorator(handler):
content_types=content_types,
commands=commands,
regexp=regexp,
state=state,
func=func,
**kwargs)
self.add_message_handler(handler_dict)
Expand All @@ -2620,7 +2608,7 @@ def add_message_handler(self, handler_dict):
"""
self.message_handlers.append(handler_dict)

def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, **kwargs):
def register_message_handler(self, callback, content_types=None, commands=None, regexp=None, func=None, chat_types=None, state=None, **kwargs):
"""
Registers message handler.
:param callback: function to be called
Expand All @@ -2644,6 +2632,7 @@ def register_message_handler(self, callback, content_types=None, commands=None,
content_types=content_types,
commands=commands,
regexp=regexp,
state=state,
func=func,
**kwargs)
self.add_message_handler(handler_dict)
Expand Down Expand Up @@ -2721,54 +2710,6 @@ def register_edited_message_handler(self, callback, content_types=None, commands
self.add_edited_message_handler(handler_dict)


def state_handler(self, state, content_types=None, regexp=None, func=None, chat_types=None, **kwargs):
"""
State handler for getting input from a user.
:param state: state of a user
:param content_types:
:param regexp:
:param func:
:param chat_types:
"""
def decorator(handler):
handler_dict = self._build_handler_dict(handler,
state=state,
content_types=content_types,
regexp=regexp,
chat_types=chat_types,
func=func,
**kwargs)
self.add_state_handler(handler_dict)
return handler

return decorator

def add_state_handler(self, handler_dict):
"""
Adds the edit message handler
:param handler_dict:
:return:
"""
self.state_handlers.append(handler_dict)

def register_state_handler(self, callback, state, content_types=None, regexp=None, func=None, chat_types=None, **kwargs):
"""
Register a state handler.
:param callback: function to be called
:param state: state to be checked
:param content_types:
:param func:
"""
handler_dict = self._build_handler_dict(callback=callback,
state=state,
content_types=content_types,
regexp=regexp,
chat_types=chat_types,
func=func,
**kwargs)
self.add_state_handler(handler_dict)


def channel_post_handler(self, commands=None, regexp=None, func=None, content_types=None, **kwargs):
"""
Channel post handler decorator
Expand Down Expand Up @@ -3251,8 +3192,6 @@ def _test_filter(self, message_filter, filter_value, message):
return filter_value(message)
elif self.custom_filters and message_filter in self.custom_filters:
return self._check_filter(message_filter,filter_value,message)
elif message_filter == 'state':
return True
else:
return False

Expand Down
27 changes: 27 additions & 0 deletions telebot/custom_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,30 @@ def __init__(self, bot):
def check(self, message):
return self._bot.get_chat_member(message.chat.id, message.from_user.id).status in ['creator', 'administrator']

class StateFilter(AdvancedCustomFilter):
"""
Filter to check state.
Example:
@bot.message_handler(state=1)
"""
def __init__(self, bot):
self.bot = bot
key = 'state'

def check(self, message, text):
if self.bot.current_states.current_state(message.from_user.id) is False:return False
elif '*' in text:return True
return self.bot.current_states.current_state(message.from_user.id) in text

class IsDigitFilter(SimpleCustomFilter):
"""
Filter to check whether the string is made up of only digits.
Example:
@bot.message_handler(is_digit=True)
"""
key = 'is_digit'

def check(self, message):
return message.text.isdigit()
4 changes: 4 additions & 0 deletions telebot/handler_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def set(self, chat_id, new_state):
"""
self.add_state(chat_id,new_state)

def _add_data(self, chat_id, key, value):
result = self._states[chat_id]['data'][key] = value
return result

def finish(self, chat_id):
"""
Finish(delete) state of a user.
Expand Down

0 comments on commit 2e4280a

Please sign in to comment.