From 45574d71173dca0f88c6d0d54fc14bf51da0b2a4 Mon Sep 17 00:00:00 2001 From: Jason <81298350+Deutscher775@users.noreply.github.com> Date: Fri, 4 Oct 2024 20:49:36 +0200 Subject: [PATCH] adding endpoint suspensions --- src/api.py | 52 +++++++++++++++++--- src/astroidapi/errors.py | 42 ++++++++++++++++ src/astroidapi/surrealdb_handler.py | 72 ++++++++++++++++++++++++++-- src/astroidapi/suspension_handler.py | 35 ++++++++++++++ 4 files changed, 191 insertions(+), 10 deletions(-) create mode 100644 src/astroidapi/suspension_handler.py diff --git a/src/api.py b/src/api.py index cce46ba..405e716 100644 --- a/src/api.py +++ b/src/api.py @@ -16,6 +16,7 @@ import astroidapi.read_handler import astroidapi.surrealdb_handler import astroidapi.statistics +import astroidapi.suspension_handler import beta_users from fastapi.middleware.cors import CORSMiddleware import requests @@ -28,7 +29,6 @@ from slowapi.util import get_remote_address import logging import astroidapi -import datetime # Configure logging to log to a file logFormatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s") @@ -63,8 +63,7 @@ version="2.1.4", docs_url=None ) -api.state.limiter = limiter -api.add_exception_handler(RateLimitExceeded, slowapi._rate_limit_exceeded_handler) + @@ -127,7 +126,7 @@ def root(): "description": "Astroid API for getting and modifying endpoints.", "website": "https://astroid.cc", "privacy": "https://astroid.cc/privacy", - "terms": "https://astroid.cce/terms", + "terms": "https://astroid.cc/terms", "imprint": "https://deutscher775.de/imprint.html", "docs": "https://astroid.cc/docs", "discord": "https://discord.gg/DbrFADj6Xw", @@ -238,6 +237,10 @@ def get_server_structure(id: int, token: Annotated[str, fastapi.Query(max_length @api.get("/{endpoint}", description="Get an endpoint.") async def get_endpoint(endpoint: int, token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None, download: bool = False): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + global data_token try: data_token = json.load(open(f"{pathlib.Path(__file__).parent.resolve()}/tokens.json", "r"))[f"{endpoint}"] @@ -264,6 +267,9 @@ async def get_endpoint(endpoint: int, @api.get("/bridges/{endpoint}", description="Get an endpoint.") async def get_bridges(endpoint: int, token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) global data_token try: data_token = json.load(open(f"{pathlib.Path(__file__).parent.resolve()}/tokens.json", "r"))[f"{endpoint}"] @@ -297,8 +303,12 @@ async def get_bridges(endpoint: int, @api.post("/token/{endpoint}", description="Generate a new token. (Only works with astroid-Bot)") -def new_token(endpoint: int, +async def new_token(endpoint: int, master_token: Annotated[str, fastapi.Query(max_length=85, min_length=85)]): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + if master_token == Bot.config.MASTER_TOKEN: with open(f"{pathlib.Path(__file__).parent.resolve()}/tokens.json", "r+") as tokens: token = secrets.token_urlsafe(53) @@ -360,6 +370,10 @@ async def post_endpoint( beta: bool = False, only_check = False, ): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + await astroidapi.endpoint_update_handler.UpdateHandler.update_endpoint( endpoint=endpoint, index=index, @@ -391,7 +405,7 @@ async def post_endpoint( beta=beta, only_check=only_check, ) - return astroidapi.surrealdb_handler.get_endpoint(endpoint) + return await astroidapi.surrealdb_handler.get_endpoint(endpoint) @api.patch("/sync", description="Sync the local files with the database.") @@ -415,6 +429,10 @@ async def mark_read(endpoint: int, read_guilded: bool = None, read_revolt: bool = None, read_nerimity: bool = None): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + if token == data_token or token == Bot.config.MASTER_TOKEN: try: if read_discord: @@ -436,6 +454,10 @@ async def mark_read(endpoint: int, @api.get("/healthcheck/{endpoint}", description="Validate the endpoints strucuture.") async def endpoint_healthcheck(endpoint: int, token: str): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + if token == Bot.config.MASTER_TOKEN: try: healty = await astroidapi.health_check.HealthCheck.EndpointCheck.check(endpoint) @@ -462,6 +484,12 @@ async def endpoint_healthcheck(endpoint: int, token: str): @api.post("/create", description="Create an endpoint.", response_description="Endpoints data.") async def create_endpoint(endpoint: int): + try: + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + except: + pass try: data = { "config": { @@ -569,6 +597,10 @@ async def delete_enpoint_data(endpoint: int, message_content: bool = None, message_attachments: Annotated[str, fastapi.Query(max_length=1550, min_length=20)] = None, token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + data_token = json.load(open(f"{pathlib.Path(__file__).parent.resolve()}/tokens.json", "r"))[f"{endpoint}"] if token is not None: if token == data_token or token == Bot.config.MASTER_TOKEN: @@ -650,6 +682,10 @@ async def delete_enpoint_data(endpoint: int, @api.get("/getendpoint/{platform}", description="Get an endpoint via a platform server id.") async def get_endpoint_platform(platform: str, id: str, token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + if not token == Bot.config.MASTER_TOKEN: return fastapi.responses.JSONResponse(status_code=401, content={"message": "The provided token is invalid. (Only the master token can be used to view or create relations.)"}) try: @@ -667,6 +703,10 @@ async def get_endpoint_platform(platform: str, id: str, token: Annotated[str, fa @api.post("/createendpoint/{platform}", description="Create an endpoint via a platform server id.") async def create_endpoint_platform(platform: str, endpoint: int, id: str, token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None): + suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint) + if suspend_status: + return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."}) + if not token == Bot.config.MASTER_TOKEN: return fastapi.responses.JSONResponse(status_code=401, content={"message": "The provided token is invalid. (Only the master token can be used to view or create relations.)"}) try: diff --git a/src/astroidapi/errors.py b/src/astroidapi/errors.py index 227060c..9de86d7 100644 --- a/src/astroidapi/errors.py +++ b/src/astroidapi/errors.py @@ -1,4 +1,8 @@ +# This file contains all the custom exceptions that are raised in the api or during handling. +# The exceptions are divided into different classes based on the module they are raised in. +# Exceptions listed here aren't handeled yet. Currently they are just for raising and logging purposes. +# This will be updated and exceptions will be handeled in the future. class SendingError(Exception): @@ -204,6 +208,26 @@ def __init__(self, message): super().__init__(self.message) + class GetSuspensionStatusError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + class SuspendEndpointError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + class UnsuspendEndpointError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + class GetSuspensionError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + class ReadHandlerError: @@ -268,6 +292,24 @@ def __init__(self, message): class ProfileProcessorError: class ProfileNotFoundError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class SuspensionHandlerError: + + class GetSuspensionStatusError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + class SuspendEndpointError(Exception): + def __init__(self, message): + self.message = message + super().__init__(self.message) + + class UnsuspendEndpointError(Exception): def __init__(self, message): self.message = message super().__init__(self.message) \ No newline at end of file diff --git a/src/astroidapi/surrealdb_handler.py b/src/astroidapi/surrealdb_handler.py index 59be32e..1066399 100644 --- a/src/astroidapi/surrealdb_handler.py +++ b/src/astroidapi/surrealdb_handler.py @@ -276,9 +276,6 @@ async def for_nerimity(cls, endpoint: int, nerimity_id: str): raise errors.SurrealDBHandler.CreateEndpointError(e) - - - class Statistics: @staticmethod @@ -334,4 +331,71 @@ async def update_messages(increment: int, start_period: bool = False): return await db.select("statistics:messages") except Exception as e: traceback.print_exc() - raise errors.SurrealDBHandler.GetStatisticsError(e) \ No newline at end of file + raise errors.SurrealDBHandler.GetStatisticsError(e) + + +class Suspension: + + @classmethod + async def get_suspend_status(cls, endpoint_id): + try: + async with Surreal(config.SDB_URL) as db: + await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS}) + await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE) + return await db.select(f"suspensions:`{endpoint_id}`") + except Exception as e: + raise errors.SurrealDBHandler.GetSuspensionStatusError(e) + + + class Endpoints: + @classmethod + async def suspend(cls, endpoint_id, reason, suspended_by: int, expire_at: int = None): + data = { + "reason": reason, + "expireAt": expire_at, + "type": "endpoint", + "suspended": True, + "suspendedAt": datetime.datetime.now().timestamp(), + "suspendedBy": suspended_by + } + try: + async with Surreal(config.SDB_URL) as db: + await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS}) + await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE) + await db.create(f"suspensions:`{endpoint_id}`", data) + return await db.select(f"suspensions:`{endpoint_id}`") + except Exception as e: + raise errors.SurrealDBHandler.SuspendEndpointError(e) + + @classmethod + async def unsuspend(cls, endpoint_id): + try: + async with Surreal(config.SDB_URL) as db: + await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS}) + await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE) + await db.delete(f"suspensions:`{endpoint_id}`") + return True + except Exception as e: + raise errors.SurrealDBHandler.UnsuspendEndpointError(e) + + @classmethod + async def update(cls, endpoint_id, reason: str = None, suspended_by: int = None, expire_at: int = None): + data = {} + if reason: + data["reason"] = reason + if suspended_by: + data["suspendedBy"] = suspended_by + if expire_at: + data["expireAt"] = expire_at + + try: + async with Surreal(config.SDB_URL) as db: + await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS}) + await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE) + current_data = await db.select(f"suspensions:`{endpoint_id}`") + for key in data: + current_data[key] = data[key] + await db.update(f"suspensions:`{endpoint_id}`", current_data) + return await db.select(f"suspensions:`{endpoint_id}`") + except Exception as e: + raise errors.SurrealDBHandler.SuspendEndpointError(e) \ No newline at end of file diff --git a/src/astroidapi/suspension_handler.py b/src/astroidapi/suspension_handler.py new file mode 100644 index 0000000..408214c --- /dev/null +++ b/src/astroidapi/suspension_handler.py @@ -0,0 +1,35 @@ +import astroidapi.surrealdb_handler as surrealdb_handler +import astroidapi.errors as errors + + +class Endpoint(): + def __init__(self, endpoint_id): + self.endpoint_id = endpoint_id + + @classmethod + async def is_suspended(cls, endpoint_id): + try: + suspended = await surrealdb_handler.Suspension.get_suspend_status(endpoint_id) + print(suspended) + except errors.SurrealDBHandler.GetSuspensionStatusError as e: + raise errors.SuspensionHandlerError.GetSuspensionStatusError(e) + try: + return suspended["suspended"] + except KeyError: + return False + except TypeError: + return False + + @classmethod + async def suspend(cls, endpoint_id, reason, suspended_by: int, expire_at: int = None): + try: + await surrealdb_handler.Suspension.Endpoints.suspend(endpoint_id, reason, suspended_by, expire_at) + except errors.SurrealDBHandler.SuspendEndpointError as e: + raise errors.SuspensionHandlerError.SuspendEndpointError(e) + + @classmethod + async def unsuspend(cls, endpoint_id): + try: + await surrealdb_handler.Suspension.Endpoints.unsuspend(endpoint_id) + except errors.SurrealDBHandler.UnsuspendEndpointError as e: + raise errors.SuspensionHandlerError.UnsuspendEndpointError(e) \ No newline at end of file