Store vetting info in the database

This commit is contained in:
Panoramic 2024-06-11 11:14:42 +03:00
parent 67259a7f93
commit ffcc7e4685
Signed by: Panoramic
GPG Key ID: 29FEDD73E66D32F1
2 changed files with 43 additions and 16 deletions

View File

@ -1,6 +1,7 @@
import logging import logging
import random import random
import re import re
import time
from nio import ( from nio import (
AsyncClient, AsyncClient,
@ -12,7 +13,7 @@ from nio import (
from vetting_bot.chat_functions import react_to_event, send_text_to_room from vetting_bot.chat_functions import react_to_event, send_text_to_room
from vetting_bot.config import Config from vetting_bot.config import Config
from vetting_bot.storage import Storage from vetting_bot.storage import Storage, IntegrityError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -118,6 +119,16 @@ class Command:
await send_text_to_room(self.client, self.room.room_id, text) await send_text_to_room(self.client, self.room.room_id, text)
return return
# Check if vetting room already exists for user
self.store.cursor.execute(
"SELECT room_id FROM vetting WHERE mxid=?", (vetted_user_id,)
)
row = self.store.cursor.fetchone()
if row is not None:
text = f"A vetting room already exists for this user: `{row[0]}`"
await send_text_to_room(self.client, self.room.room_id, text)
return
# Get members to invite # Get members to invite
invitees = [member_id for member_id in self.room.users.keys()] invitees = [member_id for member_id in self.room.users.keys()]
invitees.append(vetted_user_id) invitees.append(vetted_user_id)
@ -144,7 +155,11 @@ class Command:
logging.error(room_resp, stack_info=True) logging.error(room_resp, stack_info=True)
return return
self.store.conn # Create new vetting entry
self.store.cursor.execute(
"INSERT INTO vetting (mxid, room_id, vetting_create_time) VALUES (?, ?, ?)",
(vetted_user_id, room_resp.room_id, time.time()),
)
# Add newly created room to space # Add newly created room to space
space_child_content = { space_child_content = {

View File

@ -1,6 +1,9 @@
import logging import logging
import sqlite3
from typing import Any, Dict from typing import Any, Dict
import psycopg2
# The latest migration version of the database. # The latest migration version of the database.
# #
# Database migrations are applied starting from the number specified in the database's # Database migrations are applied starting from the number specified in the database's
@ -8,7 +11,7 @@ from typing import Any, Dict
# the version specified here. # the version specified here.
# #
# When a migration is performed, the `migration_version` table should be incremented. # When a migration is performed, the `migration_version` table should be incremented.
latest_migration_version = 0 latest_migration_version = 1
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -51,13 +54,9 @@ class Storage:
) -> Any: ) -> Any:
"""Creates and returns a connection to the database""" """Creates and returns a connection to the database"""
if database_type == "sqlite": if database_type == "sqlite":
import sqlite3
# Initialize a connection to the database, with autocommit on # Initialize a connection to the database, with autocommit on
return sqlite3.connect(connection_string, isolation_level=None) return sqlite3.connect(connection_string, isolation_level=None)
elif database_type == "postgres": elif database_type == "postgres":
import psycopg2
conn = psycopg2.connect(connection_string) conn = psycopg2.connect(connection_string)
# Autocommit on # Autocommit on
@ -102,15 +101,25 @@ class Storage:
""" """
logger.debug("Checking for necessary database migrations...") logger.debug("Checking for necessary database migrations...")
# if current_migration_version < 1: if current_migration_version < 1:
# logger.info("Migrating the database from v0 to v1...") logger.info("Migrating the database from v0 to v1...")
#
# # Add new table, delete old ones, etc. self._execute(
# """
# # Update the stored migration version CREATE TABLE vetting (
# self._execute("UPDATE migration_version SET version = 1") mxid VARCHAR(255) NOT NULL PRIMARY KEY UNIQUE,
# room_id VARCHAR(255) NOT NULL UNIQUE,
# logger.info("Database migrated to v1") vetting_create_time INT(12),
voting_start_time INT(12),
poll_event_id VARCHAR(255)
)
"""
)
# Update the stored migration version
self._execute("UPDATE migration_version SET version = 1")
logger.info("Database migrated to v1")
def _execute(self, *args) -> None: def _execute(self, *args) -> None:
"""A wrapper around cursor.execute that transforms placeholder ?'s to %s for postgres. """A wrapper around cursor.execute that transforms placeholder ?'s to %s for postgres.
@ -124,3 +133,6 @@ class Storage:
self.cursor.execute(args[0].replace("?", "%s"), *args[1:]) self.cursor.execute(args[0].replace("?", "%s"), *args[1:])
else: else:
self.cursor.execute(*args) self.cursor.execute(*args)
class IntegrityError(psycopg2.IntegrityError, sqlite3.IntegrityError): ...