Skip to content

Commit

Permalink
Address security issues in SQL creation.
Browse files Browse the repository at this point in the history
  • Loading branch information
maxachis committed Jul 11, 2024
1 parent 9839de4 commit 9303fac
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 91 deletions.
7 changes: 7 additions & 0 deletions database_client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,10 @@
"lat",
"lng",
]
RESTRICTED_COLUMNS = [
"rejection_note",
"data_source_request",
"approval_status",
"airtable_uid",
"airtable_source_last_modified",
]

Check warning on line 95 in database_client/constants.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] database_client/constants.py#L95 <292>

no newline at end of file
Raw output
./database_client/constants.py:95:2: W292 no newline at end of file
155 changes: 72 additions & 83 deletions database_client/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid

import psycopg2
from psycopg2.extras import DictCursor
from psycopg2 import sql

from database_client.dynamic_query_constructor import DynamicQueryConstructor
from middleware.custom_exceptions import (
Expand All @@ -15,14 +15,6 @@
AccessTokenNotFoundError,
)

RESTRICTED_COLUMNS = [
"rejection_note",
"data_source_request",
"approval_status",
"airtable_uid",
"airtable_source_last_modified",
]

DATA_SOURCES_MAP_COLUMN = [
"data_source_id",
"name",
Expand Down Expand Up @@ -70,8 +62,6 @@
"""




class DatabaseClient:

def __init__(self, cursor: psycopg2.extensions.cursor):
Expand All @@ -84,18 +74,24 @@ def add_new_user(self, email: str, password_digest: str):
:param password_digest:
:return:
"""
self.cursor.execute(
f"insert into users (email, password_digest) values (%s, %s)",
(email, password_digest),
query = sql.SQL(
"insert into users (email, password_digest) values ({}, {})"
).format(
sql.Literal(email),
sql.Literal(password_digest),
)
self.cursor.execute(query)

def get_user_id(self, email: str) -> Optional[int]:
"""
Gets the ID of a user in the database based on their email.
:param email:
:return:
"""
self.cursor.execute(f"select id from users where email = %s", (email,))
query = sql.SQL("select id from users where email = {}").format(
sql.Literal(email)
)
self.cursor.execute(query)
if self.cursor.rowcount == 0:
return None
return self.cursor.fetchone()[0]
Expand All @@ -107,10 +103,13 @@ def set_user_password_digest(self, email: str, password_digest: str):
:param password_digest:
:return:
"""
self.cursor.execute(
f"update users set password_digest = %s where email = %s",
(password_digest, email),
query = sql.SQL(
"update users set password_digest = {} where email = {}"
).format(
sql.Literal(password_digest),
sql.Literal(email),
)
self.cursor.execute(query)

ResetTokenInfo = namedtuple("ResetTokenInfo", ["id", "email", "create_date"])

Expand All @@ -121,10 +120,10 @@ def get_reset_token_info(self, token: str) -> Optional[ResetTokenInfo]:
:param token: The reset token to check.
:return: ResetTokenInfo if the token exists; otherwise, None.
"""
self.cursor.execute(
f"select id, email, create_date from reset_tokens where token = %s",
(token,),
)
query = sql.SQL(
"select id, email, create_date from reset_tokens where token = {}"
).format(sql.Literal(token))
self.cursor.execute(query)
row = self.cursor.fetchone()
if row is None:
return None
Expand All @@ -137,9 +136,10 @@ def add_reset_token(self, email: str, token: str):
:param email: The email to associate with the reset token.
:param token: The reset token to add.
"""
self.cursor.execute(
f"insert into reset_tokens (email, token) values (%s, %s)", (email, token)
)
query = sql.SQL(
"insert into reset_tokens (email, token) values ({}, {})"
).format(sql.Literal(email), sql.Literal(token))
self.cursor.execute(query)

def delete_reset_token(self, email: str, token: str):
"""
Expand All @@ -148,11 +148,14 @@ def delete_reset_token(self, email: str, token: str):
:param email: The email associated with the reset token to delete.
:param token: The reset token to delete.
"""
self.cursor.execute(
f"delete from reset_tokens where email = %s and token = %s", (email, token)
)
query = sql.SQL(
"delete from reset_tokens where email = {} and token = {}"
).format(sql.Literal(email), sql.Literal(token))
self.cursor.execute(query)

SessionTokenInfo = namedtuple("SessionTokenInfo", ["id", "email", "expiration_date"])
SessionTokenInfo = namedtuple(
"SessionTokenInfo", ["id", "email", "expiration_date"]
)

def get_session_token_info(self, api_key: str) -> Optional[SessionTokenInfo]:
"""
Expand All @@ -161,10 +164,10 @@ def get_session_token_info(self, api_key: str) -> Optional[SessionTokenInfo]:
:param api_key: The session token to check.
:return: SessionTokenInfo if the token exists; otherwise, None.
"""
self.cursor.execute(
f"select id, email, expiration_date from session_tokens where token = %s",
(api_key,),
)
query = sql.SQL(
"select id, email, expiration_date from session_tokens where token = {}"
).format(sql.Literal(api_key))
self.cursor.execute(query)
row = self.cursor.fetchone()
if row is None:
return None
Expand All @@ -178,10 +181,10 @@ def get_role_by_api_key(self, api_key: str) -> Optional[RoleInfo]:
:param api_key: The api key to check.
:return: RoleInfo if the token exists; otherwise, None.
"""
self.cursor.execute(
f"select id, role from users where api_key = %s",
(api_key,),
query = sql.SQL("select id, role from users where api_key = {}").format(
sql.Literal(api_key)
)
self.cursor.execute(query)
row = self.cursor.fetchone()
if row is None:
return None
Expand All @@ -195,7 +198,9 @@ def get_role_by_email(self, email: str) -> RoleInfo:
:raises UserNotFoundError: If no user is found.
:return: RoleInfo namedtuple containing the user's role.
"""
self.cursor.execute(f"select role from users where email = %s", (email,))
query = sql.SQL("select role from users where email = {}")
query = query.format(sql.Literal(email))
self.cursor.execute(query)
results = self.cursor.fetchone()
if len(results) == 0:
raise UserNotFoundError(email)
Expand All @@ -208,10 +213,10 @@ def update_user_api_key(self, api_key: str, user_id: int):
:param api_key: The api key to check.
:param user_id: The user id to update.
"""
self.cursor.execute(
f"update users set api_key = %s where id = %s",
(api_key, user_id),
)
query = sql.SQL("update users set api_key = {} where id = {}")
query = query.format(sql.Literal(api_key), sql.Literal(user_id))

self.cursor.execute(query)

def get_data_source_by_id(self, data_source_id: str) -> Optional[tuple[Any, ...]]:
"""
Expand Down Expand Up @@ -256,39 +261,14 @@ def get_needs_identification_data_sources(self) -> list[tuple[Any, ...]]:
self.cursor.execute(sql_query)
return self.cursor.fetchall()

def create_new_data_source_query(self, data: dict) -> str:
"""
Creates a query to add a new data source to the database.
:param data: A dictionary containing the data source details.
"""
column_names = ""
column_values = ""
for key, value in data.items():
if key not in RESTRICTED_COLUMNS:
column_names += f"{key}, "
if type(value) == str:
column_values += f"'{value}', "
else:
column_values += f"{value}, "

now = datetime.now().strftime("%Y-%m-%d")
airtable_uid = str(uuid.uuid4())

column_names += "approval_status, url_status, data_source_created, airtable_uid"
column_values += f"False, '[\"ok\"]', '{now}', '{airtable_uid}'"

sql_query = f"INSERT INTO data_sources ({column_names}) VALUES ({column_values}) RETURNING *"

return sql_query

def add_new_data_source(self, data: dict) -> None:
"""
Processes a request to add a new data source.
:param data: A dictionary containing the updated data source details.
"""
sql_query = self.create_new_data_source_query(data)
sql_query = DynamicQueryConstructor.create_new_data_source_query(data)
self.cursor.execute(sql_query)

def update_data_source(self, data: dict, data_source_id: str) -> None:
Expand Down Expand Up @@ -569,10 +549,13 @@ def add_quick_search_log(

def add_new_access_token(self, token: str, expiration: datetime) -> None:
"""Inserts a new access token into the database."""
self.cursor.execute(
f"insert into access_tokens (token, expiration_date) values (%s, %s)",
(token, expiration),
query = sql.SQL(
"insert into access_tokens (token, expiration_date) values ({token}, {expiration})"
).format(
token=sql.Literal(token),
expiration=sql.Literal(expiration),
)
self.cursor.execute(query)

UserInfo = namedtuple("UserInfo", ["id", "password_digest", "api_key"])

Expand All @@ -584,9 +567,10 @@ def get_user_info(self, email: str) -> UserInfo:
:raise UserNotFoundError: If no user is found.
:return: UserInfo namedtuple containing the user's information.
"""
self.cursor.execute(
f"select id, password_digest, api_key from users where email = %s", (email,)
)
query = sql.SQL(
"select id, password_digest, api_key from users where email = {email}"
).format(email=sql.Literal(email))
self.cursor.execute(query)
results = self.cursor.fetchone()
if results is None:
raise UserNotFoundError(email)
Expand All @@ -605,10 +589,14 @@ def add_new_session_token(self, session_token, email: str, expiration) -> None:
:param email: User's email.
:param expiration: The session token's expiration.
"""
self.cursor.execute(
f"insert into session_tokens (token, email, expiration_date) values (%s, %s, %s)",
(session_token, email, expiration),
query = sql.SQL(
"insert into session_tokens (token, email, expiration_date) values ({token}, {email}, {expiration})"
).format(
token=sql.Literal(session_token),
email=sql.Literal(email),
expiration=sql.Literal(expiration),
)
self.cursor.execute(query)

SessionTokenUserData = namedtuple("SessionTokenUserData", ["id", "email"])

Expand All @@ -618,9 +606,10 @@ def delete_session_token(self, old_token: str) -> None:
:param old_token: The session token.
"""
self.cursor.execute(
f"delete from session_tokens where token = %s", (old_token,)
query = sql.SQL("delete from session_tokens where token = {token}").format(
token=sql.Literal(old_token)
)
self.cursor.execute(query)

AccessToken = namedtuple("AccessToken", ["id", "token"])

Expand All @@ -632,15 +621,15 @@ def get_access_token(self, api_key: str) -> AccessToken:
:raise AccessTokenNotFoundError: If the access token is not found.
:returns: AccessToken namedtuple with the ID and the access token.
"""
self.cursor.execute(
f"select id, token from access_tokens where token = %s", (api_key,)
)
query = sql.SQL(
"select id, token from access_tokens where token = {token}"
).format(token=sql.Literal(api_key))
self.cursor.execute(query)
results = self.cursor.fetchone()
if not results:
raise AccessTokenNotFoundError("Access token not found")
return self.AccessToken(id=results[0], token=results[1])

def delete_expired_access_tokens(self) -> None:
"""Deletes all expired access tokens from the database."""
self.cursor.execute(f"delete from access_tokens where expiration_date < NOW()")

self.cursor.execute("delete from access_tokens where expiration_date < NOW()")
Loading

0 comments on commit 9303fac

Please sign in to comment.