From 9303fac46a5e18bc2b5805ea9431c25808adba38 Mon Sep 17 00:00:00 2001 From: maxachis Date: Thu, 11 Jul 2024 07:22:01 -0400 Subject: [PATCH] Address security issues in SQL creation. --- database_client/constants.py | 7 + database_client/database_client.py | 155 +++++++++---------- database_client/dynamic_query_constructor.py | 61 +++++++- 3 files changed, 132 insertions(+), 91 deletions(-) diff --git a/database_client/constants.py b/database_client/constants.py index 6a201f8e..6cc6bebf 100644 --- a/database_client/constants.py +++ b/database_client/constants.py @@ -86,3 +86,10 @@ "lat", "lng", ] +RESTRICTED_COLUMNS = [ + "rejection_note", + "data_source_request", + "approval_status", + "airtable_uid", + "airtable_source_last_modified", +] \ No newline at end of file diff --git a/database_client/database_client.py b/database_client/database_client.py index 1accce2b..66c2f4f8 100644 --- a/database_client/database_client.py +++ b/database_client/database_client.py @@ -6,7 +6,7 @@ import uuid import psycopg2 -from psycopg2.extras import DictCursor +from psycopg2 import sql from database_client.dynamic_query_constructor import DynamicQueryConstructor from middleware.custom_exceptions import ( @@ -15,14 +15,6 @@ AccessTokenNotFoundError, ) -RESTRICTED_COLUMNS = [ - "rejection_note", - "data_source_request", - "approval_status", - "airtable_uid", - "airtable_source_last_modified", -] - DATA_SOURCES_MAP_COLUMN = [ "data_source_id", "name", @@ -70,8 +62,6 @@ """ - - class DatabaseClient: def __init__(self, cursor: psycopg2.extensions.cursor): @@ -84,10 +74,13 @@ def add_new_user(self, email: str, password_digest: str): :param password_digest: :return: """ - self.cursor.execute( - f"insert into users (email, password_digest) values (%s, %s)", - (email, password_digest), + query = sql.SQL( + "insert into users (email, password_digest) values ({}, {})" + ).format( + sql.Literal(email), + sql.Literal(password_digest), ) + self.cursor.execute(query) def get_user_id(self, email: str) -> Optional[int]: """ @@ -95,7 +88,10 @@ def get_user_id(self, email: str) -> Optional[int]: :param email: :return: """ - self.cursor.execute(f"select id from users where email = %s", (email,)) + query = sql.SQL("select id from users where email = {}").format( + sql.Literal(email) + ) + self.cursor.execute(query) if self.cursor.rowcount == 0: return None return self.cursor.fetchone()[0] @@ -107,10 +103,13 @@ def set_user_password_digest(self, email: str, password_digest: str): :param password_digest: :return: """ - self.cursor.execute( - f"update users set password_digest = %s where email = %s", - (password_digest, email), + query = sql.SQL( + "update users set password_digest = {} where email = {}" + ).format( + sql.Literal(password_digest), + sql.Literal(email), ) + self.cursor.execute(query) ResetTokenInfo = namedtuple("ResetTokenInfo", ["id", "email", "create_date"]) @@ -121,10 +120,10 @@ def get_reset_token_info(self, token: str) -> Optional[ResetTokenInfo]: :param token: The reset token to check. :return: ResetTokenInfo if the token exists; otherwise, None. """ - self.cursor.execute( - f"select id, email, create_date from reset_tokens where token = %s", - (token,), - ) + query = sql.SQL( + "select id, email, create_date from reset_tokens where token = {}" + ).format(sql.Literal(token)) + self.cursor.execute(query) row = self.cursor.fetchone() if row is None: return None @@ -137,9 +136,10 @@ def add_reset_token(self, email: str, token: str): :param email: The email to associate with the reset token. :param token: The reset token to add. """ - self.cursor.execute( - f"insert into reset_tokens (email, token) values (%s, %s)", (email, token) - ) + query = sql.SQL( + "insert into reset_tokens (email, token) values ({}, {})" + ).format(sql.Literal(email), sql.Literal(token)) + self.cursor.execute(query) def delete_reset_token(self, email: str, token: str): """ @@ -148,11 +148,14 @@ def delete_reset_token(self, email: str, token: str): :param email: The email associated with the reset token to delete. :param token: The reset token to delete. """ - self.cursor.execute( - f"delete from reset_tokens where email = %s and token = %s", (email, token) - ) + query = sql.SQL( + "delete from reset_tokens where email = {} and token = {}" + ).format(sql.Literal(email), sql.Literal(token)) + self.cursor.execute(query) - SessionTokenInfo = namedtuple("SessionTokenInfo", ["id", "email", "expiration_date"]) + SessionTokenInfo = namedtuple( + "SessionTokenInfo", ["id", "email", "expiration_date"] + ) def get_session_token_info(self, api_key: str) -> Optional[SessionTokenInfo]: """ @@ -161,10 +164,10 @@ def get_session_token_info(self, api_key: str) -> Optional[SessionTokenInfo]: :param api_key: The session token to check. :return: SessionTokenInfo if the token exists; otherwise, None. """ - self.cursor.execute( - f"select id, email, expiration_date from session_tokens where token = %s", - (api_key,), - ) + query = sql.SQL( + "select id, email, expiration_date from session_tokens where token = {}" + ).format(sql.Literal(api_key)) + self.cursor.execute(query) row = self.cursor.fetchone() if row is None: return None @@ -178,10 +181,10 @@ def get_role_by_api_key(self, api_key: str) -> Optional[RoleInfo]: :param api_key: The api key to check. :return: RoleInfo if the token exists; otherwise, None. """ - self.cursor.execute( - f"select id, role from users where api_key = %s", - (api_key,), + query = sql.SQL("select id, role from users where api_key = {}").format( + sql.Literal(api_key) ) + self.cursor.execute(query) row = self.cursor.fetchone() if row is None: return None @@ -195,7 +198,9 @@ def get_role_by_email(self, email: str) -> RoleInfo: :raises UserNotFoundError: If no user is found. :return: RoleInfo namedtuple containing the user's role. """ - self.cursor.execute(f"select role from users where email = %s", (email,)) + query = sql.SQL("select role from users where email = {}") + query = query.format(sql.Literal(email)) + self.cursor.execute(query) results = self.cursor.fetchone() if len(results) == 0: raise UserNotFoundError(email) @@ -208,10 +213,10 @@ def update_user_api_key(self, api_key: str, user_id: int): :param api_key: The api key to check. :param user_id: The user id to update. """ - self.cursor.execute( - f"update users set api_key = %s where id = %s", - (api_key, user_id), - ) + query = sql.SQL("update users set api_key = {} where id = {}") + query = query.format(sql.Literal(api_key), sql.Literal(user_id)) + + self.cursor.execute(query) def get_data_source_by_id(self, data_source_id: str) -> Optional[tuple[Any, ...]]: """ @@ -256,31 +261,6 @@ def get_needs_identification_data_sources(self) -> list[tuple[Any, ...]]: self.cursor.execute(sql_query) return self.cursor.fetchall() - def create_new_data_source_query(self, data: dict) -> str: - """ - Creates a query to add a new data source to the database. - - :param data: A dictionary containing the data source details. - """ - column_names = "" - column_values = "" - for key, value in data.items(): - if key not in RESTRICTED_COLUMNS: - column_names += f"{key}, " - if type(value) == str: - column_values += f"'{value}', " - else: - column_values += f"{value}, " - - now = datetime.now().strftime("%Y-%m-%d") - airtable_uid = str(uuid.uuid4()) - - column_names += "approval_status, url_status, data_source_created, airtable_uid" - column_values += f"False, '[\"ok\"]', '{now}', '{airtable_uid}'" - - sql_query = f"INSERT INTO data_sources ({column_names}) VALUES ({column_values}) RETURNING *" - - return sql_query def add_new_data_source(self, data: dict) -> None: """ @@ -288,7 +268,7 @@ def add_new_data_source(self, data: dict) -> None: :param data: A dictionary containing the updated data source details. """ - sql_query = self.create_new_data_source_query(data) + sql_query = DynamicQueryConstructor.create_new_data_source_query(data) self.cursor.execute(sql_query) def update_data_source(self, data: dict, data_source_id: str) -> None: @@ -569,10 +549,13 @@ def add_quick_search_log( def add_new_access_token(self, token: str, expiration: datetime) -> None: """Inserts a new access token into the database.""" - self.cursor.execute( - f"insert into access_tokens (token, expiration_date) values (%s, %s)", - (token, expiration), + query = sql.SQL( + "insert into access_tokens (token, expiration_date) values ({token}, {expiration})" + ).format( + token=sql.Literal(token), + expiration=sql.Literal(expiration), ) + self.cursor.execute(query) UserInfo = namedtuple("UserInfo", ["id", "password_digest", "api_key"]) @@ -584,9 +567,10 @@ def get_user_info(self, email: str) -> UserInfo: :raise UserNotFoundError: If no user is found. :return: UserInfo namedtuple containing the user's information. """ - self.cursor.execute( - f"select id, password_digest, api_key from users where email = %s", (email,) - ) + query = sql.SQL( + "select id, password_digest, api_key from users where email = {email}" + ).format(email=sql.Literal(email)) + self.cursor.execute(query) results = self.cursor.fetchone() if results is None: raise UserNotFoundError(email) @@ -605,10 +589,14 @@ def add_new_session_token(self, session_token, email: str, expiration) -> None: :param email: User's email. :param expiration: The session token's expiration. """ - self.cursor.execute( - f"insert into session_tokens (token, email, expiration_date) values (%s, %s, %s)", - (session_token, email, expiration), + query = sql.SQL( + "insert into session_tokens (token, email, expiration_date) values ({token}, {email}, {expiration})" + ).format( + token=sql.Literal(session_token), + email=sql.Literal(email), + expiration=sql.Literal(expiration), ) + self.cursor.execute(query) SessionTokenUserData = namedtuple("SessionTokenUserData", ["id", "email"]) @@ -618,9 +606,10 @@ def delete_session_token(self, old_token: str) -> None: :param old_token: The session token. """ - self.cursor.execute( - f"delete from session_tokens where token = %s", (old_token,) + query = sql.SQL("delete from session_tokens where token = {token}").format( + token=sql.Literal(old_token) ) + self.cursor.execute(query) AccessToken = namedtuple("AccessToken", ["id", "token"]) @@ -632,9 +621,10 @@ def get_access_token(self, api_key: str) -> AccessToken: :raise AccessTokenNotFoundError: If the access token is not found. :returns: AccessToken namedtuple with the ID and the access token. """ - self.cursor.execute( - f"select id, token from access_tokens where token = %s", (api_key,) - ) + query = sql.SQL( + "select id, token from access_tokens where token = {token}" + ).format(token=sql.Literal(api_key)) + self.cursor.execute(query) results = self.cursor.fetchone() if not results: raise AccessTokenNotFoundError("Access token not found") @@ -642,5 +632,4 @@ def get_access_token(self, api_key: str) -> AccessToken: def delete_expired_access_tokens(self) -> None: """Deletes all expired access tokens from the database.""" - self.cursor.execute(f"delete from access_tokens where expiration_date < NOW()") - + self.cursor.execute("delete from access_tokens where expiration_date < NOW()") diff --git a/database_client/dynamic_query_constructor.py b/database_client/dynamic_query_constructor.py index be74348e..0154fc58 100644 --- a/database_client/dynamic_query_constructor.py +++ b/database_client/dynamic_query_constructor.py @@ -5,8 +5,12 @@ from psycopg2 import sql -from database_client.constants import AGENCY_APPROVED_COLUMNS, DATA_SOURCES_APPROVED_COLUMNS, \ - RESTRICTED_DATA_SOURCE_COLUMNS +from database_client.constants import ( + AGENCY_APPROVED_COLUMNS, + DATA_SOURCES_APPROVED_COLUMNS, + RESTRICTED_DATA_SOURCE_COLUMNS, + RESTRICTED_COLUMNS, +) TableColumn = namedtuple("TableColumn", ["table", "column"]) TableColumnAlias = namedtuple("TableColumnAlias", ["table", "column", "alias"]) @@ -21,7 +25,8 @@ class DynamicQueryConstructor: @staticmethod def build_fields( - columns_only: list[TableColumn], columns_and_alias: Optional[list[TableColumnAlias]] = None + columns_only: list[TableColumn], + columns_and_alias: Optional[list[TableColumnAlias]] = None, ): # Process columns without alias fields_only = [ @@ -140,10 +145,10 @@ def build_needs_identification_data_source_query(): return sql_query @staticmethod - def zip_needs_identification_data_source_results(results: list[tuple]) -> list[dict]: - return [ - dict(zip(DATA_SOURCES_APPROVED_COLUMNS, result)) for result in results - ] + def zip_needs_identification_data_source_results( + results: list[tuple], + ) -> list[dict]: + return [dict(zip(DATA_SOURCES_APPROVED_COLUMNS, result)) for result in results] @staticmethod def create_data_source_update_query( @@ -225,4 +230,44 @@ def create_new_data_source_query(data: dict) -> sql.Composed: """ ).format(columns=columns_sql, values=values_sql) - return query \ No newline at end of file + return query + + @staticmethod + def create_new_data_source_query(data: dict) -> sql.Composed: + """ + Creates a query to add a new data source to the database. + + :param data: A dictionary containing the data source details. + """ + columns = [] + values = [] + for key, value in data.items(): + if key not in RESTRICTED_COLUMNS: + columns.append(sql.Identifier(key)) + values.append(sql.Literal(value)) + + now = datetime.now().strftime("%Y-%m-%d") + airtable_uid = str(uuid.uuid4()) + + columns.extend( + [ + sql.Identifier("approval_status"), + sql.Identifier("url_status"), + sql.Identifier("data_source_created"), + sql.Identifier("airtable_uid"), + ] + ) + values.extend( + [ + sql.Literal(False), + sql.Literal(["ok"]), + sql.Literal(now), + sql.Literal(airtable_uid), + ] + ) + + query = sql.SQL("INSERT INTO data_sources ({}) VALUES ({}) RETURNING *").format( + sql.SQL(", ").join(columns), sql.SQL(", ").join(values) + ) + + return query