Skip to content

Commit

Permalink
adding endpoint suspensions
Browse files Browse the repository at this point in the history
  • Loading branch information
Deutscher775 committed Oct 4, 2024
1 parent 32320e1 commit 45574d7
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 10 deletions.
52 changes: 46 additions & 6 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import astroidapi.read_handler
import astroidapi.surrealdb_handler
import astroidapi.statistics
import astroidapi.suspension_handler
import beta_users
from fastapi.middleware.cors import CORSMiddleware
import requests
Expand All @@ -28,7 +29,6 @@
from slowapi.util import get_remote_address
import logging
import astroidapi
import datetime

# Configure logging to log to a file
logFormatter = logging.Formatter("%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s")
Expand Down Expand Up @@ -63,8 +63,7 @@
version="2.1.4",
docs_url=None
)
api.state.limiter = limiter
api.add_exception_handler(RateLimitExceeded, slowapi._rate_limit_exceeded_handler)




Expand Down Expand Up @@ -127,7 +126,7 @@ def root():
"description": "Astroid API for getting and modifying endpoints.",
"website": "https://astroid.cc",
"privacy": "https://astroid.cc/privacy",
"terms": "https://astroid.cce/terms",
"terms": "https://astroid.cc/terms",
"imprint": "https://deutscher775.de/imprint.html",
"docs": "https://astroid.cc/docs",
"discord": "https://discord.gg/DbrFADj6Xw",
Expand Down Expand Up @@ -238,6 +237,10 @@ def get_server_structure(id: int, token: Annotated[str, fastapi.Query(max_length
@api.get("/{endpoint}", description="Get an endpoint.")
async def get_endpoint(endpoint: int,
token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None, download: bool = False):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})

global data_token
try:
data_token = json.load(open(f"{pathlib.Path(__file__).parent.resolve()}/tokens.json", "r"))[f"{endpoint}"]
Expand All @@ -264,6 +267,9 @@ async def get_endpoint(endpoint: int,
@api.get("/bridges/{endpoint}", description="Get an endpoint.")
async def get_bridges(endpoint: int,
token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})
global data_token
try:
data_token = json.load(open(f"{pathlib.Path(__file__).parent.resolve()}/tokens.json", "r"))[f"{endpoint}"]
Expand Down Expand Up @@ -297,8 +303,12 @@ async def get_bridges(endpoint: int,


@api.post("/token/{endpoint}", description="Generate a new token. (Only works with astroid-Bot)")
def new_token(endpoint: int,
async def new_token(endpoint: int,
master_token: Annotated[str, fastapi.Query(max_length=85, min_length=85)]):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})

if master_token == Bot.config.MASTER_TOKEN:
with open(f"{pathlib.Path(__file__).parent.resolve()}/tokens.json", "r+") as tokens:
token = secrets.token_urlsafe(53)
Expand Down Expand Up @@ -360,6 +370,10 @@ async def post_endpoint(
beta: bool = False,
only_check = False,
):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})

await astroidapi.endpoint_update_handler.UpdateHandler.update_endpoint(
endpoint=endpoint,
index=index,
Expand Down Expand Up @@ -391,7 +405,7 @@ async def post_endpoint(
beta=beta,
only_check=only_check,
)
return astroidapi.surrealdb_handler.get_endpoint(endpoint)
return await astroidapi.surrealdb_handler.get_endpoint(endpoint)


@api.patch("/sync", description="Sync the local files with the database.")
Expand All @@ -415,6 +429,10 @@ async def mark_read(endpoint: int,
read_guilded: bool = None,
read_revolt: bool = None,
read_nerimity: bool = None):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})

if token == data_token or token == Bot.config.MASTER_TOKEN:
try:
if read_discord:
Expand All @@ -436,6 +454,10 @@ async def mark_read(endpoint: int,

@api.get("/healthcheck/{endpoint}", description="Validate the endpoints strucuture.")
async def endpoint_healthcheck(endpoint: int, token: str):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})

if token == Bot.config.MASTER_TOKEN:
try:
healty = await astroidapi.health_check.HealthCheck.EndpointCheck.check(endpoint)
Expand All @@ -462,6 +484,12 @@ async def endpoint_healthcheck(endpoint: int, token: str):
@api.post("/create", description="Create an endpoint.",
response_description="Endpoints data.")
async def create_endpoint(endpoint: int):
try:
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})
except:
pass
try:
data = {
"config": {
Expand Down Expand Up @@ -569,6 +597,10 @@ async def delete_enpoint_data(endpoint: int,
message_content: bool = None,
message_attachments: Annotated[str, fastapi.Query(max_length=1550, min_length=20)] = None,
token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})

data_token = json.load(open(f"{pathlib.Path(__file__).parent.resolve()}/tokens.json", "r"))[f"{endpoint}"]
if token is not None:
if token == data_token or token == Bot.config.MASTER_TOKEN:
Expand Down Expand Up @@ -650,6 +682,10 @@ async def delete_enpoint_data(endpoint: int,

@api.get("/getendpoint/{platform}", description="Get an endpoint via a platform server id.")
async def get_endpoint_platform(platform: str, id: str, token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})

if not token == Bot.config.MASTER_TOKEN:
return fastapi.responses.JSONResponse(status_code=401, content={"message": "The provided token is invalid. (Only the master token can be used to view or create relations.)"})
try:
Expand All @@ -667,6 +703,10 @@ async def get_endpoint_platform(platform: str, id: str, token: Annotated[str, fa

@api.post("/createendpoint/{platform}", description="Create an endpoint via a platform server id.")
async def create_endpoint_platform(platform: str, endpoint: int, id: str, token: Annotated[str, fastapi.Query(max_length=85, min_length=71)] = None):
suspend_status = await astroidapi.suspension_handler.Endpoint.is_suspended(endpoint)
if suspend_status:
return fastapi.responses.JSONResponse(status_code=403, content={"message": "This endpoint is suspended."})

if not token == Bot.config.MASTER_TOKEN:
return fastapi.responses.JSONResponse(status_code=401, content={"message": "The provided token is invalid. (Only the master token can be used to view or create relations.)"})
try:
Expand Down
42 changes: 42 additions & 0 deletions src/astroidapi/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

# This file contains all the custom exceptions that are raised in the api or during handling.
# The exceptions are divided into different classes based on the module they are raised in.
# Exceptions listed here aren't handeled yet. Currently they are just for raising and logging purposes.
# This will be updated and exceptions will be handeled in the future.


class SendingError(Exception):
Expand Down Expand Up @@ -204,6 +208,26 @@ def __init__(self, message):
super().__init__(self.message)


class GetSuspensionStatusError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)

class SuspendEndpointError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)

class UnsuspendEndpointError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)

class GetSuspensionError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)


class ReadHandlerError:

Expand Down Expand Up @@ -268,6 +292,24 @@ def __init__(self, message):

class ProfileProcessorError:
class ProfileNotFoundError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)


class SuspensionHandlerError:

class GetSuspensionStatusError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)

class SuspendEndpointError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)

class UnsuspendEndpointError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
72 changes: 68 additions & 4 deletions src/astroidapi/surrealdb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ async def for_nerimity(cls, endpoint: int, nerimity_id: str):
raise errors.SurrealDBHandler.CreateEndpointError(e)





class Statistics:

@staticmethod
Expand Down Expand Up @@ -334,4 +331,71 @@ async def update_messages(increment: int, start_period: bool = False):
return await db.select("statistics:messages")
except Exception as e:
traceback.print_exc()
raise errors.SurrealDBHandler.GetStatisticsError(e)
raise errors.SurrealDBHandler.GetStatisticsError(e)


class Suspension:

@classmethod
async def get_suspend_status(cls, endpoint_id):
try:
async with Surreal(config.SDB_URL) as db:
await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS})
await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE)
return await db.select(f"suspensions:`{endpoint_id}`")
except Exception as e:
raise errors.SurrealDBHandler.GetSuspensionStatusError(e)


class Endpoints:
@classmethod
async def suspend(cls, endpoint_id, reason, suspended_by: int, expire_at: int = None):
data = {
"reason": reason,
"expireAt": expire_at,
"type": "endpoint",
"suspended": True,
"suspendedAt": datetime.datetime.now().timestamp(),
"suspendedBy": suspended_by
}
try:
async with Surreal(config.SDB_URL) as db:
await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS})
await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE)
await db.create(f"suspensions:`{endpoint_id}`", data)
return await db.select(f"suspensions:`{endpoint_id}`")
except Exception as e:
raise errors.SurrealDBHandler.SuspendEndpointError(e)

@classmethod
async def unsuspend(cls, endpoint_id):
try:
async with Surreal(config.SDB_URL) as db:
await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS})
await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE)
await db.delete(f"suspensions:`{endpoint_id}`")
return True
except Exception as e:
raise errors.SurrealDBHandler.UnsuspendEndpointError(e)

@classmethod
async def update(cls, endpoint_id, reason: str = None, suspended_by: int = None, expire_at: int = None):
data = {}
if reason:
data["reason"] = reason
if suspended_by:
data["suspendedBy"] = suspended_by
if expire_at:
data["expireAt"] = expire_at

try:
async with Surreal(config.SDB_URL) as db:
await db.signin({"user": config.SDB_USER, "pass": config.SDB_PASS})
await db.use(config.SDB_NAMESPACE, config.SDB_DATABASE)
current_data = await db.select(f"suspensions:`{endpoint_id}`")
for key in data:
current_data[key] = data[key]
await db.update(f"suspensions:`{endpoint_id}`", current_data)
return await db.select(f"suspensions:`{endpoint_id}`")
except Exception as e:
raise errors.SurrealDBHandler.SuspendEndpointError(e)
35 changes: 35 additions & 0 deletions src/astroidapi/suspension_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import astroidapi.surrealdb_handler as surrealdb_handler
import astroidapi.errors as errors


class Endpoint():
def __init__(self, endpoint_id):
self.endpoint_id = endpoint_id

@classmethod
async def is_suspended(cls, endpoint_id):
try:
suspended = await surrealdb_handler.Suspension.get_suspend_status(endpoint_id)
print(suspended)
except errors.SurrealDBHandler.GetSuspensionStatusError as e:
raise errors.SuspensionHandlerError.GetSuspensionStatusError(e)
try:
return suspended["suspended"]
except KeyError:
return False
except TypeError:
return False

@classmethod
async def suspend(cls, endpoint_id, reason, suspended_by: int, expire_at: int = None):
try:
await surrealdb_handler.Suspension.Endpoints.suspend(endpoint_id, reason, suspended_by, expire_at)
except errors.SurrealDBHandler.SuspendEndpointError as e:
raise errors.SuspensionHandlerError.SuspendEndpointError(e)

@classmethod
async def unsuspend(cls, endpoint_id):
try:
await surrealdb_handler.Suspension.Endpoints.unsuspend(endpoint_id)
except errors.SurrealDBHandler.UnsuspendEndpointError as e:
raise errors.SuspensionHandlerError.UnsuspendEndpointError(e)

0 comments on commit 45574d7

Please sign in to comment.