diff --git a/commands/set_channel.py b/commands/set_channel.py index cd95492..6b003a2 100644 --- a/commands/set_channel.py +++ b/commands/set_channel.py @@ -3,12 +3,12 @@ from aiogram.filters import Command from aiogram.types import Message from database.database import pg_con +from filters.chat_type import ChatTypeFilter router = Router() -@router.message(Command('set_channel')) +@router.message(Command('set_channel'), ChatTypeFilter(chat_type=["group", "supergroup"])) async def set_channel(message: Message): - print(message) args = message.text.split() if len(args) < 2: diff --git a/events/join_chat.py b/events/join_chat.py index 774f606..cd89ed2 100644 --- a/events/join_chat.py +++ b/events/join_chat.py @@ -2,15 +2,13 @@ from aiogram import types, Router from aiogram.filters import ChatMemberUpdatedFilter, IS_MEMBER from database.database import pg_con +from filters.chat_type import ChatTypeFilter router = Router() -@router.my_chat_member(ChatMemberUpdatedFilter(IS_MEMBER)) +@router.my_chat_member(ChatMemberUpdatedFilter(IS_MEMBER), ChatTypeFilter(chat_type=["group", "supergroup"])) async def join_chat(event: types.Message): - if event.chat.type not in {'group', 'supergroup'}: - return - conn = await pg_con() data = await conn.fetch('SELECT emoji_list FROM chat WHERE chat_id = $1', event.chat.id) diff --git a/events/reactions.py b/events/reactions.py index 258df9a..d57b45f 100644 --- a/events/reactions.py +++ b/events/reactions.py @@ -3,6 +3,7 @@ from aiogram.exceptions import TelegramBadRequest from bot import bot from database.database import pg_con +from filters.chat_type import ChatTypeFilter router = Router() @@ -25,11 +26,8 @@ async def update_reaction_count(conn, chat_id, message_id, delta): ) -@router.message_reaction() +@router.message_reaction(ChatTypeFilter(chat_type=["group", "supergroup"])) async def register_message_reaction(event: types.MessageReactionUpdated): - if event.chat.type not in {'group', 'supergroup'}: - return - conn = await pg_con() data_reaction = await conn.fetchval('SELECT emoji_list FROM chat WHERE chat_id = $1', event.chat.id) diff --git a/filters/chat_type.py b/filters/chat_type.py new file mode 100644 index 0000000..b46fb48 --- /dev/null +++ b/filters/chat_type.py @@ -0,0 +1,13 @@ +from aiogram.filters import BaseFilter +from aiogram.types import Message + +# https://stackoverflow.com/a/78272229/20781634 +class ChatTypeFilter(BaseFilter): + def __init__(self, chat_type: str | list): + self.chat_type = chat_type + + async def __call__(self, message: Message) -> bool: + if isinstance(self.chat_type, str): + return message.chat.type == self.chat_type + else: + return message.chat.type in self.chat_type \ No newline at end of file