Skip to content

Commit

Permalink
store user's locale in DB
Browse files Browse the repository at this point in the history
  • Loading branch information
wetterkrank committed Jun 24, 2023
1 parent 668b7e3 commit 4e87647
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 25 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
.env*
__pycache__
.mypy_cache
.vscode
.vscode
tmp
18 changes: 11 additions & 7 deletions dasbot/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import logging
from aiogram import Bot

from aiogram import types
from aiogram.types import Message
from dasbot.db.chats_repo import ChatsRepo
from dasbot.db.stats_repo import StatsRepo
from dasbot.models.dictionary import Dictionary

from dasbot.models.quiz import Quiz
from .interface import Interface
Expand All @@ -9,20 +13,20 @@


class Controller(object):
def __init__(self, bot, chats_repo, stats_repo, dictionary):
def __init__(self, bot: Bot, chats_repo: ChatsRepo, stats_repo: StatsRepo, dictionary: Dictionary):
self.bot = bot
self.chats_repo = chats_repo
self.stats_repo = stats_repo
self.ui = Interface(bot)
self.dictionary = dictionary

# /help
async def help(self, message: types.Message):
async def help(self, message: Message):
await self.ui.reply_with_help(message)

# /start
async def start(self, message: types.Message):
chat = self.chats_repo.load_chat(message.chat)
async def start(self, message: Message):
chat = self.chats_repo.load_chat(message)
if not chat.last_seen:
await self.ui.welcome(chat)
scores = self.chats_repo.load_scores(chat.id)
Expand All @@ -34,12 +38,12 @@ async def start(self, message: types.Message):
self.chats_repo.save_chat(chat, update_last_seen=True)

# /stats
async def stats(self, message: types.Message, dictionary):
async def stats(self, message: Message, dictionary):
stats = self.stats_repo.get_stats(message.chat.id)
await self.ui.send_stats(message, stats, dictionary.wordcount())

# not-a-command
async def generic(self, message: types.Message):
async def generic(self, message: Message):
answer = message.text.strip().lower()
chat = self.chats_repo.load_chat(message.chat)
quiz = chat.quiz
Expand Down
31 changes: 22 additions & 9 deletions dasbot/db/chats_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from datetime import datetime
from datetime import timezone

from aiogram.types import Chat as TelegramChat
from aiogram.types import Message

from dasbot.models.chat import Chat, ChatSchema


log = logging.getLogger(__name__)


class ChatsRepo(object):

def __init__(self, chats_col, scores_col):
self._chats = chats_col
self._scores = scores_col
Expand All @@ -20,17 +20,27 @@ def __status(self):
log.info("%s chat(s) in DB" % self._chats.count_documents({}))
log.info("%s scores(s) in DB" % self._scores.count_documents({}))

def load_chat(self, tg_chat: TelegramChat):
# NOTE: Could make this more explicit by separating load_chat and create_chat methods
def load_chat(self, message: Message):
"""
:param chat: Telegram chat
:param message: Telegram message
:return: Chat instance, loaded from DB, or new if not found
"""
tg_chat = message.chat # NOTE: Chat may be a group etc and have many users
locale = message.from_user.locale if message.from_user else None
chat_data = self._chats.find_one({"chat_id": tg_chat.id}, {"_id": 0})
log.debug("requested chat %s, result: %s", tg_chat.id, chat_data)
if chat_data:
chat = ChatSchema().load(chat_data)
chat: Chat = ChatSchema().load(chat_data)
chat.user['last_used_locale'] = locale # locale may change over time and depend on device
else:
user = {'username': tg_chat.username, 'first_name': tg_chat.first_name, 'last_name': tg_chat.last_name}
user = {
'username': tg_chat.username,
'first_name': tg_chat.first_name,
'last_name': tg_chat.last_name,
'locale': locale,
'last_used_locale': locale,
}
chat = Chat(tg_chat.id, user)
return chat

Expand Down Expand Up @@ -65,9 +75,12 @@ def load_scores(self, chat_id):
"""
query = {"chat_id": chat_id}
results_cursor = self._scores.find(query, {"_id": 0})
scores = {item["word"]: (item["score"], item["revisit"])
for item in results_cursor}
log.debug("loaded all scores for chat %s, count: %s", chat_id, len(scores))
scores = {
item["word"]: (item["score"], item["revisit"])
for item in results_cursor
}
log.debug("loaded all scores for chat %s, count: %s", chat_id,
len(scores))
return scores

# TODO: check if saved successfully?
Expand Down
8 changes: 5 additions & 3 deletions dasbot/menu_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from aiogram.types import Message, CallbackQuery, InlineKeyboardMarkup, InlineKeyboardButton
from aiogram.utils.callback_data import CallbackData

from dasbot.db.chats_repo import ChatsRepo

log = logging.getLogger(__name__)


class MenuController(object):
def __init__(self, ui, chats_repo):
self.chats_repo = chats_repo
self.chats_repo: ChatsRepo = chats_repo
self.ui = ui
# NOTE: Can't use colon in callback actions, it's used as a separator
self.TIME_OPTIONS = ['0900', '1200', '1500', '1800', '2100', '0000', '0300', '0600']
Expand Down Expand Up @@ -76,7 +78,7 @@ def settings_kb(self, level, menu_id):

# TODO: Refactor into a generic function?
async def set_quiz_time(self, query, _level, selection):
chat = self.chats_repo.load_chat(query.message.chat)
chat = self.chats_repo.load_chat(query.message)
if selection == 'UNSUBSCRIBE':
chat.unsubscribe()
log.debug('Chat %s unsubscribed', chat.id)
Expand All @@ -91,7 +93,7 @@ async def set_quiz_time(self, query, _level, selection):
await self.settings_confirm(query, self.ui.quiz_time_set(selection))

async def set_quiz_length(self, query, _level, selection):
chat = self.chats_repo.load_chat(query.message.chat)
chat = self.chats_repo.load_chat(query.message)
new_length = int(selection) if selection in self.LENGTH_OPTIONS else 10
chat.quiz_length = new_length
self.chats_repo.save_chat(chat, update_last_seen=True)
Expand Down
9 changes: 6 additions & 3 deletions dasbot/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@


class Chat(object):
def __init__(self, chat_id, user=None, subscribed=True, last_seen=None, quiz=None,
def __init__(self, chat_id, user={}, subscribed=True, last_seen=None, quiz=None,
quiz_scheduled_time=None, quiz_length=None):
self.id = chat_id
self.user = user
self.user = user # our User is just a dictionary so far
self.subscribed = subscribed
self.last_seen = last_seen
self.quiz = quiz
Expand Down Expand Up @@ -52,18 +52,21 @@ def set_quiz_time(self, hhmm, skip_today=False):
now = datetime.now().astimezone(berlin)
self.quiz_scheduled_time = util.next_hhmm(hhmm, now, skip_today=skip_today)

# TODO: Add Account object instead of current User dictionary
class UserSchema(Schema):
class Meta:
unknown = EXCLUDE # Skip unknown fields on deserialization
username = fields.String(missing=None)
first_name = fields.String(missing=None)
last_name = fields.String(missing=None)
locale = fields.String(missing=None)
last_used_locale = fields.String(missing=None)

class ChatSchema(Schema):
class Meta:
unknown = EXCLUDE # Skip unknown fields on deserialization
chat_id = fields.Integer()
user = fields.Nested(UserSchema, missing=None)
user = fields.Nested(UserSchema, missing={})
subscribed = fields.Boolean(missing=True)
last_seen = fields.Raw(missing=None) # Keep the raw datetime for Mongo
quiz = fields.Nested(QuizSchema, missing=None)
Expand Down
19 changes: 17 additions & 2 deletions tests/test_chats_repo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import logging
from datetime import datetime, timezone
from datetime import timedelta

Expand All @@ -14,13 +15,25 @@ def __init__(self, id, username=None, first_name=None, last_name=None):
self.first_name = first_name
self.last_name = last_name

class MockTGMessage(object):
def __init__(self, chat, from_user):
self.chat = chat
self.from_user = from_user

class MockTGUser(object):
def __init__(self, locale):
self.locale = locale

# TODO: Test save/load of a chat with attached quiz
class TestChatsRepo(unittest.TestCase):
def setUp(self):
self.chats_col = mongomock.MongoClient().db.collection
scores_col = mongomock.MongoClient().db.collection
self.chats_repo = ChatsRepo(self.chats_col, scores_col)

logging.basicConfig(level=logging.DEBUG)
logging.getLogger().setLevel(logging.DEBUG)

def test_save_chat(self):
chat = Chat(chat_id=1001)
self.chats_repo.save_chat(chat)
Expand All @@ -29,13 +42,15 @@ def test_save_chat(self):
self.assertEqual(1001, saved_chats[0]['chat_id'])

def test_load_saved_chat(self):
mock_message = MockTGMessage(chat=MockTGChat(1001), from_user=MockTGUser(locale='xx'))
chat = Chat(chat_id=1001, subscribed=False)
self.chats_repo.save_chat(chat)
result = self.chats_repo.load_chat(MockTGChat(1001))
result = self.chats_repo.load_chat(mock_message)
self.assertEqual(False, result.subscribed)

def test_load_new_chat(self):
result: Chat = self.chats_repo.load_chat(MockTGChat(1001, 'vassily'))
mock_message = MockTGMessage(chat=MockTGChat(1001, 'vassily'), from_user=MockTGUser(locale='xx'))
result: Chat = self.chats_repo.load_chat(mock_message)
self.assertEqual(1001, result.id)
self.assertEqual('vassily', result.user['username'])
self.assertEqual(None, result.user['first_name'])
Expand Down

0 comments on commit 4e87647

Please sign in to comment.