diff --git a/vetting_bot/bot_commands.py b/vetting_bot/bot_commands.py index 0e0929d..cf00b84 100644 --- a/vetting_bot/bot_commands.py +++ b/vetting_bot/bot_commands.py @@ -1,6 +1,7 @@ import logging import random import re +import time from nio import ( AsyncClient, @@ -12,7 +13,7 @@ from nio import ( from vetting_bot.chat_functions import react_to_event, send_text_to_room from vetting_bot.config import Config -from vetting_bot.storage import Storage +from vetting_bot.storage import Storage, IntegrityError logger = logging.getLogger(__name__) @@ -118,6 +119,16 @@ class Command: await send_text_to_room(self.client, self.room.room_id, text) return + # Check if vetting room already exists for user + self.store.cursor.execute( + "SELECT room_id FROM vetting WHERE mxid=?", (vetted_user_id,) + ) + row = self.store.cursor.fetchone() + if row is not None: + text = f"A vetting room already exists for this user: `{row[0]}`" + await send_text_to_room(self.client, self.room.room_id, text) + return + # Get members to invite invitees = [member_id for member_id in self.room.users.keys()] invitees.append(vetted_user_id) @@ -144,7 +155,11 @@ class Command: logging.error(room_resp, stack_info=True) return - self.store.conn + # Create new vetting entry + self.store.cursor.execute( + "INSERT INTO vetting (mxid, room_id, vetting_create_time) VALUES (?, ?, ?)", + (vetted_user_id, room_resp.room_id, time.time()), + ) # Add newly created room to space space_child_content = { diff --git a/vetting_bot/storage.py b/vetting_bot/storage.py index 580ebd1..819ab25 100644 --- a/vetting_bot/storage.py +++ b/vetting_bot/storage.py @@ -1,6 +1,9 @@ import logging +import sqlite3 from typing import Any, Dict +import psycopg2 + # The latest migration version of the database. # # Database migrations are applied starting from the number specified in the database's @@ -8,7 +11,7 @@ from typing import Any, Dict # the version specified here. # # When a migration is performed, the `migration_version` table should be incremented. -latest_migration_version = 0 +latest_migration_version = 1 logger = logging.getLogger(__name__) @@ -51,13 +54,9 @@ class Storage: ) -> Any: """Creates and returns a connection to the database""" if database_type == "sqlite": - import sqlite3 - # Initialize a connection to the database, with autocommit on return sqlite3.connect(connection_string, isolation_level=None) elif database_type == "postgres": - import psycopg2 - conn = psycopg2.connect(connection_string) # Autocommit on @@ -102,15 +101,25 @@ class Storage: """ logger.debug("Checking for necessary database migrations...") - # if current_migration_version < 1: - # logger.info("Migrating the database from v0 to v1...") - # - # # Add new table, delete old ones, etc. - # - # # Update the stored migration version - # self._execute("UPDATE migration_version SET version = 1") - # - # logger.info("Database migrated to v1") + if current_migration_version < 1: + logger.info("Migrating the database from v0 to v1...") + + self._execute( + """ + CREATE TABLE vetting ( + mxid VARCHAR(255) NOT NULL PRIMARY KEY UNIQUE, + room_id VARCHAR(255) NOT NULL UNIQUE, + vetting_create_time INT(12), + voting_start_time INT(12), + poll_event_id VARCHAR(255) + ) + """ + ) + + # Update the stored migration version + self._execute("UPDATE migration_version SET version = 1") + + logger.info("Database migrated to v1") def _execute(self, *args) -> None: """A wrapper around cursor.execute that transforms placeholder ?'s to %s for postgres. @@ -124,3 +133,6 @@ class Storage: self.cursor.execute(args[0].replace("?", "%s"), *args[1:]) else: self.cursor.execute(*args) + + +class IntegrityError(psycopg2.IntegrityError, sqlite3.IntegrityError): ...