diff --git a/database/schemas/data.sql b/database/schemas/data.sql new file mode 100644 index 0000000..e90f37a --- /dev/null +++ b/database/schemas/data.sql @@ -0,0 +1,5 @@ +CREATE TABLE IF NOT EXISTS users ( + user_id INTEGER UNIQUE NOT NULL, + chart BOOLEAN DEFAULT 1, + lang TEXT DEFAULT 'en' +); diff --git a/database/server.py b/database/server.py new file mode 100644 index 0000000..b7ec3f6 --- /dev/null +++ b/database/server.py @@ -0,0 +1,189 @@ +import json +from datetime import date, datetime +from pathlib import Path +from typing import Optional, List, Dict, Any + +import aiosqlite +import yaml + +config = yaml.safe_load(open('../config.yaml', 'r', encoding='utf-8')) + +def custom_encoder(obj): + """ + Custom JSON encoder for objects not serializable by default. + + Converts date and datetime objects to ISO 8601 string format. + Raises TypeError for unsupported types. + """ + if isinstance(obj, (date, datetime)): + return obj.isoformat() + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + +class Database: + """ + Asynchronous SQLite database handler using aiosqlite. + + This class manages a single asynchronous connection to an SQLite database file + and provides common methods for executing queries, including: + + - Connecting and disconnecting to the database. + - Fetching a single row or multiple rows. + - Inserting data and returning the last inserted row ID. + - Updating or deleting data and returning the count of affected rows. + + Attributes: + db_path (str): Path to the SQLite database file. + conn (Optional[aiosqlite.Connection]): The active database connection, or None if disconnected. + + Methods: + connect(): Asynchronously open a connection to the database. + disconnect(): Asynchronously close the database connection. + fetch(query, *args): Execute a query and return a single row as a dictionary. + fetchmany(query, *args): Execute a query and return multiple rows as a list of dictionaries. + insert(query, *args): Execute an INSERT query and return the last inserted row ID. + update(query, *args): Execute an UPDATE or DELETE query and return the number of affected rows. + + Raises: + RuntimeError: If any method is called before the database connection is established. + + Example: + ```python + db = Database("example.db") + await db.connect() + user = await db.fetch("SELECT * FROM users WHERE id = ?", user_id) + await db.disconnect() + ``` + """ + + def __init__(self, db_path: str): + """ + Initialize Database instance. + + Args: + db_path (str): Path to SQLite database file. + """ + self.db_path = db_path + self.conn: Optional[aiosqlite.Connection] = None + + + async def _create_table(self) -> None: + """ + Create table from SQL file using aiosqlite. + + Reads SQL commands from 'schemas/data.sql' + and executes them as a script. + """ + sql_file = Path(__file__).parent / "schemas" / "data.sql" + sql = sql_file.read_text(encoding='utf-8') + + async with self.conn.execute("BEGIN"): + await self.conn.executescript(sql) + await self.conn.commit() + + async def connect(self) -> None: + """ + Open SQLite database connection asynchronously. + """ + self.conn = await aiosqlite.connect(self.db_path) + self.conn.row_factory = aiosqlite.Row # return dict-like rows + + async def disconnect(self) -> None: + """ + Close SQLite database connection asynchronously. + """ + if self.conn: + await self.conn.close() + self.conn = None + + async def fetch(self, query: str, *args) -> Dict[str, Any]: + """ + Execute a query and fetch a single row as a dictionary. + + Args: + query (str): SQL query string. + *args: Parameters for the SQL query. + + Returns: + Dict[str, Any]: The first row returned by the query as a dict, + or an empty dict if no row is found. + + Raises: + RuntimeError: If the database connection is not initialized. + """ + if not self.conn: + raise RuntimeError("Database connection is not initialized.") + + async with self.conn.execute(query, args) as cursor: + row = await cursor.fetchone() + if row: + return json.loads(json.dumps(dict(row), default=custom_encoder)) + return {} + + async def fetchmany(self, query: str, *args) -> List[Dict[str, Any]]: + """ + Execute a query and fetch multiple rows as a list of dictionaries. + + Args: + query (str): SQL query string. + *args: Parameters for the SQL query. + + Returns: + List[Dict[str, Any]]: List of rows as dictionaries, + or empty list if no rows found. + + Raises: + RuntimeError: If the database connection is not initialized. + """ + if not self.conn: + raise RuntimeError("Database connection is not initialized.") + + async with self.conn.execute(query, args) as cursor: + rows = await cursor.fetchall() + return json.loads( + json.dumps( + [dict(row) for row in rows], + default=custom_encoder + ) + ) if rows else [] + + async def insert(self, query: str, *args) -> Dict[str, Any]: + """ + Execute an INSERT query and return the last inserted row ID. + + Args: + query (str): SQL INSERT query string. + *args: Parameters for the SQL query. + + Returns: + Dict[str, Any]: Dictionary containing the last inserted row ID. + + Raises: + RuntimeError: If the database connection is not initialized. + """ + if not self.conn: + raise RuntimeError("Database connection is not initialized.") + + async with self.conn.execute(query, args) as cursor: + await self.conn.commit() + return {"last_row_id": cursor.lastrowid} + + async def update(self, query: str, *args) -> Dict[str, Any]: + """ + Execute an UPDATE or DELETE query and return affected rows count. + + Args: + query (str): SQL UPDATE or DELETE query string. + *args: Parameters for the SQL query. + + Returns: + Dict[str, Any]: Dictionary containing the number of affected rows. + + Raises: + RuntimeError: If the database connection is not initialized. + """ + if not self.conn: + raise RuntimeError("Database connection is not initialized.") + + async with self.conn.execute(query, args) as cursor: + await self.conn.commit() + return {"rows_affected": cursor.rowcount} diff --git a/main.py b/main.py index d695abb..aa3ff14 100644 --- a/main.py +++ b/main.py @@ -13,12 +13,14 @@ from functions.convert import Converter from functions.create_chart import create_chart from utils.format_number import format_number from utils.inline_query import reply +from database.server import Database config = yaml.safe_load(open('../config.yaml', 'r', encoding='utf-8')) bot = Bot( token=config['telegram_token'], default=DefaultBotProperties(parse_mode=ParseMode.HTML) ) +db = Database('shirino.db') router = Router() @@ -127,6 +129,8 @@ async def currency(query: types.InlineQuery) -> None: async def on_startup(bot: Bot) -> None: + await db.connect() + await db._create_table() await bot.set_webhook( f"{config['webhook']['base_url']}{config['webhook']['path']}", secret_token=config['webhook']['secret_token'], @@ -134,11 +138,16 @@ async def on_startup(bot: Bot) -> None: ) +async def on_shutdown(): + await db.disconnect() + + def main() -> None: dp = Dispatcher() dp.include_router(router) dp.startup.register(on_startup) + dp.shutdown.register(on_shutdown) app = web.Application() webhook_requests_handler = SimpleRequestHandler( diff --git a/requirements.txt b/requirements.txt index 21259b5..6ad7c0d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ aiohttp~=3.9.5 PyYAML~=6.0.1 -aiogram~=3.15.0 \ No newline at end of file +aiogram~=3.15.0 +aiosqlite~=0.21.0 \ No newline at end of file