diff --git a/sample.config.yaml b/sample.config.yaml index a3aed63..a4ba95b 100644 --- a/sample.config.yaml +++ b/sample.config.yaml @@ -39,6 +39,8 @@ vetting: room_id: "!xxx:xxx" # A space which will house all the vetting rooms space_id: "!xxx:xxx" + # Voting time in seconds + voting_time: 172800 # Logging setup logging: diff --git a/vetting_bot/bot_commands.py b/vetting_bot/bot_commands.py index ab305ac..d8c4478 100644 --- a/vetting_bot/bot_commands.py +++ b/vetting_bot/bot_commands.py @@ -15,6 +15,7 @@ from nio import ( from vetting_bot.chat_functions import react_to_event, send_text_to_room from vetting_bot.config import Config from vetting_bot.storage import Storage +from vetting_bot.timer import Timer logger = logging.getLogger(__name__) @@ -249,9 +250,16 @@ class Command: await send_text_to_room(self.client, self.room.room_id, text) return + voting_start_time = time.time() + self.store.cursor.execute( "UPDATE vetting SET poll_event_id = ?, voting_start_time = ? WHERE mxid = ?", - (poll_resp.event_id, time.time(), vetted_user_id), + (poll_resp.event_id, voting_start_time, vetted_user_id), + ) + + timer = Timer(self.client, self.store, self.config) + await timer.wait_for_poll_end( + vetted_user_id, poll_resp.event_id, voting_start_time ) async def _unknown_command(self): diff --git a/vetting_bot/config.py b/vetting_bot/config.py index eaf3378..ef8a856 100644 --- a/vetting_bot/config.py +++ b/vetting_bot/config.py @@ -115,6 +115,8 @@ class Config: if not re.match("!.*:.*", self.vetting_space_id): raise ConfigError("vetting.space_id must be in the form !xxx:domain") + self.voting_time = int(self._get_cfg(["vetting", "voting_time"], required=True)) + def _get_cfg( self, path: List[str], diff --git a/vetting_bot/main.py b/vetting_bot/main.py index 121212e..3745ad2 100644 --- a/vetting_bot/main.py +++ b/vetting_bot/main.py @@ -19,6 +19,7 @@ from nio import ( from vetting_bot.callbacks import Callbacks from vetting_bot.config import Config from vetting_bot.storage import Storage +from vetting_bot.timer import Timer logger = logging.getLogger(__name__) @@ -108,6 +109,14 @@ async def main(): client.server = client.user_id.split(":", maxsplit=1)[1] + timer = Timer(client, store, config) + + async def start_timer_with_delay(): + await asyncio.sleep(5) + await timer.start_all_timers() + + asyncio.create_task(start_timer_with_delay()) + await client.sync_forever(timeout=30000, full_state=True) except (ClientConnectionError, ServerDisconnectedError): diff --git a/vetting_bot/storage.py b/vetting_bot/storage.py index cbb3993..b912688 100644 --- a/vetting_bot/storage.py +++ b/vetting_bot/storage.py @@ -111,7 +111,8 @@ class Storage: room_id VARCHAR(255) NOT NULL UNIQUE, vetting_create_time INT(12), voting_start_time INT(12), - poll_event_id VARCHAR(255) + poll_event_id VARCHAR(255), + vote_ended BOOLEAN NOT NULL DEFAULT FALSE ) """ ) diff --git a/vetting_bot/timer.py b/vetting_bot/timer.py new file mode 100644 index 0000000..30e8a67 --- /dev/null +++ b/vetting_bot/timer.py @@ -0,0 +1,71 @@ +import asyncio +import logging +import time + +from nio import AsyncClient, RoomSendError + +from vetting_bot.config import Config +from vetting_bot.storage import Storage + +logger = logging.getLogger(__name__) + + +class Timer: + def __init__( + self, + client: AsyncClient, + store: Storage, + config: Config, + ): + self.client = client + self.store = store + self.config = config + + async def start_all_timers(self): + self.store.cursor.execute( + """ + SELECT mxid, poll_event_id, voting_start_time, vote_ended + FROM vetting + WHERE voting_start_time IS NOT NULL + """ + ) + + rows = self.store.cursor.fetchall() + for row in rows: + if row[3]: + continue + + await self.wait_for_poll_end( + mxid=row[0], poll_event_id=row[1], start_time=row[2] + ) + + async def wait_for_poll_end(self, mxid: str, poll_event_id: str, start_time: int): + async def _task(): + time_left = start_time + self.config.voting_time - time.time() + await asyncio.sleep(time_left) + await self._end_poll(mxid, poll_event_id) + + asyncio.create_task(_task()) + + async def _end_poll(self, mxid: str, poll_event_id: str): + event_content = { + "m.relates_to": { + "rel_type": "m.reference", + "event_id": poll_event_id, + } + } + + poll_resp = await self.client.room_send( + self.config.vetting_room_id, + message_type="org.matrix.msc3381.poll.end", + content=event_content, + ) + + if isinstance(poll_resp, RoomSendError): + logger.error(poll_resp, stack_info=True) + return + + self.store.cursor.execute( + "UPDATE vetting SET vote_ended = 1 WHERE mxid = ?", + (mxid,), + )