From 93bd6e275dc00c4dc7f397e474995b63c9e0b377 Mon Sep 17 00:00:00 2001 From: elipaz Date: Fri, 1 Nov 2024 20:50:49 +0200 Subject: [PATCH 01/38] Add Client code files --- client/Communicator.py | 36 ++++++++++++++++++++++++++++++++++++ client/Controller.py | 10 ++++++++++ client/View.py | 10 ++++++++++ client/main.py | 7 +++++++ 4 files changed, 63 insertions(+) create mode 100644 client/Communicator.py create mode 100644 client/Controller.py create mode 100644 client/View.py create mode 100644 client/main.py diff --git a/client/Communicator.py b/client/Communicator.py new file mode 100644 index 0000000..04eda8e --- /dev/null +++ b/client/Communicator.py @@ -0,0 +1,36 @@ +import socket + +PORT = 5000 +ADDRESS = 'localhost' +HOST = ADDRESS + +class Communicator: + def __init__(self): + self._host = ADDRESS + self._port = PORT + self._socket = None + + # self.setup() + + def setup(self): + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.connect((self._host, self._port)) + + def send_message(self, message): + if not self._socket: + raise RuntimeError("Socket not set up. Call setup method first.") + + message_bytes = message.encode('utf-8') + self._socket.send(message_bytes) + + def receive_message(self): + if not self._socket: + raise RuntimeError("Socket not set up. Call setup method first.") + + message_bytes = self._socket.recv(1024) + return message_bytes.decode('utf-8') + + def close(self): + if self._socket: + self._socket.close() + self._socket = None diff --git a/client/Controller.py b/client/Controller.py new file mode 100644 index 0000000..b63f4f2 --- /dev/null +++ b/client/Controller.py @@ -0,0 +1,10 @@ +from View import Viewer +from Communicator import Communicator + +class Controller: + def __init__(self): + self._view = Viewer() + self._communicator = Communicator() + + def run(self): + self._view.run() diff --git a/client/View.py b/client/View.py new file mode 100644 index 0000000..04b2b0a --- /dev/null +++ b/client/View.py @@ -0,0 +1,10 @@ +import tkinter as tk + +class Viewer: + def __init__(self): + self.root = tk.Tk() + self.root.title("My Application") + self.root.geometry("800x600") + + def run(self): + self.root.mainloop() diff --git a/client/main.py b/client/main.py new file mode 100644 index 0000000..f39593b --- /dev/null +++ b/client/main.py @@ -0,0 +1,7 @@ +from Controller import Controller + +def main(): + Controller().run() + +if __name__ == "__main__": + main() From 94c2d2bea3c2ad912444fd3cba2ea4a4e9bf6daf Mon Sep 17 00:00:00 2001 From: elipaz Date: Sun, 3 Nov 2024 15:21:29 +0200 Subject: [PATCH 02/38] Add logger --- client/Controller.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/client/Controller.py b/client/Controller.py index b63f4f2..98a9ea7 100644 --- a/client/Controller.py +++ b/client/Controller.py @@ -1,10 +1,42 @@ from View import Viewer from Communicator import Communicator +import os +import logging +from datetime import datetime + +LOG_DIR = "client_logs" class Controller: def __init__(self): self._view = Viewer() self._communicator = Communicator() + self._logger = None + self._logger_setup() def run(self): - self._view.run() + self._logger.info("Starting application") + + try: + self._view.run() + + except Exception as e: + self._logger.error(f"Error during execution: {str(e)}", exc_info=True) + raise + + def _logger_setup(self): + if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR) + + log_file = os.path.join(LOG_DIR, f"client_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") + + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler() + ] + ) + + self._logger = logging.getLogger(__name__) + self._logger.info("Logger setup complete") From e57f88cf6f0d1aa0285d2b23d86501cf5705504f Mon Sep 17 00:00:00 2001 From: elipaz Date: Sun, 3 Nov 2024 23:26:20 +0200 Subject: [PATCH 03/38] Add tests --- client/__init__.py | 0 client/main.py | 2 +- client/requirments.txt | 8 +++ client/setup.py | 7 ++ client/{ => src}/Communicator.py | 34 +++++----- client/{ => src}/Controller.py | 33 +++++---- client/{ => src}/View.py | 8 ++- client/src/__init__.py | 0 client/tests/__init__.py | 0 client/tests/test_communicator.py | 76 +++++++++++++++++++++ client/tests/test_controller.py | 107 ++++++++++++++++++++++++++++++ client/tests/test_view.py | 24 +++++++ 12 files changed, 265 insertions(+), 34 deletions(-) create mode 100644 client/__init__.py create mode 100644 client/requirments.txt create mode 100644 client/setup.py rename client/{ => src}/Communicator.py (65%) rename client/{ => src}/Controller.py (51%) rename client/{ => src}/View.py (53%) create mode 100644 client/src/__init__.py create mode 100644 client/tests/__init__.py create mode 100644 client/tests/test_communicator.py create mode 100644 client/tests/test_controller.py create mode 100644 client/tests/test_view.py diff --git a/client/__init__.py b/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/main.py b/client/main.py index f39593b..ad5b89a 100644 --- a/client/main.py +++ b/client/main.py @@ -1,4 +1,4 @@ -from Controller import Controller +from src.Controller import Controller def main(): Controller().run() diff --git a/client/requirments.txt b/client/requirments.txt new file mode 100644 index 0000000..ea7c08e --- /dev/null +++ b/client/requirments.txt @@ -0,0 +1,8 @@ +-e git+https://github.com/pazMenachem/My_Internet.git@94c2d2bea3c2ad912444fd3cba2ea4a4e9bf6daf#egg=client&subdirectory=client +colorama==0.4.6 +exceptiongroup==1.2.2 +iniconfig==2.0.0 +packaging==24.1 +pluggy==1.5.0 +pytest==8.3.3 +tomli==2.0.2 diff --git a/client/setup.py b/client/setup.py new file mode 100644 index 0000000..37ab367 --- /dev/null +++ b/client/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="client", + packages=find_packages(), + version="0.1", +) \ No newline at end of file diff --git a/client/Communicator.py b/client/src/Communicator.py similarity index 65% rename from client/Communicator.py rename to client/src/Communicator.py index 04eda8e..4d0e2e0 100644 --- a/client/Communicator.py +++ b/client/src/Communicator.py @@ -1,36 +1,36 @@ import socket +from typing import Optional -PORT = 5000 -ADDRESS = 'localhost' -HOST = ADDRESS +PORT = 65432 +HOST = '127.0.0.1' class Communicator: - def __init__(self): - self._host = ADDRESS - self._port = PORT - self._socket = None + def __init__(self) -> None: + self._host: str = HOST + self._port: int = PORT + self._socket: Optional[socket.socket] = None # self.setup() - - def setup(self): + + def setup(self) -> None: self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.connect((self._host, self._port)) - - def send_message(self, message): + + def send_message(self, message: str) -> None: if not self._socket: raise RuntimeError("Socket not set up. Call setup method first.") - + message_bytes = message.encode('utf-8') self._socket.send(message_bytes) - - def receive_message(self): + + def receive_message(self) -> str: if not self._socket: raise RuntimeError("Socket not set up. Call setup method first.") - + message_bytes = self._socket.recv(1024) return message_bytes.decode('utf-8') - - def close(self): + + def close(self) -> None: if self._socket: self._socket.close() self._socket = None diff --git a/client/Controller.py b/client/src/Controller.py similarity index 51% rename from client/Controller.py rename to client/src/Controller.py index 98a9ea7..574ae2f 100644 --- a/client/Controller.py +++ b/client/src/Controller.py @@ -1,19 +1,23 @@ -from View import Viewer -from Communicator import Communicator -import os import logging +import os from datetime import datetime +from typing import Optional + +from .Communicator import Communicator +from .View import Viewer LOG_DIR = "client_logs" class Controller: - def __init__(self): - self._view = Viewer() - self._communicator = Communicator() - self._logger = None + def __init__(self) -> None: + self._view: Viewer = Viewer() + self._communicator: Communicator = Communicator() + self._logger: Optional[logging.Logger] = None + self._logger_setup() - def run(self): + def run(self) -> None: + """Run the controller.""" self._logger.info("Starting application") try: @@ -23,19 +27,22 @@ def run(self): self._logger.error(f"Error during execution: {str(e)}", exc_info=True) raise - def _logger_setup(self): + def _logger_setup(self) -> None: + """Set up the logger.""" if not os.path.exists(LOG_DIR): os.makedirs(LOG_DIR) - log_file = os.path.join(LOG_DIR, f"client_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log") + log_file: str = os.path.join( + LOG_DIR, f"Client_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" + ) logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler(log_file), - logging.StreamHandler() - ] + logging.StreamHandler(), + ], ) self._logger = logging.getLogger(__name__) diff --git a/client/View.py b/client/src/View.py similarity index 53% rename from client/View.py rename to client/src/View.py index 04b2b0a..2b4c293 100644 --- a/client/View.py +++ b/client/src/View.py @@ -1,10 +1,12 @@ import tkinter as tk + class Viewer: - def __init__(self): - self.root = tk.Tk() + def __init__(self) -> None: + self.root: tk.Tk = tk.Tk() self.root.title("My Application") self.root.geometry("800x600") - def run(self): + def run(self) -> None: + """Run the viewer.""" self.root.mainloop() diff --git a/client/src/__init__.py b/client/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/__init__.py b/client/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/test_communicator.py b/client/tests/test_communicator.py new file mode 100644 index 0000000..bc561a2 --- /dev/null +++ b/client/tests/test_communicator.py @@ -0,0 +1,76 @@ +import socket +from unittest import mock +from typing import Optional + +import pytest + +from src.Communicator import Communicator, HOST, PORT + + +@pytest.fixture +def communicator() -> Communicator: + """Fixture to create a Communicator instance.""" + return Communicator() + + +@mock.patch('socket.socket') +def test_setup(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test the setup method initializes and connects the socket.""" + mock_socket_instance = mock_socket_class.return_value + communicator.setup() + mock_socket_class.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) + mock_socket_instance.connect.assert_called_once_with((HOST, PORT)) + assert communicator._socket is mock_socket_instance + + +@mock.patch('socket.socket') +def test_send_message_without_setup(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test sending a message without setting up the socket raises RuntimeError.""" + with pytest.raises(RuntimeError) as exc_info: + communicator.send_message("Hello") + assert str(exc_info.value) == "Socket not set up. Call setup method first." + + +@mock.patch('socket.socket') +def test_send_message(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test sending a message successfully.""" + mock_socket_instance = mock_socket_class.return_value + communicator._socket = mock_socket_instance + + message: str = "Hello, World!" + communicator.send_message(message) + + mock_socket_instance.send.assert_called_once_with(message.encode('utf-8')) + + +@mock.patch('socket.socket') +def test_receive_message_without_setup(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test receiving a message without setting up the socket raises RuntimeError.""" + with pytest.raises(RuntimeError) as exc_info: + communicator.receive_message() + assert str(exc_info.value) == "Socket not set up. Call setup method first." + + +@mock.patch('socket.socket') +def test_receive_message(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test receiving a message successfully.""" + mock_socket_instance = mock_socket_class.return_value + communicator._socket = mock_socket_instance + + mock_socket_instance.recv.return_value = b'Hello, Client!' + message: str = communicator.receive_message() + + mock_socket_instance.recv.assert_called_once_with(1024) + assert message == 'Hello, Client!' + + +@mock.patch('socket.socket') +def test_close_socket(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test closing the socket.""" + mock_socket_instance = mock_socket_class.return_value + communicator._socket = mock_socket_instance + + communicator.close() + + mock_socket_instance.close.assert_called_once() + assert communicator._socket is None \ No newline at end of file diff --git a/client/tests/test_controller.py b/client/tests/test_controller.py new file mode 100644 index 0000000..655cc9b --- /dev/null +++ b/client/tests/test_controller.py @@ -0,0 +1,107 @@ +import os +import logging +from unittest import mock +from datetime import datetime +from typing import Optional + +import pytest + +from src.Controller import Controller +from src.View import Viewer +from src.Communicator import Communicator, HOST, PORT + + +@pytest.fixture +def controller() -> Controller: + """Fixture to create a Controller instance.""" + with mock.patch('src.Controller.Viewer'), \ + mock.patch('src.Controller.Communicator'), \ + mock.patch('src.Controller.logging'): + yield Controller() + + +@mock.patch('src.Controller.os.path.exists') +@mock.patch('src.Controller.os.makedirs') +@mock.patch('src.Controller.logging.getLogger') +@mock.patch('src.Controller.logging.basicConfig') +@mock.patch('src.Controller.datetime') +def test_logger_setup( + mock_datetime: mock.Mock, + mock_basicConfig: mock.Mock, + mock_getLogger: mock.Mock, + mock_makedirs: mock.Mock, + mock_exists: mock.Mock, + controller: Controller +) -> None: + """Test the logger setup in Controller.""" + mock_exists.return_value = False + mock_datetime.now.return_value = datetime(2023, 10, 1, 12, 0, 0) + mock_logger = mock.Mock() + mock_logger.info = mock.Mock() + mock_getLogger.return_value = mock_logger + + controller._logger_setup() + + mock_exists.assert_called_once_with("client_logs") + mock_makedirs.assert_called_once_with("client_logs") + mock_datetime.now.assert_called_once() + expected_log_file = os.path.join("client_logs", "Client_20231001_120000.log") + mock_basicConfig.assert_called_once_with( + level=mock.ANY, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + mock.ANY, + mock.ANY, + ], + ) + mock_getLogger.assert_called_once_with('src.Controller') + mock_logger.info.assert_any_call("Logger setup complete") + + +@mock.patch('src.Controller.Viewer') +@mock.patch('src.Controller.logging.Logger.info') +@mock.patch('src.Controller.logging.Logger.error') +def test_run_success( + mock_error: mock.Mock, + mock_info: mock.Mock, + mock_viewer: mock.Mock, + controller: Controller +) -> None: + """Test the run method executes successfully.""" + controller._view = mock_viewer.return_value + controller._logger = mock.Mock(spec=logging.Logger) + + controller.run() + + controller._logger.info.assert_called_with("Starting application") + controller._view.run.assert_called_once() + mock_error.assert_not_called() + + +@mock.patch('src.Controller.logging.Logger.info') +@mock.patch('src.Controller.logging.Logger.error') +def test_run_exception( + mock_error: mock.Mock, + mock_info: mock.Mock, + controller: Controller +) -> None: + """Test the run method handles exceptions properly.""" + # Setup the mock viewer instance + mock_viewer = mock.Mock() + mock_viewer.run.side_effect = Exception("Test Exception") + + controller._view = mock_viewer + + mock_logger = mock.Mock(spec=logging.Logger) + controller._logger = mock_logger + + with pytest.raises(Exception) as exc_info: + controller.run() + + assert str(exc_info.value) == "Test Exception" + mock_logger.info.assert_called_with("Starting application") + mock_logger.error.assert_called_with( + "Error during execution: Test Exception", + exc_info=True + ) + mock_viewer.run.assert_called_once() diff --git a/client/tests/test_view.py b/client/tests/test_view.py new file mode 100644 index 0000000..813399c --- /dev/null +++ b/client/tests/test_view.py @@ -0,0 +1,24 @@ +import tkinter as tk +from unittest import mock + +import pytest + +from src.View import Viewer + +@pytest.fixture +def viewer() -> Viewer: + """Fixture to create a Viewer instance.""" + with mock.patch('src.View.tk.Tk'): + yield Viewer() + + +def test_init(viewer: Viewer) -> None: + """Test the initialization of Viewer.""" + viewer.root.title.assert_called_once_with("My Application") + viewer.root.geometry.assert_called_once_with("800x600") + + +def test_run(viewer: Viewer) -> None: + """Test running the viewer's main loop.""" + viewer.run() + viewer.root.mainloop.assert_called_once() From b67a51a248501bf900336a989919e902f6f54188 Mon Sep 17 00:00:00 2001 From: elipaz Date: Mon, 4 Nov 2024 12:17:05 +0200 Subject: [PATCH 04/38] Modify Controller name to Application --- client/src/Application.py | 51 +++++++++++++++ client/tests/test_application.py | 106 +++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 client/src/Application.py create mode 100644 client/tests/test_application.py diff --git a/client/src/Application.py b/client/src/Application.py new file mode 100644 index 0000000..c6986b0 --- /dev/null +++ b/client/src/Application.py @@ -0,0 +1,51 @@ +import logging +import os +from datetime import datetime +from typing import Optional + +from .Communicator import Communicator +from .View import Viewer + +LOG_DIR = "client_logs" + +class Application: + def __init__(self) -> None: + self._view: Viewer = Viewer() + self._communicator: Communicator = Communicator() + self._logger: Optional[logging.Logger] = None + + self._logger_setup() + + def run(self) -> None: + self._logger.info("Starting application") + + try: + self._view.run() + self._communicator.run() + + except Exception as e: + self._logger.error(f"Error during execution: {str(e)}", exc_info=True) + raise + + def send_message(self, message: str) -> None: + self._communicator.send_message(message) + + def _logger_setup(self) -> None: + if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR) + + log_file: str = os.path.join( + LOG_DIR, f"Client_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" + ) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler(), + ], + ) + + self._logger = logging.getLogger(__name__) + self._logger.info("Logger setup complete") \ No newline at end of file diff --git a/client/tests/test_application.py b/client/tests/test_application.py new file mode 100644 index 0000000..f32802b --- /dev/null +++ b/client/tests/test_application.py @@ -0,0 +1,106 @@ +import os +import logging +from unittest import mock +from datetime import datetime +from typing import Optional + +import pytest + +from src.Application import Application +from src.View import Viewer +from src.Communicator import Communicator, HOST, PORT + + +@pytest.fixture +def application() -> Application: + """Fixture to create an Application instance.""" + with mock.patch('src.Application.Viewer'), \ + mock.patch('src.Application.Communicator'), \ + mock.patch('src.Application.logging'): + yield Application() + + +@mock.patch('src.Application.os.path.exists') +@mock.patch('src.Application.os.makedirs') +@mock.patch('src.Application.logging.getLogger') +@mock.patch('src.Application.logging.basicConfig') +@mock.patch('src.Application.datetime') +def test_logger_setup( + mock_datetime: mock.Mock, + mock_basicConfig: mock.Mock, + mock_getLogger: mock.Mock, + mock_makedirs: mock.Mock, + mock_exists: mock.Mock, + application: Application +) -> None: + """Test the logger setup in Application.""" + mock_exists.return_value = False + mock_datetime.now.return_value = datetime(2023, 10, 1, 12, 0, 0) + mock_logger = mock.Mock() + mock_logger.info = mock.Mock() + mock_getLogger.return_value = mock_logger + + application._logger_setup() + + mock_exists.assert_called_once_with("client_logs") + mock_makedirs.assert_called_once_with("client_logs") + mock_datetime.now.assert_called_once() + expected_log_file = os.path.join("client_logs", "Client_20231001_120000.log") + mock_basicConfig.assert_called_once_with( + level=mock.ANY, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + mock.ANY, + mock.ANY, + ], + ) + mock_getLogger.assert_called_once_with('src.Application') + mock_logger.info.assert_any_call("Logger setup complete") + + +@mock.patch('src.Application.Viewer') +@mock.patch('src.Application.logging.Logger.info') +@mock.patch('src.Application.logging.Logger.error') +def test_run_success( + mock_error: mock.Mock, + mock_info: mock.Mock, + mock_viewer: mock.Mock, + application: Application +) -> None: + """Test the run method executes successfully.""" + application._view = mock_viewer.return_value + application._logger = mock.Mock(spec=logging.Logger) + + application.run() + + application._logger.info.assert_called_with("Starting application") + application._view.run.assert_called_once() + mock_error.assert_not_called() + + +@mock.patch('src.Application.logging.Logger.info') +@mock.patch('src.Application.logging.Logger.error') +def test_run_exception( + mock_error: mock.Mock, + mock_info: mock.Mock, + application: Application +) -> None: + """Test the run method handles exceptions properly.""" + mock_viewer = mock.Mock() + mock_viewer.run.side_effect = Exception("Test Exception") + + application._view = mock_viewer + + mock_logger = mock.Mock(spec=logging.Logger) + application._logger = mock_logger + + with pytest.raises(Exception) as exc_info: + application.run() + + assert str(exc_info.value) == "Test Exception" + mock_logger.info.assert_called_with("Starting application") + mock_logger.error.assert_called_with( + "Error during execution: Test Exception", + exc_info=True + ) + mock_viewer.run.assert_called_once() \ No newline at end of file From b6f1ef155c88421e06f631010e4d62aa0c2b0f73 Mon Sep 17 00:00:00 2001 From: elipaz Date: Mon, 4 Nov 2024 15:39:44 +0200 Subject: [PATCH 05/38] Modify Code Receiving requests from View and Communicator to Application now works. --- client/main.py | 7 ++- client/src/Application.py | 110 +++++++++++++++++++++++++------------ client/src/Communicator.py | 62 ++++++++++++++++----- client/src/Controller.py | 49 ----------------- client/src/Logger.py | 45 +++++++++++++++ client/src/View.py | 90 ++++++++++++++++++++++++++++-- 6 files changed, 259 insertions(+), 104 deletions(-) delete mode 100644 client/src/Controller.py create mode 100644 client/src/Logger.py diff --git a/client/main.py b/client/main.py index ad5b89a..7ef3d01 100644 --- a/client/main.py +++ b/client/main.py @@ -1,7 +1,8 @@ -from src.Controller import Controller +from src.Application import Application -def main(): - Controller().run() +def main() -> None: + application: Application = Application() + application.run() if __name__ == "__main__": main() diff --git a/client/src/Application.py b/client/src/Application.py index c6986b0..3fa2803 100644 --- a/client/src/Application.py +++ b/client/src/Application.py @@ -1,51 +1,93 @@ -import logging -import os -from datetime import datetime -from typing import Optional - +import json +import threading from .Communicator import Communicator from .View import Viewer - -LOG_DIR = "client_logs" +from .Logger import setup_logger class Application: + """ + Main application class that coordinates communication between UI and server. + + Uses threading to handle simultaneous GUI and network operations. + + Attributes: + _logger: Logger instance for application logging + _view: Viewer instance for GUI operations + _communicator: Communicator instance for network operations + """ + def __init__(self) -> None: - self._view: Viewer = Viewer() - self._communicator: Communicator = Communicator() - self._logger: Optional[logging.Logger] = None - - self._logger_setup() + """Initialize application components.""" + self._logger = setup_logger(__name__) + self._view = Viewer(message_callback=self._handle_request) + self._communicator = Communicator(message_callback=self._handle_request) def run(self) -> None: + """ + Start the application with threaded communication handling. + + Raises: + Exception: If there's an error during startup of either component. + """ self._logger.info("Starting application") try: - self._view.run() - self._communicator.run() - + self._start_communication() + self._start_gui() + except Exception as e: self._logger.error(f"Error during execution: {str(e)}", exc_info=True) raise - - def send_message(self, message: str) -> None: - self._communicator.send_message(message) + finally: + self._cleanup() + + def _start_communication(self) -> None: + """Initialize and start the communication thread.""" + try: + self._communicator.connect() + threading.Thread( + target=self._communicator.receive_message, + daemon=True + ).start() + + self._logger.info("Communication server started successfully") + except Exception as e: + self._logger.error(f"Failed to start communication: {str(e)}") + raise - def _logger_setup(self) -> None: - if not os.path.exists(LOG_DIR): - os.makedirs(LOG_DIR) + def _start_gui(self) -> None: + """Start the GUI main loop.""" + try: + self._logger.info("Starting GUI") + self._view.run() + + except Exception as e: + self._logger.error(f"Failed to start GUI: {str(e)}") + raise - log_file: str = os.path.join( - LOG_DIR, f"Client_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" - ) + def _handle_request(self, request: str) -> None: + """ + Handle outgoing messages from the UI and Server. - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[ - logging.FileHandler(log_file), - logging.StreamHandler(), - ], - ) + Args: + request: received request from server or user input from UI. + """ + try: + self._logger.debug(f"Processing request: {request}") + + pass ## TODO: Implement request handling from server or UI. + + except json.JSONDecodeError as e: + self._logger.error(f"Invalid JSON format: {str(e)}") + raise + except Exception as e: + self._logger.error(f"Error handling request: {str(e)}") + raise - self._logger = logging.getLogger(__name__) - self._logger.info("Logger setup complete") \ No newline at end of file + def _cleanup(self) -> None: + """Clean up resources and stop threads.""" + self._logger.info("Cleaning up application resources") + if self._communicator: + self._communicator.close() + if self._view: + self._view.root.destroy() diff --git a/client/src/Communicator.py b/client/src/Communicator.py index 4d0e2e0..037373e 100644 --- a/client/src/Communicator.py +++ b/client/src/Communicator.py @@ -1,36 +1,70 @@ import socket -from typing import Optional +from typing import Optional, Callable +import json PORT = 65432 HOST = '127.0.0.1' +RECEIVE_BUFFER_SIZE = 1024 class Communicator: - def __init__(self) -> None: - self._host: str = HOST - self._port: int = PORT + def __init__(self, message_callback: Callable[[str], None]) -> None: + """ + Initialize the communicator. + + Args: + message_callback: Callback function to handle received messages. + """ + self._host = HOST + self._port = PORT self._socket: Optional[socket.socket] = None + self._message_callback = message_callback - # self.setup() - - def setup(self) -> None: + def connect(self) -> None: + """ + Establish connection to the server. + + Raises: + socket.error: If connection cannot be established. + """ self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.connect((self._host, self._port)) def send_message(self, message: str) -> None: + """ + Send a json message to the server. + + Args: + message_json: The message to send to the server. + + Raises: + RuntimeError: If socket connection is not established. + """ if not self._socket: - raise RuntimeError("Socket not set up. Call setup method first.") + raise RuntimeError("Socket not set up. Call connect method first.") + + self._socket.send(message.encode('utf-8')) + + def receive_message(self) -> None: + """Continuously receive and process messages from the socket connection. - message_bytes = message.encode('utf-8') - self._socket.send(message_bytes) + This method runs in a loop to receive messages from the socket. Each received + message is decoded from UTF-8 and passed to the message callback function. - def receive_message(self) -> str: + Raises: + RuntimeError: If socket connection is not established. + socket.error: If there's an error receiving data from the socket. + UnicodeDecodeError: If received data cannot be decoded as UTF-8. + """ if not self._socket: - raise RuntimeError("Socket not set up. Call setup method first.") + raise RuntimeError("Socket not set up. Call connect method first.") - message_bytes = self._socket.recv(1024) - return message_bytes.decode('utf-8') + while message_bytes := self._socket.recv(RECEIVE_BUFFER_SIZE): + if not message_bytes: + break + self._message_callback(message_bytes.decode('utf-8')) def close(self) -> None: + """Close the socket connection and clean up resources.""" if self._socket: self._socket.close() self._socket = None diff --git a/client/src/Controller.py b/client/src/Controller.py deleted file mode 100644 index 574ae2f..0000000 --- a/client/src/Controller.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging -import os -from datetime import datetime -from typing import Optional - -from .Communicator import Communicator -from .View import Viewer - -LOG_DIR = "client_logs" - -class Controller: - def __init__(self) -> None: - self._view: Viewer = Viewer() - self._communicator: Communicator = Communicator() - self._logger: Optional[logging.Logger] = None - - self._logger_setup() - - def run(self) -> None: - """Run the controller.""" - self._logger.info("Starting application") - - try: - self._view.run() - - except Exception as e: - self._logger.error(f"Error during execution: {str(e)}", exc_info=True) - raise - - def _logger_setup(self) -> None: - """Set up the logger.""" - if not os.path.exists(LOG_DIR): - os.makedirs(LOG_DIR) - - log_file: str = os.path.join( - LOG_DIR, f"Client_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" - ) - - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[ - logging.FileHandler(log_file), - logging.StreamHandler(), - ], - ) - - self._logger = logging.getLogger(__name__) - self._logger.info("Logger setup complete") diff --git a/client/src/Logger.py b/client/src/Logger.py new file mode 100644 index 0000000..034d9b7 --- /dev/null +++ b/client/src/Logger.py @@ -0,0 +1,45 @@ +"""Logger module for handling application-wide logging configuration.""" + +import logging +import os +from datetime import datetime +from typing import Optional + +LOG_DIR = "client_logs" +_logger: Optional[logging.Logger] = None + +def setup_logger(name: str) -> logging.Logger: + """ + Configure and return a logger instance. + + Args: + name: The name of the module requesting the logger. + + Returns: + logging.Logger: Configured logger instance. + """ + global _logger + + if _logger is not None: + return logging.getLogger(name) + + if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR) + + log_file: str = os.path.join( + LOG_DIR, f"Client_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" + ) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler(), + ], + ) + + _logger = logging.getLogger(name) + _logger.info("Logger setup complete") + + return _logger \ No newline at end of file diff --git a/client/src/View.py b/client/src/View.py index 2b4c293..72b660f 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -1,12 +1,94 @@ import tkinter as tk - +from tkinter import scrolledtext, ttk +from typing import Callable +import json class Viewer: - def __init__(self) -> None: + """ + Graphical user interface for the application. + """ + + def __init__(self, message_callback: Callable[[str], None]) -> None: + """ + Initialize the viewer window and its components. + + Args: + message_callback: Callback function to handle message sending. + """ self.root: tk.Tk = tk.Tk() - self.root.title("My Application") + self.root.title("Chat Application") self.root.geometry("800x600") + self._message_callback = message_callback + self._setup_ui() + + def _send_message(self) -> None: + """Handle the sending of messages from the input field.""" + message = self.input_field.get().strip() + if message: + message_json = json.dumps({"CODE": "100", "content": message}) + self._message_callback(message_json) + self.input_field.delete(0, tk.END) + self.display_message("You", message) def run(self) -> None: - """Run the viewer.""" + """Start the main event loop of the viewer.""" self.root.mainloop() + + def _setup_ui(self) -> None: + """Set up the UI components including text areas and buttons.""" + main_container = ttk.Frame(self.root, padding="5") + main_container.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + self.root.columnconfigure(0, weight=1) + self.root.rowconfigure(0, weight=1) + + self.message_area = scrolledtext.ScrolledText( + main_container, + wrap=tk.WORD, + width=70, + height=30 + ) + + self.message_area.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S)) + self.message_area.config(state=tk.DISABLED) + + self.input_field = ttk.Entry(main_container) + self.input_field.grid(row=1, column=0, sticky=(tk.W, tk.E)) + self.input_field.bind("", lambda e: self._send_message()) + + self.send_button = ttk.Button( + main_container, + text="Send", + command=self._send_message + ) + self.send_button.grid(row=1, column=1) + + main_container.columnconfigure(0, weight=3) + main_container.columnconfigure(1, weight=1) + main_container.rowconfigure(0, weight=1) + + ## TODO: This method won't be relevant for the final version + def display_message(self, sender: str, message: str) -> None: + """ + Display a message in the message area. + + Args: + sender: The name of the message sender. + message: The message content to display. + """ + self.message_area.config(state=tk.NORMAL) + self.message_area.insert(tk.END, f"{sender}: {message}\n") + self.message_area.see(tk.END) + self.message_area.config(state=tk.DISABLED) + + ## TODO: This method won't be relevant for the final version + def display_error(self, error_message: str) -> None: + """ + Display an error message in the message area. + + Args: + error_message: The error message to display. + """ + self.message_area.config(state=tk.NORMAL) + self.message_area.insert(tk.END, f"Error: {error_message}\n") + self.message_area.see(tk.END) + self.message_area.config(state=tk.DISABLED) From bac6fd9d89bfcc855b1c7fd42628f4eead51aa14 Mon Sep 17 00:00:00 2001 From: elipaz Date: Mon, 4 Nov 2024 15:50:33 +0200 Subject: [PATCH 06/38] Modify tests --- client/pytest.ini | 2 + client/tests/test_application.py | 173 +++++++++++++++--------------- client/tests/test_communicator.py | 80 +++++++++++--- client/tests/test_controller.py | 107 ------------------ client/tests/test_view.py | 61 +++++++++-- 5 files changed, 203 insertions(+), 220 deletions(-) create mode 100644 client/pytest.ini delete mode 100644 client/tests/test_controller.py diff --git a/client/pytest.ini b/client/pytest.ini new file mode 100644 index 0000000..a635c5c --- /dev/null +++ b/client/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = . diff --git a/client/tests/test_application.py b/client/tests/test_application.py index f32802b..ca9df06 100644 --- a/client/tests/test_application.py +++ b/client/tests/test_application.py @@ -2,105 +2,104 @@ import logging from unittest import mock from datetime import datetime -from typing import Optional +from typing import Optional, Callable import pytest from src.Application import Application from src.View import Viewer -from src.Communicator import Communicator, HOST, PORT +from src.Communicator import Communicator @pytest.fixture -def application() -> Application: +def mock_callback() -> Callable[[str], None]: + """Fixture to provide a mock callback function.""" + return mock.Mock() + + +@pytest.fixture +def application(mock_callback: Callable[[str], None]) -> Application: """Fixture to create an Application instance.""" - with mock.patch('src.Application.Viewer'), \ - mock.patch('src.Application.Communicator'), \ - mock.patch('src.Application.logging'): - yield Application() - - -@mock.patch('src.Application.os.path.exists') -@mock.patch('src.Application.os.makedirs') -@mock.patch('src.Application.logging.getLogger') -@mock.patch('src.Application.logging.basicConfig') -@mock.patch('src.Application.datetime') -def test_logger_setup( - mock_datetime: mock.Mock, - mock_basicConfig: mock.Mock, - mock_getLogger: mock.Mock, - mock_makedirs: mock.Mock, - mock_exists: mock.Mock, - application: Application -) -> None: - """Test the logger setup in Application.""" - mock_exists.return_value = False - mock_datetime.now.return_value = datetime(2023, 10, 1, 12, 0, 0) - mock_logger = mock.Mock() - mock_logger.info = mock.Mock() - mock_getLogger.return_value = mock_logger - - application._logger_setup() - - mock_exists.assert_called_once_with("client_logs") - mock_makedirs.assert_called_once_with("client_logs") - mock_datetime.now.assert_called_once() - expected_log_file = os.path.join("client_logs", "Client_20231001_120000.log") - mock_basicConfig.assert_called_once_with( - level=mock.ANY, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[ - mock.ANY, - mock.ANY, - ], - ) - mock_getLogger.assert_called_once_with('src.Application') - mock_logger.info.assert_any_call("Logger setup complete") - - -@mock.patch('src.Application.Viewer') -@mock.patch('src.Application.logging.Logger.info') -@mock.patch('src.Application.logging.Logger.error') -def test_run_success( - mock_error: mock.Mock, - mock_info: mock.Mock, - mock_viewer: mock.Mock, - application: Application -) -> None: - """Test the run method executes successfully.""" - application._view = mock_viewer.return_value - application._logger = mock.Mock(spec=logging.Logger) - - application.run() + with mock.patch('src.Application.Viewer') as mock_viewer, \ + mock.patch('src.Application.Communicator') as mock_comm, \ + mock.patch('src.Application.setup_logger') as mock_logger: + app = Application() + app._logger = mock.Mock() + return app + + +def test_init(application: Application) -> None: + """Test the initialization of Application.""" + assert hasattr(application, '_logger') + assert hasattr(application, '_view') + assert hasattr(application, '_communicator') + + +@mock.patch('src.Application.threading.Thread') +def test_start_communication(mock_thread: mock.Mock, application: Application) -> None: + """Test the communication startup.""" + application._start_communication() - application._logger.info.assert_called_with("Starting application") + application._communicator.connect.assert_called_once() + mock_thread.assert_called_once_with( + target=application._communicator.receive_message, + daemon=True + ) + mock_thread.return_value.start.assert_called_once() + + +def test_start_gui(application: Application) -> None: + """Test the GUI startup.""" + application._start_gui() application._view.run.assert_called_once() - mock_error.assert_not_called() - - -@mock.patch('src.Application.logging.Logger.info') -@mock.patch('src.Application.logging.Logger.error') -def test_run_exception( - mock_error: mock.Mock, - mock_info: mock.Mock, - application: Application -) -> None: - """Test the run method handles exceptions properly.""" - mock_viewer = mock.Mock() - mock_viewer.run.side_effect = Exception("Test Exception") - - application._view = mock_viewer + + +def test_handle_request(application: Application) -> None: + """Test request handling.""" + test_request = '{"type": "test", "content": "message"}' - mock_logger = mock.Mock(spec=logging.Logger) - application._logger = mock_logger + # Currently just testing logging as implementation is pending + application._handle_request(test_request) + application._logger.debug.assert_called_once_with(f"Processing request: {test_request}") + + +def test_cleanup(application: Application) -> None: + """Test cleanup process.""" + application._cleanup() - with pytest.raises(Exception) as exc_info: + application._communicator.close.assert_called_once() + application._view.root.destroy.assert_called_once() + + +def test_run_success(application: Application) -> None: + """Test successful application run.""" + with mock.patch.object(application, '_start_communication'), \ + mock.patch.object(application, '_start_gui'), \ + mock.patch.object(application, '_cleanup'): + application.run() + + application._start_communication.assert_called_once() + application._start_gui.assert_called_once() + application._cleanup.assert_called_once() + + +def test_run_exception(application: Application) -> None: + """Test application run with exception.""" + error_msg = "Test error" - assert str(exc_info.value) == "Test Exception" - mock_logger.info.assert_called_with("Starting application") - mock_logger.error.assert_called_with( - "Error during execution: Test Exception", - exc_info=True - ) - mock_viewer.run.assert_called_once() \ No newline at end of file + with mock.patch.object(application, '_start_communication') as mock_start_comm, \ + mock.patch.object(application, '_cleanup') as mock_cleanup: + + mock_start_comm.side_effect = Exception(error_msg) + + with pytest.raises(Exception) as exc_info: + application.run() + + assert str(exc_info.value) == error_msg + application._logger.error.assert_called_with( + f"Error during execution: {error_msg}", + exc_info=True + ) + mock_cleanup.assert_called_once() + \ No newline at end of file diff --git a/client/tests/test_communicator.py b/client/tests/test_communicator.py index bc561a2..1ed8e81 100644 --- a/client/tests/test_communicator.py +++ b/client/tests/test_communicator.py @@ -1,34 +1,52 @@ import socket from unittest import mock -from typing import Optional +from typing import Optional, Callable import pytest -from src.Communicator import Communicator, HOST, PORT +from src.Communicator import Communicator, HOST, PORT, RECEIVE_BUFFER_SIZE @pytest.fixture -def communicator() -> Communicator: +def mock_callback() -> Callable[[str], None]: + """Fixture to provide a mock callback function.""" + return mock.Mock() + + +@pytest.fixture +def communicator(mock_callback: Callable[[str], None]) -> Communicator: """Fixture to create a Communicator instance.""" - return Communicator() + return Communicator(message_callback=mock_callback) + + +def test_init(communicator: Communicator, mock_callback: Callable[[str], None]) -> None: + """Test the initialization of Communicator.""" + assert communicator._host == HOST + assert communicator._port == PORT + assert communicator._socket is None + assert communicator._message_callback == mock_callback @mock.patch('socket.socket') -def test_setup(mock_socket_class: mock.Mock, communicator: Communicator) -> None: - """Test the setup method initializes and connects the socket.""" +def test_connect(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test the connect method initializes and connects the socket.""" mock_socket_instance = mock_socket_class.return_value - communicator.setup() + communicator.connect() + mock_socket_class.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) mock_socket_instance.connect.assert_called_once_with((HOST, PORT)) assert communicator._socket is mock_socket_instance @mock.patch('socket.socket') -def test_send_message_without_setup(mock_socket_class: mock.Mock, communicator: Communicator) -> None: +def test_send_message_without_setup( + mock_socket_class: mock.Mock, + communicator: Communicator +) -> None: """Test sending a message without setting up the socket raises RuntimeError.""" with pytest.raises(RuntimeError) as exc_info: communicator.send_message("Hello") - assert str(exc_info.value) == "Socket not set up. Call setup method first." + assert str(exc_info.value) == "Socket not set up. Call connect method first." @mock.patch('socket.socket') @@ -44,24 +62,33 @@ def test_send_message(mock_socket_class: mock.Mock, communicator: Communicator) @mock.patch('socket.socket') -def test_receive_message_without_setup(mock_socket_class: mock.Mock, communicator: Communicator) -> None: +def test_receive_message_without_setup( + mock_socket_class: mock.Mock, + communicator: Communicator +) -> None: """Test receiving a message without setting up the socket raises RuntimeError.""" with pytest.raises(RuntimeError) as exc_info: communicator.receive_message() - assert str(exc_info.value) == "Socket not set up. Call setup method first." + assert str(exc_info.value) == "Socket not set up. Call connect method first." @mock.patch('socket.socket') -def test_receive_message(mock_socket_class: mock.Mock, communicator: Communicator) -> None: +def test_receive_message( + mock_socket_class: mock.Mock, + communicator: Communicator, + mock_callback: Callable[[str], None] +) -> None: """Test receiving a message successfully.""" mock_socket_instance = mock_socket_class.return_value communicator._socket = mock_socket_instance - mock_socket_instance.recv.return_value = b'Hello, Client!' - message: str = communicator.receive_message() + # Setup mock to return a message once and then empty string to break the loop + mock_socket_instance.recv.side_effect = [b'Hello, Client!', b''] + + communicator.receive_message() - mock_socket_instance.recv.assert_called_once_with(1024) - assert message == 'Hello, Client!' + mock_socket_instance.recv.assert_called_with(RECEIVE_BUFFER_SIZE) + mock_callback.assert_called_once_with('Hello, Client!') @mock.patch('socket.socket') @@ -73,4 +100,23 @@ def test_close_socket(mock_socket_class: mock.Mock, communicator: Communicator) communicator.close() mock_socket_instance.close.assert_called_once() - assert communicator._socket is None \ No newline at end of file + assert communicator._socket is None + + +@mock.patch('socket.socket') +def test_receive_message_decode_error( + mock_socket_class: mock.Mock, + communicator: Communicator, + mock_callback: Callable[[str], None] +) -> None: + """Test handling of decode errors in receive_message.""" + mock_socket_instance = mock_socket_class.return_value + communicator._socket = mock_socket_instance + + # Setup mock to return invalid UTF-8 bytes + mock_socket_instance.recv.side_effect = [bytes([0xFF, 0xFE, 0xFD]), b''] + + with pytest.raises(UnicodeDecodeError): + communicator.receive_message() + + mock_callback.assert_not_called() diff --git a/client/tests/test_controller.py b/client/tests/test_controller.py deleted file mode 100644 index 655cc9b..0000000 --- a/client/tests/test_controller.py +++ /dev/null @@ -1,107 +0,0 @@ -import os -import logging -from unittest import mock -from datetime import datetime -from typing import Optional - -import pytest - -from src.Controller import Controller -from src.View import Viewer -from src.Communicator import Communicator, HOST, PORT - - -@pytest.fixture -def controller() -> Controller: - """Fixture to create a Controller instance.""" - with mock.patch('src.Controller.Viewer'), \ - mock.patch('src.Controller.Communicator'), \ - mock.patch('src.Controller.logging'): - yield Controller() - - -@mock.patch('src.Controller.os.path.exists') -@mock.patch('src.Controller.os.makedirs') -@mock.patch('src.Controller.logging.getLogger') -@mock.patch('src.Controller.logging.basicConfig') -@mock.patch('src.Controller.datetime') -def test_logger_setup( - mock_datetime: mock.Mock, - mock_basicConfig: mock.Mock, - mock_getLogger: mock.Mock, - mock_makedirs: mock.Mock, - mock_exists: mock.Mock, - controller: Controller -) -> None: - """Test the logger setup in Controller.""" - mock_exists.return_value = False - mock_datetime.now.return_value = datetime(2023, 10, 1, 12, 0, 0) - mock_logger = mock.Mock() - mock_logger.info = mock.Mock() - mock_getLogger.return_value = mock_logger - - controller._logger_setup() - - mock_exists.assert_called_once_with("client_logs") - mock_makedirs.assert_called_once_with("client_logs") - mock_datetime.now.assert_called_once() - expected_log_file = os.path.join("client_logs", "Client_20231001_120000.log") - mock_basicConfig.assert_called_once_with( - level=mock.ANY, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[ - mock.ANY, - mock.ANY, - ], - ) - mock_getLogger.assert_called_once_with('src.Controller') - mock_logger.info.assert_any_call("Logger setup complete") - - -@mock.patch('src.Controller.Viewer') -@mock.patch('src.Controller.logging.Logger.info') -@mock.patch('src.Controller.logging.Logger.error') -def test_run_success( - mock_error: mock.Mock, - mock_info: mock.Mock, - mock_viewer: mock.Mock, - controller: Controller -) -> None: - """Test the run method executes successfully.""" - controller._view = mock_viewer.return_value - controller._logger = mock.Mock(spec=logging.Logger) - - controller.run() - - controller._logger.info.assert_called_with("Starting application") - controller._view.run.assert_called_once() - mock_error.assert_not_called() - - -@mock.patch('src.Controller.logging.Logger.info') -@mock.patch('src.Controller.logging.Logger.error') -def test_run_exception( - mock_error: mock.Mock, - mock_info: mock.Mock, - controller: Controller -) -> None: - """Test the run method handles exceptions properly.""" - # Setup the mock viewer instance - mock_viewer = mock.Mock() - mock_viewer.run.side_effect = Exception("Test Exception") - - controller._view = mock_viewer - - mock_logger = mock.Mock(spec=logging.Logger) - controller._logger = mock_logger - - with pytest.raises(Exception) as exc_info: - controller.run() - - assert str(exc_info.value) == "Test Exception" - mock_logger.info.assert_called_with("Starting application") - mock_logger.error.assert_called_with( - "Error during execution: Test Exception", - exc_info=True - ) - mock_viewer.run.assert_called_once() diff --git a/client/tests/test_view.py b/client/tests/test_view.py index 813399c..888e78b 100644 --- a/client/tests/test_view.py +++ b/client/tests/test_view.py @@ -1,24 +1,67 @@ import tkinter as tk from unittest import mock +import json +from typing import Callable import pytest from src.View import Viewer + +@pytest.fixture +def mock_callback() -> Callable[[str], None]: + """Fixture to provide a mock callback function.""" + return mock.Mock() + + @pytest.fixture -def viewer() -> Viewer: +def viewer(mock_callback: Callable[[str], None]) -> Viewer: """Fixture to create a Viewer instance.""" - with mock.patch('src.View.tk.Tk'): - yield Viewer() + with mock.patch('src.View.tk.Tk') as mock_tk: + mock_tk.return_value.title = mock.Mock() + mock_tk.return_value.geometry = mock.Mock() + return Viewer(message_callback=mock_callback) -def test_init(viewer: Viewer) -> None: +def test_init(viewer: Viewer, mock_callback: Callable[[str], None]) -> None: """Test the initialization of Viewer.""" - viewer.root.title.assert_called_once_with("My Application") + viewer.root.title.assert_called_once_with("Chat Application") viewer.root.geometry.assert_called_once_with("800x600") + assert viewer._message_callback == mock_callback + + +def test_send_message(viewer: Viewer, mock_callback: Callable[[str], None]) -> None: + """Test sending a message.""" + test_message = "Hello, World!" + viewer.input_field = mock.Mock() + viewer.input_field.get.return_value = test_message + + viewer._send_message() + + expected_json = json.dumps({"CODE": "100", "content": test_message}) + mock_callback.assert_called_once_with(expected_json) + viewer.input_field.delete.assert_called_once_with(0, tk.END) + + +def test_display_message(viewer: Viewer) -> None: + """Test displaying a message.""" + viewer.message_area = mock.Mock() + + viewer.display_message("User", "Test message") + + viewer.message_area.config.assert_any_call(state=tk.NORMAL) + viewer.message_area.insert.assert_called_once_with(tk.END, "User: Test message\n") + viewer.message_area.see.assert_called_once_with(tk.END) + viewer.message_area.config.assert_any_call(state=tk.DISABLED) -def test_run(viewer: Viewer) -> None: - """Test running the viewer's main loop.""" - viewer.run() - viewer.root.mainloop.assert_called_once() +def test_display_error(viewer: Viewer) -> None: + """Test displaying an error message.""" + viewer.message_area = mock.Mock() + + viewer.display_error("Test error") + + viewer.message_area.config.assert_any_call(state=tk.NORMAL) + viewer.message_area.insert.assert_called_once_with(tk.END, "Error: Test error\n") + viewer.message_area.see.assert_called_once_with(tk.END) + viewer.message_area.config.assert_any_call(state=tk.DISABLED) From 953460c195258dc3dddd282da908fe084bd6312e Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Mon, 4 Nov 2024 19:01:15 +0200 Subject: [PATCH 07/38] server files --- server/main.py | 5 ++ server/src/config.py | 6 ++ server/src/db_manager.py | 86 +++++++++++++++++++++++++ server/src/handlers.py | 115 ++++++++++++++++++++++++++++++++++ server/src/response_codes.py | 21 +++++++ server/src/server.py | 117 +++++++++++++++++++++++++++++++++++ 6 files changed, 350 insertions(+) create mode 100644 server/main.py create mode 100644 server/src/config.py create mode 100644 server/src/db_manager.py create mode 100644 server/src/handlers.py create mode 100644 server/src/response_codes.py create mode 100644 server/src/server.py diff --git a/server/main.py b/server/main.py new file mode 100644 index 0000000..07fe9ac --- /dev/null +++ b/server/main.py @@ -0,0 +1,5 @@ +from My_Internet.server.src.server import run +from My_Internet.server.src.config import DB_FILE + +if __name__ == '__main__': + run(DB_FILE) \ No newline at end of file diff --git a/server/src/config.py b/server/src/config.py new file mode 100644 index 0000000..9f4955a --- /dev/null +++ b/server/src/config.py @@ -0,0 +1,6 @@ +# config.py + +HOST: str = '127.0.0.1' +CLIENT_PORT: int = 65432 +KERNEL_PORT: int = 65433 +DB_FILE: str = 'my_internet.db' \ No newline at end of file diff --git a/server/src/db_manager.py b/server/src/db_manager.py new file mode 100644 index 0000000..dfdff94 --- /dev/null +++ b/server/src/db_manager.py @@ -0,0 +1,86 @@ +# db_manager.py + +import sqlite3 +from typing import List, Tuple + + +class DatabaseManager: + def __init__(self, db_file: str): + self.db_file = db_file + self.create_tables() + + def create_tables(self) -> None: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS blocked_domains ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + domain TEXT UNIQUE + ) + """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS easylist ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + entry TEXT UNIQUE + ) + """) + conn.commit() + + def add_blocked_domain(self, domain: str) -> None: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + try: + cursor.execute(""" + INSERT INTO blocked_domains (domain) + VALUES (?) + """, (domain,)) + conn.commit() + except sqlite3.IntegrityError: + print(f"Domain {domain} already exists in the database.") + + def remove_blocked_domain(self, domain: str) -> None: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + DELETE FROM blocked_domains + WHERE domain = ? + """, (domain,)) + conn.commit() + + def is_domain_blocked(self, domain: str) -> bool: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT domain FROM blocked_domains + WHERE domain = ? + """, (domain,)) + result = cursor.fetchone() + return result is not None + + def store_easylist_entries(self, entries: List[Tuple[str]]) -> None: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.executemany(""" + INSERT OR IGNORE INTO easylist (entry) + VALUES (?) + """, entries) + conn.commit() + + def is_easylist_blocked(self, domain: str) -> bool: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT 1 FROM easylist + WHERE ? GLOB '*' || entry || '*' + """, (domain,)) + result = cursor.fetchone() + return result is not None + + def clear_easylist(self) -> None: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM easylist") + conn.commit() + + def close(self) -> None: + pass \ No newline at end of file diff --git a/server/src/handlers.py b/server/src/handlers.py new file mode 100644 index 0000000..1e46cc7 --- /dev/null +++ b/server/src/handlers.py @@ -0,0 +1,115 @@ +# handlers.py +import requests +import socket +from abc import ABC, abstractmethod +from typing import Dict, Any +from My_Internet.server.src.db_manager import DatabaseManager +from response_codes import ( + SUCCESS, INVALID_REQUEST, DOMAIN_BLOCKED, + DOMAIN_NOT_FOUND, AD_BLOCK_ENABLED, + ADULT_CONTENT_BLOCKED, RESPONSE_MESSAGES +) + + +class RequestHandler(ABC): + @abstractmethod + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + pass + +FAMILY_DNS_IP = "1.1.1.3" +EASYLIST_URL = "https://easylist.to/easylist/easylist.txt" + + +class AdBlockHandler(RequestHandler): + def __init__(self, db_manager: DatabaseManager): + self.db_manager = db_manager + self.load_easylist() + + def load_easylist(self) -> None: + try: + response = requests.get(EASYLIST_URL) + response.raise_for_status() + easylist_data = response.text + self.parse_and_store_easylist(easylist_data) + except requests.exceptions.RequestException as e: + print(f"Error loading EasyList: {e}") + + def parse_and_store_easylist(self, easylist_data: str) -> None: + entries = [] + for line in easylist_data.split("\n"): + line = line.strip() + if line and not line.startswith("!"): + entries.append((line,)) + self.db_manager.clear_easylist() + self.db_manager.store_easylist_entries(entries) + + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + domain = request_data.get('domain') + if self.is_domain_blocked(domain): + return { + 'code': AD_BLOCK_ENABLED, + 'message': RESPONSE_MESSAGES[AD_BLOCK_ENABLED] + } + else: + return { + 'code': SUCCESS, + 'message': RESPONSE_MESSAGES[SUCCESS] + } + + def is_domain_blocked(self, domain: str) -> bool: + return self.db_manager.is_easylist_blocked(domain) + + +class DomainBlockHandler(RequestHandler): + def __init__(self, db_manager: DatabaseManager): + self.db_manager = db_manager + + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + domain = request_data.get('domain') + action = request_data.get('action') + + if action == 'block': + self.db_manager.add_blocked_domain(domain) + return { + 'code': DOMAIN_BLOCKED, + 'message': RESPONSE_MESSAGES[DOMAIN_BLOCKED] + } + elif action == 'unblock': + if self.db_manager.is_domain_blocked(domain): + self.db_manager.remove_blocked_domain(domain) + return { + 'code': SUCCESS, + 'message': RESPONSE_MESSAGES[SUCCESS] + } + else: + return { + 'code': DOMAIN_NOT_FOUND, + 'message': RESPONSE_MESSAGES[DOMAIN_NOT_FOUND] + } + else: + return { + 'code': INVALID_REQUEST, + 'message': RESPONSE_MESSAGES[INVALID_REQUEST] + } + +# add bolean to cheack if to turn on or off the adult content block. +class AdultContentBlockHandler(RequestHandler): + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + domain = request_data.get('domain') + if self.is_adult_content(domain): + return { + 'code': ADULT_CONTENT_BLOCKED, + 'message': RESPONSE_MESSAGES[ADULT_CONTENT_BLOCKED] + } + else: + return { + 'code': SUCCESS, + 'message': RESPONSE_MESSAGES[SUCCESS] + } + + def is_adult_content(self, domain: str) -> bool: + try: + ip_address = socket.gethostbyname(domain) + return ip_address == FAMILY_DNS_IP + except socket.gaierror: + return False diff --git a/server/src/response_codes.py b/server/src/response_codes.py new file mode 100644 index 0000000..7095c06 --- /dev/null +++ b/server/src/response_codes.py @@ -0,0 +1,21 @@ +# response_codes.py + +from typing import Dict + +# Response codes +SUCCESS: int = 200 +INVALID_REQUEST: int = 400 +DOMAIN_BLOCKED: int = 201 +DOMAIN_NOT_FOUND: int = 404 +AD_BLOCK_ENABLED: int = 202 +ADULT_CONTENT_BLOCKED: int = 203 + +# Response messages +RESPONSE_MESSAGES: Dict[int, str] = { + SUCCESS: "Request processed successfully.", + INVALID_REQUEST: "Invalid request. Please check the request format.", + DOMAIN_BLOCKED: "Domain has been successfully blocked.", + DOMAIN_NOT_FOUND: "Domain not found in the block list.", + AD_BLOCK_ENABLED: "Ad blocking has been enabled for the domain.", + ADULT_CONTENT_BLOCKED: "Adult content has been blocked for the domain." +} \ No newline at end of file diff --git a/server/src/server.py b/server/src/server.py new file mode 100644 index 0000000..548ea5e --- /dev/null +++ b/server/src/server.py @@ -0,0 +1,117 @@ +# server.py + +import asyncio +import json +from typing import Dict, Any +from My_Internet.server.src.config import HOST, CLIENT_PORT, KERNEL_PORT, DB_FILE +from My_Internet.server.src.db_manager import DatabaseManager +from My_Internet.server.src.handlers import AdBlockHandler, DomainBlockHandler, AdultContentBlockHandler +from response_codes import INVALID_REQUEST, RESPONSE_MESSAGES + + +async def handle_client( + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + db_manager: DatabaseManager +) -> None: + while True: + try: + data = await reader.readline() + if not data: + break + + request_data = json.loads(data.decode('utf-8')) + response_data = route_request(request_data, db_manager) + + writer.write(json.dumps(response_data).encode('utf-8') + b'\n') + await writer.drain() + + except ConnectionResetError: + print("Client disconnected.") + break + + writer.close() + + +async def handle_kernel( + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + db_manager: DatabaseManager +) -> None: + while True: + try: + data = await reader.readline() + if not data: + break + + request_data = json.loads(data.decode('utf-8')) + response_data = route_kernel_request(request_data, db_manager) + + writer.write(json.dumps(response_data).encode('utf-8') + b'\n') + await writer.drain() + + except ConnectionResetError: + print("Kernel module disconnected.") + break + + writer.close() + +# create request factory class , and make an instance of it to handle the request base on the data send it to the right handler. +def route_request(request_data: Dict[str, Any], db_manager: DatabaseManager) -> Dict[str, Any]: + request_type = request_data.get('type') + + if request_type == 'ad_block': + handler = AdBlockHandler(db_manager) + elif request_type == 'domain_block': + handler = DomainBlockHandler(db_manager) + elif request_type == 'adult_content_block': + handler = AdultContentBlockHandler() + else: + return { + 'code': INVALID_REQUEST, + 'message': RESPONSE_MESSAGES[INVALID_REQUEST] + } + + return handler.handle_request(request_data) + + +def route_kernel_request(request_data: Dict[str, Any], db_manager: DatabaseManager) -> Dict[str, Any]: + domain = request_data.get('domain') + + if db_manager.is_domain_blocked(domain) or db_manager.is_easylist_blocked(domain): + return {'block': True} + else: + return {'block': False} + + +async def start_server(db_manager: DatabaseManager) -> None: + client_server = await asyncio.start_server( + lambda r, w: handle_client(r, w, db_manager), + HOST, + CLIENT_PORT + ) + kernel_server = await asyncio.start_server( + lambda r, w: handle_kernel(r, w, db_manager), + HOST, + KERNEL_PORT + ) + + print(f"Client server running on {HOST}:{CLIENT_PORT}") + print(f"Kernel server running on {HOST}:{KERNEL_PORT}") + + async with client_server, kernel_server: + await asyncio.gather( + client_server.serve_forever(), + kernel_server.serve_forever() + ) + + +def run(db_file: str) -> None: + """Initialize and run the server with the given database file.""" + db_manager = DatabaseManager(db_file) + try: + asyncio.run(start_server(db_manager)) + except KeyboardInterrupt: + print("Server stopped by user.") + finally: + db_manager.close() \ No newline at end of file From 294e5bcbbb0bfe3b49839020e4bade59f2df0c9a Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Mon, 4 Nov 2024 20:04:03 +0200 Subject: [PATCH 08/38] add Factory handler class in handlers file --- server/src/handlers.py | 57 +++++++++++++++++++++++++++++++++++++++--- server/src/server.py | 30 ++++++---------------- 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/server/src/handlers.py b/server/src/handlers.py index 1e46cc7..54d2e6a 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -2,7 +2,7 @@ import requests import socket from abc import ABC, abstractmethod -from typing import Dict, Any +from typing import Dict, Any, Optional from My_Internet.server.src.db_manager import DatabaseManager from response_codes import ( SUCCESS, INVALID_REQUEST, DOMAIN_BLOCKED, @@ -92,7 +92,7 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: 'message': RESPONSE_MESSAGES[INVALID_REQUEST] } -# add bolean to cheack if to turn on or off the adult content block. + class AdultContentBlockHandler(RequestHandler): def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: domain = request_data.get('domain') @@ -112,4 +112,55 @@ def is_adult_content(self, domain: str) -> bool: ip_address = socket.gethostbyname(domain) return ip_address == FAMILY_DNS_IP except socket.gaierror: - return False + return False + + +class RequestFactory: + def __init__(self, db_manager: DatabaseManager): + """ + Initialize the RequestFactory with a database manager instance. + + Args: + db_manager: DatabaseManager instance for handling database operations + """ + self.db_manager = db_manager + # Map request types to handler creator functions + self._handlers = { + 'ad_block': lambda: AdBlockHandler(self.db_manager), + 'domain_block': lambda: DomainBlockHandler(self.db_manager), + 'adult_content_block': lambda: AdultContentBlockHandler() + } + + def create_request_handler(self, request_type: str) -> Optional[RequestHandler]: + """ + Creates and returns the appropriate request handler based on request type. + + Args: + request_type: The type of request to handle + + Returns: + RequestHandler instance or None if request type is not supported + """ + handler_creator = self._handlers.get(request_type) + return handler_creator() if handler_creator else None + + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """ + Routes the request to appropriate handler and processes it. + + Args: + request_data: The request data containing type and other parameters + + Returns: + Dict containing response code and message + """ + request_type = request_data.get('type') + handler = self.create_request_handler(request_type) + + if handler: + return handler.handle_request(request_data) + else: + return { + 'code': INVALID_REQUEST, + 'message': RESPONSE_MESSAGES[INVALID_REQUEST] + } \ No newline at end of file diff --git a/server/src/server.py b/server/src/server.py index 548ea5e..c461965 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -5,14 +5,14 @@ from typing import Dict, Any from My_Internet.server.src.config import HOST, CLIENT_PORT, KERNEL_PORT, DB_FILE from My_Internet.server.src.db_manager import DatabaseManager -from My_Internet.server.src.handlers import AdBlockHandler, DomainBlockHandler, AdultContentBlockHandler -from response_codes import INVALID_REQUEST, RESPONSE_MESSAGES +from My_Internet.server.src.handlers import RequestFactory + async def handle_client( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, - db_manager: DatabaseManager + request_factory: RequestFactory ) -> None: while True: try: @@ -21,7 +21,7 @@ async def handle_client( break request_data = json.loads(data.decode('utf-8')) - response_data = route_request(request_data, db_manager) + response_data = request_factory.handle_request(request_data) writer.write(json.dumps(response_data).encode('utf-8') + b'\n') await writer.drain() @@ -56,24 +56,6 @@ async def handle_kernel( writer.close() -# create request factory class , and make an instance of it to handle the request base on the data send it to the right handler. -def route_request(request_data: Dict[str, Any], db_manager: DatabaseManager) -> Dict[str, Any]: - request_type = request_data.get('type') - - if request_type == 'ad_block': - handler = AdBlockHandler(db_manager) - elif request_type == 'domain_block': - handler = DomainBlockHandler(db_manager) - elif request_type == 'adult_content_block': - handler = AdultContentBlockHandler() - else: - return { - 'code': INVALID_REQUEST, - 'message': RESPONSE_MESSAGES[INVALID_REQUEST] - } - - return handler.handle_request(request_data) - def route_kernel_request(request_data: Dict[str, Any], db_manager: DatabaseManager) -> Dict[str, Any]: domain = request_data.get('domain') @@ -85,8 +67,10 @@ def route_kernel_request(request_data: Dict[str, Any], db_manager: DatabaseManag async def start_server(db_manager: DatabaseManager) -> None: + request_factory = RequestFactory(db_manager) + client_server = await asyncio.start_server( - lambda r, w: handle_client(r, w, db_manager), + lambda r, w: handle_client(r, w, request_factory), HOST, CLIENT_PORT ) From e5c58db284446e54dee2665f68a17427de91889d Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Mon, 4 Nov 2024 20:26:13 +0200 Subject: [PATCH 09/38] change the adult content block logic, and server route kernal request function --- server/src/handlers.py | 85 +++++++++++++++++++++++++++++++++++------- server/src/server.py | 15 +++++--- 2 files changed, 81 insertions(+), 19 deletions(-) diff --git a/server/src/handlers.py b/server/src/handlers.py index 54d2e6a..5fd6fa3 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -16,7 +16,6 @@ class RequestHandler(ABC): def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: pass -FAMILY_DNS_IP = "1.1.1.3" EASYLIST_URL = "https://easylist.to/easylist/easylist.txt" @@ -94,25 +93,83 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: class AdultContentBlockHandler(RequestHandler): + # Class-level variable to track status across all instances + # We use a class variable so all instances share the same state + _is_enabled: bool = False + + def __init__(self, db_manager: DatabaseManager): + self.db_manager = db_manager + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - domain = request_data.get('domain') - if self.is_adult_content(domain): + action = request_data.get('action') + + try: + if action == 'status': + return self._get_status() + + elif action in ['enable', 'disable']: + return self._toggle_blocking(action) + + elif action == 'check': + return self._check_domain(request_data.get('domain')) + + else: + return { + 'code': INVALID_REQUEST, + 'message': RESPONSE_MESSAGES[INVALID_REQUEST] + } + + except Exception as e: + print(f"Error in adult content handler: {e}") return { - 'code': ADULT_CONTENT_BLOCKED, - 'message': RESPONSE_MESSAGES[ADULT_CONTENT_BLOCKED] + 'code': INVALID_REQUEST, + 'message': "An error occurred processing the request" } - else: + + def _get_status(self) -> Dict[str, Any]: + """Get current blocking status.""" + return { + 'code': SUCCESS, + 'message': RESPONSE_MESSAGES[SUCCESS], + 'adult_content_block': 'on' if self._is_enabled else 'off' + } + + def _toggle_blocking(self, action: str) -> Dict[str, Any]: + """Enable or disable blocking.""" + self.__class__._is_enabled = (action == 'enable') + status = 'enabled' if self._is_enabled else 'disabled' + + print(f"Adult content blocking {status}") + + return { + 'code': SUCCESS, + 'message': f"Adult content blocking has been {status}.", + 'adult_content_block': 'on' if self._is_enabled else 'off' + } + + def _check_domain(self, domain: str) -> Dict[str, Any]: + """Check if a domain should be blocked.""" + if not domain: return { - 'code': SUCCESS, - 'message': RESPONSE_MESSAGES[SUCCESS] + 'code': INVALID_REQUEST, + 'message': RESPONSE_MESSAGES[INVALID_REQUEST] } + + if self._is_enabled: + return { + 'code': ADULT_CONTENT_BLOCKED, + 'message': RESPONSE_MESSAGES[ADULT_CONTENT_BLOCKED] + } + + return { + 'code': SUCCESS, + 'message': RESPONSE_MESSAGES[SUCCESS] + } - def is_adult_content(self, domain: str) -> bool: - try: - ip_address = socket.gethostbyname(domain) - return ip_address == FAMILY_DNS_IP - except socket.gaierror: - return False + @classmethod + def is_blocking_enabled(cls) -> bool: + """Public method to check if blocking is enabled.""" + return cls._is_enabled class RequestFactory: diff --git a/server/src/server.py b/server/src/server.py index c461965..0125167 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -5,7 +5,7 @@ from typing import Dict, Any from My_Internet.server.src.config import HOST, CLIENT_PORT, KERNEL_PORT, DB_FILE from My_Internet.server.src.db_manager import DatabaseManager -from My_Internet.server.src.handlers import RequestFactory +from My_Internet.server.src.handlers import RequestFactory, AdultContentBlockHandler @@ -59,11 +59,16 @@ async def handle_kernel( def route_kernel_request(request_data: Dict[str, Any], db_manager: DatabaseManager) -> Dict[str, Any]: domain = request_data.get('domain') + categories = request_data.get('categories', []) - if db_manager.is_domain_blocked(domain) or db_manager.is_easylist_blocked(domain): - return {'block': True} - else: - return {'block': False} + # Fast checks in order of most common to least common + should_block = ( + db_manager.is_domain_blocked(domain) or + db_manager.is_easylist_blocked(domain) or + (AdultContentBlockHandler.is_blocking_enabled() and 'adult' in categories) + ) + + return {'block': should_block} async def start_server(db_manager: DatabaseManager) -> None: From f1ebd70041b62495f7948f2415287e71bf4e83e9 Mon Sep 17 00:00:00 2001 From: elipaz Date: Tue, 5 Nov 2024 16:44:38 +0200 Subject: [PATCH 10/38] Add config file and ConfigManager --- client/config.json | 19 +++++++ client/src/ConfigManager.py | 102 ++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 client/config.json create mode 100644 client/src/ConfigManager.py diff --git a/client/config.json b/client/config.json new file mode 100644 index 0000000..b662c6b --- /dev/null +++ b/client/config.json @@ -0,0 +1,19 @@ +{ + "network": { + "host": "127.0.0.1", + "port": 65432, + "receive_buffer_size": 1024 + }, + "blocked_domains": { + "example.com": true, + "ads.example.com": true + }, + "settings": { + "ad_block": "on", + "adult_block": "on" + }, + "logging": { + "level": "INFO", + "log_dir": "client_logs" + } +} \ No newline at end of file diff --git a/client/src/ConfigManager.py b/client/src/ConfigManager.py new file mode 100644 index 0000000..f7fa92a --- /dev/null +++ b/client/src/ConfigManager.py @@ -0,0 +1,102 @@ +"""Configuration management module for the application.""" + +import json +import os +from typing import Dict, Any +from .Logger import setup_logger + +DEFAULT_CONFIG = { + "network": { + "host": "127.0.0.1", + "port": 65432, + "receive_buffer_size": 1024 + }, + "blocked_domains": {}, + "settings": { + "ad_block": "off", + "adult_block": "off" + }, + "logging": { + "level": "INFO", + "log_dir": "client_logs" + } +} + +class ConfigManager: + """Manages application configuration loading and saving.""" + + def __init__(self, config_file: str = "config.json") -> None: + """ + Initialize the configuration manager. + + Args: + config_file: Path to the configuration file. + """ + self.logger = setup_logger(__name__) + self.config_file = config_file + self.config = self._load_config() + + def _load_config(self) -> Dict[str, Any]: + """ + Load configuration from JSON file. + + Returns: + Dict containing configuration settings. + """ + try: + if os.path.exists(self.config_file): + self.logger.info(f"Loading configuration from {self.config_file}") + with open(self.config_file, 'r') as f: + user_config = json.load(f) + return self._merge_configs(DEFAULT_CONFIG, user_config) + + self.logger.warning(f"Configuration file not found, using default configuration") + + except json.JSONDecodeError: + self.logger.error(f"Error decoding {self.config_file}, using default configuration") + + return DEFAULT_CONFIG.copy() + + def _merge_configs(self, default: Dict[str, Any], user: Dict[str, Any]) -> Dict[str, Any]: + """ + Recursively merge user configuration with default configuration. + + Args: + default: Default configuration dictionary + user: User configuration dictionary + + Returns: + Merged configuration dictionary + """ + result = default.copy() + + for key, value in user.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = self._merge_configs(result[key], value) + else: + result[key] = value + + return result + + def save_config(self, config: Dict[str, Any]) -> None: + """ + Save configuration to JSON file. + + Args: + config: Configuration dictionary to save + """ + try: + with open(self.config_file, 'w') as f: + json.dump(config, f, indent=4) + self.logger.info("Configuration saved successfully") + except Exception as e: + self.logger.error(f"Error saving configuration: {str(e)}") + + def get_config(self) -> Dict[str, Any]: + """ + Get the current configuration. + + Returns: + Current configuration dictionary + """ + return self.config From c4a2a4a1ccbbcf88051174857859e58abd8dd015 Mon Sep 17 00:00:00 2001 From: elipaz Date: Tue, 5 Nov 2024 16:45:57 +0200 Subject: [PATCH 11/38] Add config manager to files --- client/src/Application.py | 23 +++-- client/src/Communicator.py | 62 +++++++++---- client/src/View.py | 180 +++++++++++++++++++++++++++---------- 3 files changed, 194 insertions(+), 71 deletions(-) diff --git a/client/src/Application.py b/client/src/Application.py index 3fa2803..8a06a97 100644 --- a/client/src/Application.py +++ b/client/src/Application.py @@ -3,6 +3,7 @@ from .Communicator import Communicator from .View import Viewer from .Logger import setup_logger +from .ConfigManager import ConfigManager class Application: """ @@ -19,8 +20,9 @@ class Application: def __init__(self) -> None: """Initialize application components.""" self._logger = setup_logger(__name__) - self._view = Viewer(message_callback=self._handle_request) - self._communicator = Communicator(message_callback=self._handle_request) + self._config_manager = ConfigManager() + self._view = Viewer(config_manager=self._config_manager, message_callback=self._handle_request) + self._communicator = Communicator(config_manager=self._config_manager, message_callback=self._handle_request) def run(self) -> None: """ @@ -32,7 +34,7 @@ def run(self) -> None: self._logger.info("Starting application") try: - self._start_communication() + # self._start_communication() self._start_gui() except Exception as e: @@ -73,7 +75,7 @@ def _handle_request(self, request: str) -> None: request: received request from server or user input from UI. """ try: - self._logger.debug(f"Processing request: {request}") + self._logger.info(f"Processing request: {request}") pass ## TODO: Implement request handling from server or UI. @@ -87,7 +89,12 @@ def _handle_request(self, request: str) -> None: def _cleanup(self) -> None: """Clean up resources and stop threads.""" self._logger.info("Cleaning up application resources") - if self._communicator: - self._communicator.close() - if self._view: - self._view.root.destroy() + try: + if self._communicator: + self._communicator.close() + + if self._view and self._view.root.winfo_exists(): + self._view.root.destroy() + + except Exception as e: + self._logger.warning(f"Cleanup encountered an error: {str(e)}") diff --git a/client/src/Communicator.py b/client/src/Communicator.py index 037373e..0590ace 100644 --- a/client/src/Communicator.py +++ b/client/src/Communicator.py @@ -1,23 +1,26 @@ import socket from typing import Optional, Callable import json - -PORT = 65432 -HOST = '127.0.0.1' -RECEIVE_BUFFER_SIZE = 1024 +from .Logger import setup_logger class Communicator: - def __init__(self, message_callback: Callable[[str], None]) -> None: + def __init__(self, config_manager, message_callback: Callable[[str], None]) -> None: """ Initialize the communicator. Args: + config_manager: Configuration manager instance message_callback: Callback function to handle received messages. """ - self._host = HOST - self._port = PORT - self._socket: Optional[socket.socket] = None + self.logger = setup_logger(__name__) + self.logger.info("Initializing Communicator") + self.config = config_manager.get_config() self._message_callback = message_callback + + self._host = self.config["network"]["host"] + self._port = self.config["network"]["port"] + self._receive_buffer_size = self.config["network"]["receive_buffer_size"] + self._socket: Optional[socket.socket] = None def connect(self) -> None: """ @@ -26,8 +29,13 @@ def connect(self) -> None: Raises: socket.error: If connection cannot be established. """ - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._socket.connect((self._host, self._port)) + try: + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.connect((self._host, self._port)) + self.logger.info(f"Connected to server at {self._host}:{self._port}") + except socket.error as e: + self.logger.error(f"Failed to connect to server: {str(e)}") + raise def send_message(self, message: str) -> None: """ @@ -40,9 +48,15 @@ def send_message(self, message: str) -> None: RuntimeError: If socket connection is not established. """ if not self._socket: + self.logger.error("Attempted to send message without connection") raise RuntimeError("Socket not set up. Call connect method first.") - self._socket.send(message.encode('utf-8')) + try: + self._socket.send(message.encode('utf-8')) + self.logger.info(f"Message sent: {message}") + except Exception as e: + self.logger.error(f"Failed to send message: {str(e)}") + raise def receive_message(self) -> None: """Continuously receive and process messages from the socket connection. @@ -56,15 +70,29 @@ def receive_message(self) -> None: UnicodeDecodeError: If received data cannot be decoded as UTF-8. """ if not self._socket: + self.logger.error("Attempted to receive message without connection") raise RuntimeError("Socket not set up. Call connect method first.") - while message_bytes := self._socket.recv(RECEIVE_BUFFER_SIZE): - if not message_bytes: - break - self._message_callback(message_bytes.decode('utf-8')) + self.logger.info("Starting message receive loop") + try: + while message_bytes := self._socket.recv(self._receive_buffer_size): + if not message_bytes: + self.logger.warning("Received empty message, breaking receive loop") + break + message = message_bytes.decode('utf-8') + self.logger.info(f"Received message: {message}") + self._message_callback(message) + except Exception as e: + self.logger.error(f"Error receiving message: {str(e)}") + raise def close(self) -> None: """Close the socket connection and clean up resources.""" if self._socket: - self._socket.close() - self._socket = None + try: + self._socket.close() + self.logger.info("Socket connection closed") + except Exception as e: + self.logger.error(f"Error closing socket: {str(e)}") + finally: + self._socket = None diff --git a/client/src/View.py b/client/src/View.py index 72b660f..d9a1530 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -1,25 +1,37 @@ import tkinter as tk -from tkinter import scrolledtext, ttk -from typing import Callable +from tkinter import scrolledtext, ttk, messagebox +from typing import Callable, Dict, List, Any import json +import os +from .Logger import setup_logger +from .ConfigManager import ConfigManager + class Viewer: """ Graphical user interface for the application. """ - def __init__(self, message_callback: Callable[[str], None]) -> None: + def __init__(self, config_manager: ConfigManager, message_callback: Callable[[str], None]) -> None: """ Initialize the viewer window and its components. Args: + config_manager: Configuration manager instance message_callback: Callback function to handle message sending. """ + self.logger = setup_logger(__name__) + self.logger.info("Initializing Viewer") + self.config_manager = config_manager + self.config = config_manager.get_config() + self._message_callback = message_callback + self.root: tk.Tk = tk.Tk() - self.root.title("Chat Application") + self.root.title("Site Blocker") self.root.geometry("800x600") - self._message_callback = message_callback + self._setup_ui() + self.logger.info("Viewer initialization complete") def _send_message(self) -> None: """Handle the sending of messages from the input field.""" @@ -32,63 +44,139 @@ def _send_message(self) -> None: def run(self) -> None: """Start the main event loop of the viewer.""" + self.logger.info("Starting main event loop") self.root.mainloop() def _setup_ui(self) -> None: - """Set up the UI components including text areas and buttons.""" - main_container = ttk.Frame(self.root, padding="5") + """Set up the UI components including block controls and domain list.""" + main_container = ttk.Frame(self.root, padding="10") main_container.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) self.root.columnconfigure(0, weight=1) self.root.rowconfigure(0, weight=1) - self.message_area = scrolledtext.ScrolledText( - main_container, - wrap=tk.WORD, - width=70, - height=30 - ) - - self.message_area.grid(row=0, column=0, columnspan=2, sticky=(tk.W, tk.E, tk.N, tk.S)) - self.message_area.config(state=tk.DISABLED) + # Left side - Specific sites block + sites_frame = ttk.LabelFrame(main_container, text="Specific sites block", padding="5") + sites_frame.grid(row=0, column=0, rowspan=3, padx=5, sticky=(tk.W, tk.E, tk.N, tk.S)) + + # Domains listbox + self.domains_listbox = tk.Listbox(sites_frame, width=40, height=15) + self.domains_listbox.grid(row=0, column=0, pady=5, sticky=(tk.W, tk.E, tk.N, tk.S)) + + # Add domain entry + domain_entry_frame = ttk.Frame(sites_frame) + domain_entry_frame.grid(row=1, column=0, sticky=(tk.W, tk.E)) + + ttk.Label(domain_entry_frame, text="Add Domain:").grid(row=0, column=0, padx=5) + self.domain_entry = ttk.Entry(domain_entry_frame) + self.domain_entry.grid(row=0, column=1, padx=5, sticky=(tk.W, tk.E)) + + # Add buttons for domain management + button_frame = ttk.Frame(sites_frame) + button_frame.grid(row=2, column=0, pady=5, sticky=(tk.W, tk.E)) + + ttk.Button(button_frame, text="Add", command=self._add_domain).grid(row=0, column=0, padx=5) + ttk.Button(button_frame, text="Remove", command=self._remove_domain).grid(row=0, column=1, padx=5) + + # Bind double-click event for removing domains + self.domains_listbox.bind('', lambda e: self._remove_domain()) + + # Load saved domains into listbox + for domain in self.config["blocked_domains"].keys(): + self.domains_listbox.insert(tk.END, domain) + + # Ad Block controls + ad_frame = ttk.LabelFrame(main_container, text="Ad Block", padding="5") + ad_frame.grid(row=0, column=1, pady=10, sticky=(tk.W, tk.E)) + + self.ad_var = tk.StringVar(value=self.config["settings"]["ad_block"]) + ttk.Radiobutton(ad_frame, text="on", value="on", variable=self.ad_var).grid(row=0, column=0, padx=10) + ttk.Radiobutton(ad_frame, text="off", value="off", variable=self.ad_var).grid(row=0, column=1, padx=10) + + # Adult sites Block controls + adult_frame = ttk.LabelFrame(main_container, text="Adult sites Block", padding="5") + adult_frame.grid(row=1, column=1, pady=10, sticky=(tk.W, tk.E)) + + self.adult_var = tk.StringVar(value=self.config["settings"]["adult_block"]) + ttk.Radiobutton(adult_frame, text="on", value="on", variable=self.adult_var).grid(row=0, column=0, padx=10) + ttk.Radiobutton(adult_frame, text="off", value="off", variable=self.adult_var).grid(row=0, column=1, padx=10) + + # Bind radio button commands + self.ad_var.trace_add('write', lambda *args: self._handle_ad_block()) + self.adult_var.trace_add('write', lambda *args: self._handle_adult_block()) + + # Configure grid weights + main_container.columnconfigure(0, weight=1) + sites_frame.columnconfigure(0, weight=1) + domain_entry_frame.columnconfigure(1, weight=1) + + def _add_domain(self) -> None: + """Add a domain to the blocked sites list.""" + domain = self.domain_entry.get().strip() + if domain: + if domain not in self.config["blocked_domains"]: + self.domains_listbox.insert(tk.END, domain) + self.config["blocked_domains"][domain] = True + self.domain_entry.delete(0, tk.END) + self.config_manager.save_config(self.config) + self.logger.info(f"Domain added: {domain}") + else: + self.logger.warning(f"Attempted to add duplicate domain: {domain}") + self._show_error("Domain already exists in the list") - self.input_field = ttk.Entry(main_container) - self.input_field.grid(row=1, column=0, sticky=(tk.W, tk.E)) - self.input_field.bind("", lambda e: self._send_message()) + def _remove_domain(self) -> None: + """Remove the selected domain from the blocked sites list.""" + selection = self.domains_listbox.curselection() + if selection: + domain = self.domains_listbox.get(selection) + self.domains_listbox.delete(selection) + del self.config["blocked_domains"][domain] + self.config_manager.save_config(self.config) + self.logger.info(f"Domain removed: {domain}") + else: + self.logger.warning("Attempted to remove domain without selection") + self._show_error("Please select a domain to remove") - self.send_button = ttk.Button( - main_container, - text="Send", - command=self._send_message - ) - self.send_button.grid(row=1, column=1) + def _handle_ad_block(self) -> None: + """Handle changes to the ad block setting.""" + state = self.ad_var.get() + self.config["settings"]["ad_block"] = state + self.config_manager.save_config(self.config) + self.logger.info(f"Ad blocking state changed to: {state}") - main_container.columnconfigure(0, weight=3) - main_container.columnconfigure(1, weight=1) - main_container.rowconfigure(0, weight=1) + def _handle_adult_block(self) -> None: + """Handle changes to the adult sites block setting.""" + state = self.adult_var.get() + self.config["settings"]["adult_block"] = state + self.config_manager.save_config(self.config) + self.logger.info(f"Adult site blocking state changed to: {state}") - ## TODO: This method won't be relevant for the final version - def display_message(self, sender: str, message: str) -> None: + def _show_error(self, message: str) -> None: """ - Display a message in the message area. + Display an error message in a popup window. Args: - sender: The name of the message sender. - message: The message content to display. + message: The error message to display. """ - self.message_area.config(state=tk.NORMAL) - self.message_area.insert(tk.END, f"{sender}: {message}\n") - self.message_area.see(tk.END) - self.message_area.config(state=tk.DISABLED) + self.logger.error(f"Error message displayed: {message}") + tk.messagebox.showerror("Error", message) - ## TODO: This method won't be relevant for the final version - def display_error(self, error_message: str) -> None: + def get_blocked_domains(self) -> tuple[str, ...]: """ - Display an error message in the message area. + Get the list of currently blocked domains. - Args: - error_message: The error message to display. + Returns: + A tuple containing all blocked domains. + """ + return self.domains_listbox.get(0, tk.END) + + def get_block_settings(self) -> dict[str, str]: + """ + Get the current state of blocking settings. + + Returns: + A dictionary containing the current state of ad and adult content blocking. """ - self.message_area.config(state=tk.NORMAL) - self.message_area.insert(tk.END, f"Error: {error_message}\n") - self.message_area.see(tk.END) - self.message_area.config(state=tk.DISABLED) + return { + "ad_block": self.ad_var.get(), + "adult_block": self.adult_var.get() + } From 1dce5099a7e96fa0f1c964ea8552a6d25a9fbc5d Mon Sep 17 00:00:00 2001 From: elipaz Date: Tue, 5 Nov 2024 17:49:14 +0200 Subject: [PATCH 12/38] Add Update domain list func and Thread safety messures --- client/src/Application.py | 14 ++- client/src/View.py | 179 ++++++++++++++++++++++---------------- 2 files changed, 115 insertions(+), 78 deletions(-) diff --git a/client/src/Application.py b/client/src/Application.py index 8a06a97..054ce79 100644 --- a/client/src/Application.py +++ b/client/src/Application.py @@ -21,6 +21,8 @@ def __init__(self) -> None: """Initialize application components.""" self._logger = setup_logger(__name__) self._config_manager = ConfigManager() + self._request_lock = threading.Lock() + self._view = Viewer(config_manager=self._config_manager, message_callback=self._handle_request) self._communicator = Communicator(config_manager=self._config_manager, message_callback=self._handle_request) @@ -77,8 +79,16 @@ def _handle_request(self, request: str) -> None: try: self._logger.info(f"Processing request: {request}") - pass ## TODO: Implement request handling from server or UI. - + with self._request_lock: + match request["CODE"]: + case "50" | "51" | "52": + self._communicator.send_message(request) + case "53": + if not isinstance(request["content"], list): + self._logger.error("Invalid content format for domain list update") + return + self._view.update_domain_list(request["content"]) + except json.JSONDecodeError as e: self._logger.error(f"Invalid JSON format: {str(e)}") raise diff --git a/client/src/View.py b/client/src/View.py index d9a1530..bc2926b 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -33,6 +33,51 @@ def __init__(self, config_manager: ConfigManager, message_callback: Callable[[st self._setup_ui() self.logger.info("Viewer initialization complete") + def run(self) -> None: + """Start the main event loop of the viewer.""" + self.logger.info("Starting main event loop") + self.root.mainloop() + + def get_blocked_domains(self) -> tuple[str, ...]: + """ + Get the list of currently blocked domains. + + Returns: + A tuple containing all blocked domains. + """ + return self.domains_listbox.get(0, tk.END) + + def get_block_settings(self) -> dict[str, str]: + """ + Get the current state of blocking settings. + + Returns: + A dictionary containing the current state of ad and adult content blocking. + """ + return { + "ad_block": self.ad_var.get(), + "adult_block": self.adult_var.get() + } + + def update_domain_list(self, domains: List[str]) -> None: + """ + Update the domains listbox with a new list of domains from the server. + + Args: + domains: List of domain strings to be displayed in the listbox. + """ + self.logger.info("Updating domain list from server") + try: + self.domains_listbox.delete(0, tk.END) + + for domain in domains: + self.domains_listbox.insert(tk.END, domain) + + self.logger.info(f"Updated domain list with {len(domains)} domains") + except Exception as e: + self.logger.error(f"Error updating domain list: {str(e)}") + self._show_error("Failed to update domain list") + def _send_message(self) -> None: """Handle the sending of messages from the input field.""" message = self.input_field.get().strip() @@ -42,10 +87,64 @@ def _send_message(self) -> None: self.input_field.delete(0, tk.END) self.display_message("You", message) - def run(self) -> None: - """Start the main event loop of the viewer.""" - self.logger.info("Starting main event loop") - self.root.mainloop() + def _add_domain(self) -> None: + """Add a domain to the blocked sites list.""" + domain = self.domain_entry.get().strip() + if domain: + if domain not in self.config["blocked_domains"]: + self.domains_listbox.insert(tk.END, domain) + self.config["blocked_domains"][domain] = True + self.domain_entry.delete(0, tk.END) + self.config_manager.save_config(self.config) + self.logger.info(f"Domain added: {domain}") + + self._message_callback(json.dumps({"CODE": "53", "content": domain})) + else: + self.logger.warning(f"Attempted to add duplicate domain: {domain}") + self._show_error("Domain already exists in the list") + + def _remove_domain(self) -> None: + """Remove the selected domain from the blocked sites list.""" + selection = self.domains_listbox.curselection() + if selection: + domain = self.domains_listbox.get(selection) + self.domains_listbox.delete(selection) + del self.config["blocked_domains"][domain] + self.config_manager.save_config(self.config) + self.logger.info(f"Domain removed: {domain}") + + self._message_callback(json.dumps({"CODE": "54", "content": domain})) + else: + self.logger.warning("Attempted to remove domain without selection") + self._show_error("Please select a domain to remove") + + def _handle_ad_block(self) -> None: + """Handle changes to the ad block setting.""" + state = self.ad_var.get() + self.config["settings"]["ad_block"] = state + self.config_manager.save_config(self.config) + self.logger.info(f"Ad blocking state changed to: {state}") + + self._message_callback(json.dumps({"CODE": "50", "content": state})) + + def _handle_adult_block(self) -> None: + """Handle changes to the adult sites block setting.""" + state = self.adult_var.get() + self.config["settings"]["adult_block"] = state + self.config_manager.save_config(self.config) + self.logger.info(f"Adult site blocking state changed to: {state}") + + self._message_callback(json.dumps({"CODE": "51", "content": state})) + + def _show_error(self, message: str) -> None: + """ + Display an error message in a popup window. + + Args: + message: The error message to display. + """ + self.logger.error(f"Error message displayed: {message}") + tk.messagebox.showerror("Error", message) def _setup_ui(self) -> None: """Set up the UI components including block controls and domain list.""" @@ -108,75 +207,3 @@ def _setup_ui(self) -> None: main_container.columnconfigure(0, weight=1) sites_frame.columnconfigure(0, weight=1) domain_entry_frame.columnconfigure(1, weight=1) - - def _add_domain(self) -> None: - """Add a domain to the blocked sites list.""" - domain = self.domain_entry.get().strip() - if domain: - if domain not in self.config["blocked_domains"]: - self.domains_listbox.insert(tk.END, domain) - self.config["blocked_domains"][domain] = True - self.domain_entry.delete(0, tk.END) - self.config_manager.save_config(self.config) - self.logger.info(f"Domain added: {domain}") - else: - self.logger.warning(f"Attempted to add duplicate domain: {domain}") - self._show_error("Domain already exists in the list") - - def _remove_domain(self) -> None: - """Remove the selected domain from the blocked sites list.""" - selection = self.domains_listbox.curselection() - if selection: - domain = self.domains_listbox.get(selection) - self.domains_listbox.delete(selection) - del self.config["blocked_domains"][domain] - self.config_manager.save_config(self.config) - self.logger.info(f"Domain removed: {domain}") - else: - self.logger.warning("Attempted to remove domain without selection") - self._show_error("Please select a domain to remove") - - def _handle_ad_block(self) -> None: - """Handle changes to the ad block setting.""" - state = self.ad_var.get() - self.config["settings"]["ad_block"] = state - self.config_manager.save_config(self.config) - self.logger.info(f"Ad blocking state changed to: {state}") - - def _handle_adult_block(self) -> None: - """Handle changes to the adult sites block setting.""" - state = self.adult_var.get() - self.config["settings"]["adult_block"] = state - self.config_manager.save_config(self.config) - self.logger.info(f"Adult site blocking state changed to: {state}") - - def _show_error(self, message: str) -> None: - """ - Display an error message in a popup window. - - Args: - message: The error message to display. - """ - self.logger.error(f"Error message displayed: {message}") - tk.messagebox.showerror("Error", message) - - def get_blocked_domains(self) -> tuple[str, ...]: - """ - Get the list of currently blocked domains. - - Returns: - A tuple containing all blocked domains. - """ - return self.domains_listbox.get(0, tk.END) - - def get_block_settings(self) -> dict[str, str]: - """ - Get the current state of blocking settings. - - Returns: - A dictionary containing the current state of ad and adult content blocking. - """ - return { - "ad_block": self.ad_var.get(), - "adult_block": self.adult_var.get() - } From 298ed61e9c462d0e41b99c817e8050d3828b717b Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:50:30 +0200 Subject: [PATCH 13/38] fix RequestFactory class - send to Adult Content class the db manager --- server/src/handlers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/src/handlers.py b/server/src/handlers.py index 5fd6fa3..e817167 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import Dict, Any, Optional from My_Internet.server.src.db_manager import DatabaseManager -from response_codes import ( +from My_Internet.server.src.response_codes import ( SUCCESS, INVALID_REQUEST, DOMAIN_BLOCKED, DOMAIN_NOT_FOUND, AD_BLOCK_ENABLED, ADULT_CONTENT_BLOCKED, RESPONSE_MESSAGES @@ -185,7 +185,7 @@ def __init__(self, db_manager: DatabaseManager): self._handlers = { 'ad_block': lambda: AdBlockHandler(self.db_manager), 'domain_block': lambda: DomainBlockHandler(self.db_manager), - 'adult_content_block': lambda: AdultContentBlockHandler() + 'adult_content_block': lambda: AdultContentBlockHandler(self.db_manager) } def create_request_handler(self, request_type: str) -> Optional[RequestHandler]: From 5ecfb93b638110596ea7351231b2af1c79b6c6b1 Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:52:25 +0200 Subject: [PATCH 14/38] add tests for the handlers class --- server/src/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 server/src/__init__.py diff --git a/server/src/__init__.py b/server/src/__init__.py new file mode 100644 index 0000000..e69de29 From 24b5695e43b2b48435a221f86e0e32733c19468f Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:54:46 +0200 Subject: [PATCH 15/38] add test for handlers --- __init__.py | 0 server/__init__.py | 0 server/requirements.txt | 4 + server/tests/__init__.py | 0 server/tests/conftest.py | 62 +++++++++ server/tests/test_handlers.py | 229 ++++++++++++++++++++++++++++++++++ setup.py | 7 ++ 7 files changed, 302 insertions(+) create mode 100644 __init__.py create mode 100644 server/__init__.py create mode 100644 server/requirements.txt create mode 100644 server/tests/__init__.py create mode 100644 server/tests/conftest.py create mode 100644 server/tests/test_handlers.py create mode 100644 setup.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/__init__.py b/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 0000000..b3211cb --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,4 @@ +pytest==7.4.0 +pytest-mock==3.11.1 +pytest-asyncio==0.21.1 +pytest-cov==4.1.0 \ No newline at end of file diff --git a/server/tests/__init__.py b/server/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/conftest.py b/server/tests/conftest.py new file mode 100644 index 0000000..2a2cd28 --- /dev/null +++ b/server/tests/conftest.py @@ -0,0 +1,62 @@ +import pytest +from unittest import mock +from typing import Generator +from My_Internet.server.src.db_manager import DatabaseManager +from My_Internet.server.src.handlers import RequestFactory + +@pytest.fixture(scope="function") +def mock_db_manager() -> mock.Mock: + """Fixture to provide a mock database manager.""" + mock_db = mock.Mock(spec=DatabaseManager) + + # Setup default returns for common methods + mock_db.is_domain_blocked.return_value = False + mock_db.is_easylist_blocked.return_value = False + + return mock_db + +@pytest.fixture +def mock_requests() -> Generator[mock.Mock, None, None]: + """Fixture to mock requests library.""" + with mock.patch('My_Internet.server.src.handlers.requests') as mock_req: + # Create a mock response + mock_response = mock.Mock() + mock_response.text = "test.com\n!comment\nexample.com" + mock_req.get.return_value = mock_response + + # Reset the mock to clear any previous calls + mock_req.reset_mock() + yield mock_req + +@pytest.fixture(scope="function") +def request_factory(mock_db_manager: mock.Mock) -> RequestFactory: + """Fixture to create a RequestFactory instance.""" + return RequestFactory(mock_db_manager) + +@pytest.fixture(scope="session") +def sample_domains() -> list[str]: + """Fixture to provide test domains.""" + return [ + "example.com", + "test.com", + "sample.org" + ] + +@pytest.fixture(scope="session") +def sample_requests() -> dict: + """Fixture to provide sample request data.""" + return { + "adult_block": { + "type": "adult_content_block", + "action": "enable" + }, + "domain_block": { + "type": "domain_block", + "action": "block", + "domain": "example.com" + }, + "ad_block": { + "type": "ad_block", + "domain": "test.com" + } + } \ No newline at end of file diff --git a/server/tests/test_handlers.py b/server/tests/test_handlers.py new file mode 100644 index 0000000..3fe012d --- /dev/null +++ b/server/tests/test_handlers.py @@ -0,0 +1,229 @@ +import pytest +from unittest import mock +from typing import Dict, Any +from My_Internet.server.src.handlers import EASYLIST_URL +from My_Internet.server.src.handlers import ( + RequestHandler, + AdultContentBlockHandler, + DomainBlockHandler, + AdBlockHandler, + RequestFactory +) +from My_Internet.server.src.response_codes import ( + SUCCESS, + INVALID_REQUEST, + DOMAIN_BLOCKED, + DOMAIN_NOT_FOUND, + AD_BLOCK_ENABLED, + ADULT_CONTENT_BLOCKED, + RESPONSE_MESSAGES +) + +class TestAdultContentBlockHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock) -> AdultContentBlockHandler: + """Create handler instance and reset state.""" + handler = AdultContentBlockHandler(mock_db_manager) + # Reset class-level state before each test + AdultContentBlockHandler._is_enabled = False + return handler + + def test_init(self, handler: AdultContentBlockHandler, mock_db_manager: mock.Mock) -> None: + """Test handler initialization.""" + assert handler.db_manager == mock_db_manager + assert not handler._is_enabled + + def test_handle_enable_request(self, handler: AdultContentBlockHandler) -> None: + """Test enabling adult content blocking.""" + request_data: Dict[str, Any] = {'action': 'enable'} + response = handler.handle_request(request_data) + + assert response['code'] == SUCCESS + assert response['adult_content_block'] == 'on' + assert AdultContentBlockHandler.is_blocking_enabled() + + def test_handle_disable_request(self, handler: AdultContentBlockHandler) -> None: + """Test disabling adult content blocking.""" + AdultContentBlockHandler._is_enabled = True + request_data: Dict[str, Any] = {'action': 'disable'} + + response = handler.handle_request(request_data) + + assert response['code'] == SUCCESS + assert response['adult_content_block'] == 'off' + assert not AdultContentBlockHandler.is_blocking_enabled() + + def test_handle_check_request(self, handler: AdultContentBlockHandler) -> None: + """Test checking domain with blocking enabled.""" + AdultContentBlockHandler._is_enabled = True + request_data = {'action': 'check', 'domain': 'example.com'} + + response = handler.handle_request(request_data) + + assert response['code'] == ADULT_CONTENT_BLOCKED + + def test_handle_check_request_disabled(self, handler: AdultContentBlockHandler) -> None: + """Test checking domain with blocking disabled.""" + request_data = {'action': 'check', 'domain': 'example.com'} + + response = handler.handle_request(request_data) + + assert response['code'] == SUCCESS + + def test_handle_invalid_request(self, handler: AdultContentBlockHandler) -> None: + """Test handling invalid requests.""" + invalid_requests = [ + {'action': 'invalid_action'}, + {'action': 'check'}, # Missing domain + {} # Empty request + ] + + for request_data in invalid_requests: + response = handler.handle_request(request_data) + assert response['code'] == INVALID_REQUEST + + +class TestDomainBlockHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock) -> DomainBlockHandler: + """Create handler instance.""" + return DomainBlockHandler(mock_db_manager) + + def test_block_domain(self, handler: DomainBlockHandler, sample_domains: list[str]) -> None: + """Test blocking a domain.""" + domain = sample_domains[0] + response = handler.handle_request({ + 'action': 'block', + 'domain': domain + }) + + handler.db_manager.add_blocked_domain.assert_called_once_with(domain) + assert response['code'] == DOMAIN_BLOCKED + + def test_unblock_domain(self, handler: DomainBlockHandler, sample_domains: list[str]) -> None: + """Test unblocking a domain.""" + domain = sample_domains[0] + handler.db_manager.is_domain_blocked.return_value = True + + response = handler.handle_request({ + 'action': 'unblock', + 'domain': domain + }) + + handler.db_manager.remove_blocked_domain.assert_called_once_with(domain) + assert response['code'] == SUCCESS + + def test_unblock_nonexistent_domain(self, handler: DomainBlockHandler) -> None: + """Test unblocking a domain that isn't blocked.""" + handler.db_manager.is_domain_blocked.return_value = False + + response = handler.handle_request({ + 'action': 'unblock', + 'domain': 'nonexistent.com' + }) + + assert response['code'] == DOMAIN_NOT_FOUND + + +class TestAdBlockHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock, mock_requests: mock.Mock) -> AdBlockHandler: + """Create handler instance with loading disabled.""" + with mock.patch('My_Internet.server.src.handlers.AdBlockHandler.load_easylist'): + # Create handler without loading easylist during initialization + handler = AdBlockHandler(mock_db_manager) + # Reset the mock before the test + mock_requests.reset_mock() + return handler + + def test_load_easylist(self, handler: AdBlockHandler, mock_requests: mock.Mock) -> None: + """Test loading the easylist.""" + # Configure mock response + mock_response = mock.Mock() + mock_response.text = "test.com\n!comment\nexample.com" + mock_requests.get.return_value = mock_response + + # Call method + handler.load_easylist() + + # Verify calls using the imported constant + mock_requests.get.assert_called_once_with(EASYLIST_URL) + mock_response.raise_for_status.assert_called_once() + handler.db_manager.clear_easylist.assert_called_once() + + # Verify that store_easylist_entries was called with correct data + expected_entries = [('test.com',), ('example.com',)] + handler.db_manager.store_easylist_entries.assert_called_once_with(expected_entries) + + def test_handle_check_request(self, handler: AdBlockHandler) -> None: + """Test checking a domain against easylist.""" + handler.db_manager.is_easylist_blocked.return_value = True + response = handler.handle_request({'domain': 'example.com'}) + + assert response['code'] == AD_BLOCK_ENABLED + + def test_handle_check_request_not_blocked(self, handler: AdBlockHandler) -> None: + """Test checking an unblocked domain.""" + handler.db_manager.is_easylist_blocked.return_value = False + response = handler.handle_request({'domain': 'example.com'}) + + assert response['code'] == SUCCESS + + +class TestRequestFactory: + def test_create_handlers(self, request_factory: RequestFactory) -> None: + """Test creating different types of handlers.""" + handlers = { + 'ad_block': AdBlockHandler, + 'domain_block': DomainBlockHandler, + 'adult_content_block': AdultContentBlockHandler + } + + for handler_type, handler_class in handlers.items(): + handler = request_factory.create_request_handler(handler_type) + assert isinstance(handler, handler_class) + + @mock.patch.object(AdultContentBlockHandler, 'handle_request') + def test_request_delegation( + self, + mock_handle: mock.Mock, + request_factory: RequestFactory + ) -> None: + """Test request delegation to appropriate handler.""" + expected_response = {'code': SUCCESS, 'message': 'Test response'} + mock_handle.return_value = expected_response + + request_data = { + 'type': 'adult_content_block', + 'action': 'enable' + } + + response = request_factory.handle_request(request_data) + mock_handle.assert_called_once_with(request_data) + assert response == expected_response + + def test_handle_invalid_request_type(self, request_factory: RequestFactory) -> None: + """Test handling invalid request type.""" + response = request_factory.handle_request({'type': 'invalid_type'}) + assert response['code'] == INVALID_REQUEST + + def test_factory_handler_integration(self, request_factory: RequestFactory) -> None: + """Test integration between factory and handlers.""" + test_cases = [ + { + 'request': {'type': 'adult_content_block', 'action': 'enable'}, + 'expected_code': SUCCESS + }, + { + 'request': {'type': 'domain_block', 'action': 'block', 'domain': 'example.com'}, + 'expected_code': DOMAIN_BLOCKED + }, + { + 'request': {'type': 'ad_block', 'domain': 'test.com'}, + 'expected_code': SUCCESS + } + ] + + for test_case in test_cases: + response = request_factory.handle_request(test_case['request']) + assert response['code'] == test_case['expected_code'] \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b660429 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="My_Internet", + version="0.1", + packages=find_packages(), +) \ No newline at end of file From 7822b04347f4866b8dd120072076504ca5485d23 Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Tue, 5 Nov 2024 18:19:11 +0200 Subject: [PATCH 16/38] tests for the server --- server/tests/conftest.py | 16 +++- server/tests/test_server.py | 145 ++++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 server/tests/test_server.py diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 2a2cd28..14fb24c 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -59,4 +59,18 @@ def sample_requests() -> dict: "type": "ad_block", "domain": "test.com" } - } \ No newline at end of file + } + +@pytest.fixture +async def mock_stream_reader() -> mock.AsyncMock: + """Global fixture for AsyncMock StreamReader.""" + reader = mock.AsyncMock() + return reader + +@pytest.fixture +async def mock_stream_writer() -> mock.AsyncMock: + """Global fixture for AsyncMock StreamWriter.""" + writer = mock.AsyncMock() + writer.write = mock.Mock() # write is usually synchronous + writer.drain = mock.AsyncMock() + return writer \ No newline at end of file diff --git a/server/tests/test_server.py b/server/tests/test_server.py new file mode 100644 index 0000000..3fcadd1 --- /dev/null +++ b/server/tests/test_server.py @@ -0,0 +1,145 @@ +import pytest +import json +import asyncio +from unittest import mock +from typing import AsyncGenerator, Dict, Any + +from My_Internet.server.src.server import ( + handle_client, + handle_kernel, + route_kernel_request, + start_server +) +from My_Internet.server.src.config import HOST, CLIENT_PORT, KERNEL_PORT +from My_Internet.server.src.handlers import RequestFactory +from My_Internet.server.src.response_codes import SUCCESS, INVALID_REQUEST + +class TestServer: + @pytest.fixture + def mock_stream_reader(self) -> mock.AsyncMock: + reader = mock.AsyncMock() + reader.readline = mock.AsyncMock() + return reader + + @pytest.fixture + def mock_stream_writer(self) -> mock.Mock: + writer = mock.Mock() + writer.write = mock.Mock() + writer.drain = mock.AsyncMock() + writer.close = mock.Mock() + return writer + + @pytest.mark.asyncio # Only for async functions + async def test_handle_client( + self, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock, + request_factory: RequestFactory + ) -> None: + """Test client request handling.""" + test_request = { + 'type': 'domain_block', + 'action': 'block', + 'domain': 'example.com' + } + + mock_stream_reader.readline.side_effect = [ + json.dumps(test_request).encode() + b'\n', + b'' + ] + + await handle_client(mock_stream_reader, mock_stream_writer, request_factory) + + assert mock_stream_writer.write.called + assert mock_stream_writer.drain.called + assert mock_stream_writer.close.called + + @pytest.mark.asyncio # Only for async functions + async def test_handle_kernel( + self, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock, + mock_db_manager: mock.Mock + ) -> None: + """Test kernel request handling.""" + test_request = { + 'domain': 'example.com', + 'categories': ['adult'] + } + + mock_stream_reader.readline.side_effect = [ + json.dumps(test_request).encode() + b'\n', + b'' + ] + + await handle_kernel(mock_stream_reader, mock_stream_writer, mock_db_manager) + + assert mock_stream_writer.write.called + assert mock_stream_writer.drain.called + assert mock_stream_writer.close.called + + # No asyncio marker for synchronous functions + def test_route_kernel_request(self, mock_db_manager: mock.Mock) -> None: + """Test kernel request routing.""" + # Test blocked domain + mock_db_manager.is_domain_blocked.return_value = True + response = route_kernel_request({'domain': 'example.com'}, mock_db_manager) + assert response['block'] is True + + # Test allowed domain + mock_db_manager.is_domain_blocked.return_value = False + mock_db_manager.is_easylist_blocked.return_value = False + response = route_kernel_request({'domain': 'example.com'}, mock_db_manager) + assert response['block'] is False + + @pytest.mark.asyncio # Only for async functions + async def test_start_server(self, mock_db_manager: mock.Mock) -> None: + """Test server startup.""" + mock_client_server = mock.AsyncMock() + mock_kernel_server = mock.AsyncMock() + + with mock.patch('asyncio.start_server', + side_effect=[mock_client_server, mock_kernel_server]) as mock_start: + task = asyncio.create_task(start_server(mock_db_manager)) + await asyncio.sleep(0.1) + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + assert mock_start.call_count == 2 + assert mock_client_server.serve_forever.called + assert mock_kernel_server.serve_forever.called + + @pytest.mark.asyncio # Only for async functions + async def test_client_connection_error( + self, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock, + request_factory: RequestFactory + ) -> None: + """Test handling of client connection errors.""" + mock_stream_reader.readline.side_effect = ConnectionResetError() + + await handle_client(mock_stream_reader, mock_stream_writer, request_factory) + assert mock_stream_writer.close.called + + # No asyncio marker for synchronous functions + @pytest.mark.parametrize("request_data,expected_block", [ + ({'domain': 'example.com', 'categories': []}, False), + ({'domain': 'blocked.com', 'categories': ['adult']}, True), + ]) + def test_kernel_request_scenarios( + self, + request_data: Dict[str, Any], + expected_block: bool, + mock_db_manager: mock.Mock + ) -> None: + """Test various kernel request scenarios.""" + is_blocked = 'blocked.com' in request_data['domain'] + mock_db_manager.is_domain_blocked.return_value = is_blocked + + response = route_kernel_request(request_data, mock_db_manager) + assert response['block'] is expected_block \ No newline at end of file From 42169ce53e5f5b56160bbfb78076f74b308d1d1d Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:05:09 +0200 Subject: [PATCH 17/38] change imports in handlers,server,main --- .gitignore | 2 ++ server/main.py | 4 ++-- server/src/handlers.py | 4 ++-- server/src/server.py | 6 +++--- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 594cee2..5319942 100644 --- a/.gitignore +++ b/.gitignore @@ -845,3 +845,5 @@ FodyWeavers.xsd # Additional files built by Visual Studio # End of https://www.toptal.com/developers/gitignore/api/python,c,windows,linux,visualstudio,pycharm,clion +server/client.py +server/my_internet.db diff --git a/server/main.py b/server/main.py index 07fe9ac..409bf27 100644 --- a/server/main.py +++ b/server/main.py @@ -1,5 +1,5 @@ -from My_Internet.server.src.server import run -from My_Internet.server.src.config import DB_FILE +from src.server import run +from src.config import DB_FILE if __name__ == '__main__': run(DB_FILE) \ No newline at end of file diff --git a/server/src/handlers.py b/server/src/handlers.py index e817167..907e140 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -3,8 +3,8 @@ import socket from abc import ABC, abstractmethod from typing import Dict, Any, Optional -from My_Internet.server.src.db_manager import DatabaseManager -from My_Internet.server.src.response_codes import ( +from .db_manager import DatabaseManager +from .response_codes import ( SUCCESS, INVALID_REQUEST, DOMAIN_BLOCKED, DOMAIN_NOT_FOUND, AD_BLOCK_ENABLED, ADULT_CONTENT_BLOCKED, RESPONSE_MESSAGES diff --git a/server/src/server.py b/server/src/server.py index 0125167..001481d 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -3,9 +3,9 @@ import asyncio import json from typing import Dict, Any -from My_Internet.server.src.config import HOST, CLIENT_PORT, KERNEL_PORT, DB_FILE -from My_Internet.server.src.db_manager import DatabaseManager -from My_Internet.server.src.handlers import RequestFactory, AdultContentBlockHandler +from .config import HOST, CLIENT_PORT, KERNEL_PORT, DB_FILE +from .db_manager import DatabaseManager +from .handlers import RequestFactory, AdultContentBlockHandler From f82015fb8340eaa9d79ecf6b92252df38b793cbc Mon Sep 17 00:00:00 2001 From: elipaz Date: Wed, 6 Nov 2024 13:44:51 +0200 Subject: [PATCH 18/38] Add utlis and consts --- client/src/Application.py | 22 ++++--- client/src/Communicator.py | 25 +++++--- client/src/ConfigManager.py | 17 +----- client/src/Logger.py | 6 +- client/src/View.py | 116 ++++++++++++++++++++++-------------- client/src/utils.py | 64 ++++++++++++++++++++ 6 files changed, 168 insertions(+), 82 deletions(-) create mode 100644 client/src/utils.py diff --git a/client/src/Application.py b/client/src/Application.py index 054ce79..9108050 100644 --- a/client/src/Application.py +++ b/client/src/Application.py @@ -5,6 +5,11 @@ from .Logger import setup_logger from .ConfigManager import ConfigManager +from .utils import ( + STR_CODE, STR_CONTENT, + Codes +) + class Application: """ Main application class that coordinates communication between UI and server. @@ -78,16 +83,17 @@ def _handle_request(self, request: str) -> None: """ try: self._logger.info(f"Processing request: {request}") + request_dict = json.loads(request) with self._request_lock: - match request["CODE"]: - case "50" | "51" | "52": - self._communicator.send_message(request) - case "53": - if not isinstance(request["content"], list): - self._logger.error("Invalid content format for domain list update") - return - self._view.update_domain_list(request["content"]) + match request_dict[STR_CODE]: + case Codes.CODE_AD_BLOCK | \ + Codes.CODE_ADULT_BLOCK | \ + Codes.CODE_ADD_DOMAIN | \ + Codes.CODE_REMOVE_DOMAIN: + self._communicator.send_message(json.dumps(request)) + case Codes.CODE_DOMAIN_LIST_UPDATE: + self._view.update_domain_list(request_dict[STR_CONTENT]) except json.JSONDecodeError as e: self._logger.error(f"Invalid JSON format: {str(e)}") diff --git a/client/src/Communicator.py b/client/src/Communicator.py index 0590ace..e6275b7 100644 --- a/client/src/Communicator.py +++ b/client/src/Communicator.py @@ -2,6 +2,11 @@ from typing import Optional, Callable import json from .Logger import setup_logger +from .utils import ( + DEFAULT_HOST, DEFAULT_PORT, DEFAULT_BUFFER_SIZE, + ERR_SOCKET_NOT_SETUP, + STR_NETWORK +) class Communicator: def __init__(self, config_manager, message_callback: Callable[[str], None]) -> None: @@ -17,9 +22,9 @@ def __init__(self, config_manager, message_callback: Callable[[str], None]) -> N self.config = config_manager.get_config() self._message_callback = message_callback - self._host = self.config["network"]["host"] - self._port = self.config["network"]["port"] - self._receive_buffer_size = self.config["network"]["receive_buffer_size"] + self._host = self.config[STR_NETWORK][DEFAULT_HOST] + self._port = self.config[STR_NETWORK][DEFAULT_PORT] + self._receive_buffer_size = self.config[STR_NETWORK][DEFAULT_BUFFER_SIZE] self._socket: Optional[socket.socket] = None def connect(self) -> None: @@ -47,9 +52,7 @@ def send_message(self, message: str) -> None: Raises: RuntimeError: If socket connection is not established. """ - if not self._socket: - self.logger.error("Attempted to send message without connection") - raise RuntimeError("Socket not set up. Call connect method first.") + self._validate_connection() try: self._socket.send(message.encode('utf-8')) @@ -69,9 +72,7 @@ def receive_message(self) -> None: socket.error: If there's an error receiving data from the socket. UnicodeDecodeError: If received data cannot be decoded as UTF-8. """ - if not self._socket: - self.logger.error("Attempted to receive message without connection") - raise RuntimeError("Socket not set up. Call connect method first.") + self._validate_connection() self.logger.info("Starting message receive loop") try: @@ -96,3 +97,9 @@ def close(self) -> None: self.logger.error(f"Error closing socket: {str(e)}") finally: self._socket = None + + def _validate_connection(self) -> None: + """Validate the socket connection.""" + if not self._socket: + self.logger.error(ERR_SOCKET_NOT_SETUP) + raise RuntimeError(ERR_SOCKET_NOT_SETUP) diff --git a/client/src/ConfigManager.py b/client/src/ConfigManager.py index f7fa92a..62e67df 100644 --- a/client/src/ConfigManager.py +++ b/client/src/ConfigManager.py @@ -4,23 +4,8 @@ import os from typing import Dict, Any from .Logger import setup_logger +from .utils import DEFAULT_CONFIG -DEFAULT_CONFIG = { - "network": { - "host": "127.0.0.1", - "port": 65432, - "receive_buffer_size": 1024 - }, - "blocked_domains": {}, - "settings": { - "ad_block": "off", - "adult_block": "off" - }, - "logging": { - "level": "INFO", - "log_dir": "client_logs" - } -} class ConfigManager: """Manages application configuration loading and saving.""" diff --git a/client/src/Logger.py b/client/src/Logger.py index 034d9b7..a6c5709 100644 --- a/client/src/Logger.py +++ b/client/src/Logger.py @@ -4,8 +4,8 @@ import os from datetime import datetime from typing import Optional +from .utils import LOG_DIR, LOG_FORMAT, LOG_DATE_FORMAT -LOG_DIR = "client_logs" _logger: Optional[logging.Logger] = None def setup_logger(name: str) -> logging.Logger: @@ -27,12 +27,12 @@ def setup_logger(name: str) -> logging.Logger: os.makedirs(LOG_DIR) log_file: str = os.path.join( - LOG_DIR, f"Client_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" + LOG_DIR, f"Client_{datetime.now().strftime(LOG_DATE_FORMAT)}.log" ) logging.basicConfig( level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + format=LOG_FORMAT, handlers=[ logging.FileHandler(log_file), logging.StreamHandler(), diff --git a/client/src/View.py b/client/src/View.py index bc2926b..6037bb6 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -1,11 +1,19 @@ import tkinter as tk -from tkinter import scrolledtext, ttk, messagebox -from typing import Callable, Dict, List, Any +from tkinter import ttk, messagebox +from typing import Callable, List import json -import os +import threading from .Logger import setup_logger from .ConfigManager import ConfigManager +from .utils import ( + Codes, + WINDOW_SIZE, WINDOW_TITLE, + ERR_DUPLICATE_DOMAIN, ERR_NO_DOMAIN_SELECTED, ERR_DOMAIN_LIST_UPDATE_FAILED, + STR_AD_BLOCK, STR_ADULT_BLOCK, STR_CODE, + STR_BLOCKED_DOMAINS, STR_CONTENT, STR_SETTINGS, STR_ERROR +) + class Viewer: """ @@ -25,10 +33,12 @@ def __init__(self, config_manager: ConfigManager, message_callback: Callable[[st self.config_manager = config_manager self.config = config_manager.get_config() self._message_callback = message_callback + self._update_list_lock = threading.Lock() + self.root: tk.Tk = tk.Tk() - self.root.title("Site Blocker") - self.root.geometry("800x600") + self.root.title(WINDOW_TITLE) + self.root.geometry(WINDOW_SIZE) self._setup_ui() self.logger.info("Viewer initialization complete") @@ -55,8 +65,8 @@ def get_block_settings(self) -> dict[str, str]: A dictionary containing the current state of ad and adult content blocking. """ return { - "ad_block": self.ad_var.get(), - "adult_block": self.adult_var.get() + STR_AD_BLOCK: self.ad_var.get(), + STR_ADULT_BLOCK: self.adult_var.get() } def update_domain_list(self, domains: List[str]) -> None: @@ -66,75 +76,89 @@ def update_domain_list(self, domains: List[str]) -> None: Args: domains: List of domain strings to be displayed in the listbox. """ - self.logger.info("Updating domain list from server") - try: - self.domains_listbox.delete(0, tk.END) - - for domain in domains: - self.domains_listbox.insert(tk.END, domain) + with self._update_list_lock: + self.logger.info("Updating domain list from server") + + try: + self.domains_listbox.delete(0, tk.END) - self.logger.info(f"Updated domain list with {len(domains)} domains") - except Exception as e: - self.logger.error(f"Error updating domain list: {str(e)}") - self._show_error("Failed to update domain list") - - def _send_message(self) -> None: - """Handle the sending of messages from the input field.""" - message = self.input_field.get().strip() - if message: - message_json = json.dumps({"CODE": "100", "content": message}) - self._message_callback(message_json) - self.input_field.delete(0, tk.END) - self.display_message("You", message) + for domain in domains: + self.domains_listbox.insert(tk.END, domain) + + self.logger.info(f"Updated domain list with {len(domains)} domains") + + except Exception as e: + self.logger.error(f"Error updating domain list: {str(e)}") + self._show_error(ERR_DOMAIN_LIST_UPDATE_FAILED) def _add_domain(self) -> None: """Add a domain to the blocked sites list.""" domain = self.domain_entry.get().strip() + if domain: - if domain not in self.config["blocked_domains"]: + if domain not in self.config[STR_BLOCKED_DOMAINS]: self.domains_listbox.insert(tk.END, domain) - self.config["blocked_domains"][domain] = True self.domain_entry.delete(0, tk.END) + + self.config[STR_BLOCKED_DOMAINS][domain] = True self.config_manager.save_config(self.config) + + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_ADD_DOMAIN, + STR_CONTENT: domain + })) + self.logger.info(f"Domain added: {domain}") - - self._message_callback(json.dumps({"CODE": "53", "content": domain})) else: self.logger.warning(f"Attempted to add duplicate domain: {domain}") - self._show_error("Domain already exists in the list") + self._show_error(ERR_DUPLICATE_DOMAIN) def _remove_domain(self) -> None: """Remove the selected domain from the blocked sites list.""" selection = self.domains_listbox.curselection() + if selection: domain = self.domains_listbox.get(selection) self.domains_listbox.delete(selection) - del self.config["blocked_domains"][domain] + + del self.config[STR_BLOCKED_DOMAINS][domain] self.config_manager.save_config(self.config) - self.logger.info(f"Domain removed: {domain}") - self._message_callback(json.dumps({"CODE": "54", "content": domain})) + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_REMOVE_DOMAIN, + STR_CONTENT: domain + })) + + self.logger.info(f"Domain removed: {domain}") else: self.logger.warning("Attempted to remove domain without selection") - self._show_error("Please select a domain to remove") + self._show_error(ERR_NO_DOMAIN_SELECTED) def _handle_ad_block(self) -> None: """Handle changes to the ad block setting.""" state = self.ad_var.get() - self.config["settings"]["ad_block"] = state + self.config[STR_SETTINGS][STR_AD_BLOCK] = state self.config_manager.save_config(self.config) - self.logger.info(f"Ad blocking state changed to: {state}") - self._message_callback(json.dumps({"CODE": "50", "content": state})) + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_AD_BLOCK, + STR_CONTENT: state + })) + + self.logger.info(f"Ad blocking state changed to: {state}") def _handle_adult_block(self) -> None: """Handle changes to the adult sites block setting.""" state = self.adult_var.get() - self.config["settings"]["adult_block"] = state + self.config[STR_SETTINGS][STR_ADULT_BLOCK] = state self.config_manager.save_config(self.config) - self.logger.info(f"Adult site blocking state changed to: {state}") - self._message_callback(json.dumps({"CODE": "51", "content": state})) + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_ADULT_BLOCK, + STR_CONTENT: state + })) + + self.logger.info(f"Adult site blocking state changed to: {state}") def _show_error(self, message: str) -> None: """ @@ -144,7 +168,7 @@ def _show_error(self, message: str) -> None: message: The error message to display. """ self.logger.error(f"Error message displayed: {message}") - tk.messagebox.showerror("Error", message) + tk.messagebox.showerror(STR_ERROR, message) def _setup_ui(self) -> None: """Set up the UI components including block controls and domain list.""" @@ -180,14 +204,14 @@ def _setup_ui(self) -> None: self.domains_listbox.bind('', lambda e: self._remove_domain()) # Load saved domains into listbox - for domain in self.config["blocked_domains"].keys(): + for domain in self.config[STR_BLOCKED_DOMAINS].keys(): self.domains_listbox.insert(tk.END, domain) - + # Ad Block controls ad_frame = ttk.LabelFrame(main_container, text="Ad Block", padding="5") ad_frame.grid(row=0, column=1, pady=10, sticky=(tk.W, tk.E)) - self.ad_var = tk.StringVar(value=self.config["settings"]["ad_block"]) + self.ad_var = tk.StringVar(value=self.config[STR_SETTINGS][STR_AD_BLOCK]) ttk.Radiobutton(ad_frame, text="on", value="on", variable=self.ad_var).grid(row=0, column=0, padx=10) ttk.Radiobutton(ad_frame, text="off", value="off", variable=self.ad_var).grid(row=0, column=1, padx=10) @@ -195,7 +219,7 @@ def _setup_ui(self) -> None: adult_frame = ttk.LabelFrame(main_container, text="Adult sites Block", padding="5") adult_frame.grid(row=1, column=1, pady=10, sticky=(tk.W, tk.E)) - self.adult_var = tk.StringVar(value=self.config["settings"]["adult_block"]) + self.adult_var = tk.StringVar(value=self.config[STR_SETTINGS][STR_ADULT_BLOCK]) ttk.Radiobutton(adult_frame, text="on", value="on", variable=self.adult_var).grid(row=0, column=0, padx=10) ttk.Radiobutton(adult_frame, text="off", value="off", variable=self.adult_var).grid(row=0, column=1, padx=10) diff --git a/client/src/utils.py b/client/src/utils.py new file mode 100644 index 0000000..31b59f2 --- /dev/null +++ b/client/src/utils.py @@ -0,0 +1,64 @@ +"""Utility module containing constants and common functions for the application.""" + +# Network related constants +DEFAULT_HOST = "host" +DEFAULT_PORT = "port" +DEFAULT_BUFFER_SIZE = "receive_buffer_size" + +# GUI constants +WINDOW_TITLE = "Site Blocker" +WINDOW_SIZE = "800x600" +PADDING_SMALL = "5" +PADDING_MEDIUM = "10" + +# Message codes +class Codes: + """Constants for message codes used in communication.""" + CODE_AD_BLOCK = "50" + CODE_ADULT_BLOCK = "51" + CODE_ADD_DOMAIN = "52" + CODE_REMOVE_DOMAIN = "53" + CODE_DOMAIN_LIST_UPDATE = "54" + +# Default settings +DEFAULT_CONFIG = { + "network": { + "host": DEFAULT_HOST, + "port": DEFAULT_PORT, + "receive_buffer_size": DEFAULT_BUFFER_SIZE + }, + "blocked_domains": {}, + "settings": { + "ad_block": "off", + "adult_block": "off" + }, + "logging": { + "level": "INFO", + "log_dir": "client_logs" + } +} + +# Logging constants +LOG_DIR = "client_logs" +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +LOG_DATE_FORMAT = "%Y%m%d_%H%M%S" + +# Error messages +ERR_SOCKET_NOT_SETUP = "Socket not set up. Call connect method first." +ERR_NO_CONNECTION = "Attempted to send message without connection" +ERR_DUPLICATE_DOMAIN = "Domain already exists in the list" +ERR_NO_DOMAIN_SELECTED = "Please select a domain to remove" +ERR_DOMAIN_LIST_UPDATE_FAILED = "Failed to update domain list" + +# String Constants +STR_AD_BLOCK = "ad_block" +STR_ADULT_BLOCK = "adult_block" +STR_CODE = "code" +STR_CONTENT = "content" +STR_ERROR = "Error" + +# Config Constants +STR_BLOCKED_DOMAINS = "blocked_domains" +STR_NETWORK = "network" +STR_SETTINGS = "settings" +STR_LOGGING = "logging" From 957850bd222760984ad6b5af77e2d3a86955a60d Mon Sep 17 00:00:00 2001 From: elipaz Date: Wed, 6 Nov 2024 15:21:51 +0200 Subject: [PATCH 19/38] Finish without tests --- client/config.json | 7 +- client/src/Application.py | 2 +- client/src/View.py | 221 +++++++++++++++++++++++++++++++------- 3 files changed, 188 insertions(+), 42 deletions(-) diff --git a/client/config.json b/client/config.json index b662c6b..a0dfe4e 100644 --- a/client/config.json +++ b/client/config.json @@ -6,11 +6,12 @@ }, "blocked_domains": { "example.com": true, - "ads.example.com": true + "ads.example.com": true, + "fxp.co.il": true }, "settings": { - "ad_block": "on", - "adult_block": "on" + "ad_block": "off", + "adult_block": "off" }, "logging": { "level": "INFO", diff --git a/client/src/Application.py b/client/src/Application.py index 9108050..d4ec6bf 100644 --- a/client/src/Application.py +++ b/client/src/Application.py @@ -41,7 +41,7 @@ def run(self) -> None: self._logger.info("Starting application") try: - # self._start_communication() + self._start_communication() self._start_gui() except Exception as e: diff --git a/client/src/View.py b/client/src/View.py index 6037bb6..2943421 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -35,12 +35,25 @@ def __init__(self, config_manager: ConfigManager, message_callback: Callable[[st self._message_callback = message_callback self._update_list_lock = threading.Lock() - + # Initialize root window first self.root: tk.Tk = tk.Tk() self.root.title(WINDOW_TITLE) self.root.geometry(WINDOW_SIZE) + self.root.withdraw() # Hide the window temporarily + + # Configure styles + style = ttk.Style() + style.configure('TLabelframe', padding=10) + style.configure('TLabelframe.Label', font=('Arial', 10, 'bold')) + style.configure('TButton', padding=5) + style.configure('TRadiobutton', font=('Arial', 10)) + style.configure('TLabel', font=('Arial', 10)) + self._setup_ui() + + # Show the window after setup is complete + self.root.deiconify() self.logger.info("Viewer initialization complete") def run(self) -> None: @@ -172,62 +185,194 @@ def _show_error(self, message: str) -> None: def _setup_ui(self) -> None: """Set up the UI components including block controls and domain list.""" - main_container = ttk.Frame(self.root, padding="10") + # Main container with increased padding + main_container = ttk.Frame(self.root, padding="20") main_container.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) self.root.columnconfigure(0, weight=1) self.root.rowconfigure(0, weight=1) - # Left side - Specific sites block - sites_frame = ttk.LabelFrame(main_container, text="Specific sites block", padding="5") - sites_frame.grid(row=0, column=0, rowspan=3, padx=5, sticky=(tk.W, tk.E, tk.N, tk.S)) + # Left side - Specific sites block (now with better proportions) + sites_frame = ttk.LabelFrame( + main_container, + text="Specific Sites Block", + padding="15" + ) + sites_frame.grid( + row=0, + column=0, + rowspan=3, + padx=10, + sticky=(tk.W, tk.E, tk.N, tk.S) + ) + + # Create a frame for listbox and scrollbar + listbox_frame = ttk.Frame(sites_frame) + listbox_frame.grid(row=0, column=0, pady=5, sticky=(tk.W, tk.E, tk.N, tk.S)) - # Domains listbox - self.domains_listbox = tk.Listbox(sites_frame, width=40, height=15) - self.domains_listbox.grid(row=0, column=0, pady=5, sticky=(tk.W, tk.E, tk.N, tk.S)) + # Domains listbox with scrollbars + self.domains_listbox = tk.Listbox( + listbox_frame, + width=40, + height=15, + selectmode=tk.SINGLE, + activestyle='dotbox', + font=('Arial', 10) + ) + scrollbar_y = ttk.Scrollbar( + listbox_frame, + orient=tk.VERTICAL, + command=self.domains_listbox.yview + ) + scrollbar_x = ttk.Scrollbar( + listbox_frame, + orient=tk.HORIZONTAL, + command=self.domains_listbox.xview + ) - # Add domain entry + self.domains_listbox.configure( + yscrollcommand=scrollbar_y.set, + xscrollcommand=scrollbar_x.set + ) + + # Grid layout for listbox and scrollbars + self.domains_listbox.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + scrollbar_y.grid(row=0, column=1, sticky=(tk.N, tk.S)) + scrollbar_x.grid(row=1, column=0, sticky=(tk.W, tk.E)) + + # Add domain entry with improved layout domain_entry_frame = ttk.Frame(sites_frame) - domain_entry_frame.grid(row=1, column=0, sticky=(tk.W, tk.E)) + domain_entry_frame.grid( + row=1, + column=0, + pady=15, + sticky=(tk.W, tk.E) + ) + + ttk.Label( + domain_entry_frame, + text="Add Domain:", + font=('Arial', 10) + ).grid(row=0, column=0, padx=5) - ttk.Label(domain_entry_frame, text="Add Domain:").grid(row=0, column=0, padx=5) - self.domain_entry = ttk.Entry(domain_entry_frame) - self.domain_entry.grid(row=0, column=1, padx=5, sticky=(tk.W, tk.E)) + self.domain_entry = ttk.Entry( + domain_entry_frame, + font=('Arial', 10) + ) + self.domain_entry.grid( + row=0, + column=1, + padx=5, + sticky=(tk.W, tk.E) + ) - # Add buttons for domain management + # Buttons with improved styling button_frame = ttk.Frame(sites_frame) - button_frame.grid(row=2, column=0, pady=5, sticky=(tk.W, tk.E)) + button_frame.grid( + row=2, + column=0, + pady=10, + sticky=(tk.W, tk.E) + ) - ttk.Button(button_frame, text="Add", command=self._add_domain).grid(row=0, column=0, padx=5) - ttk.Button(button_frame, text="Remove", command=self._remove_domain).grid(row=0, column=1, padx=5) + style = ttk.Style() + style.configure('Action.TButton', padding=5) - # Bind double-click event for removing domains - self.domains_listbox.bind('', lambda e: self._remove_domain()) + ttk.Button( + button_frame, + text="Add Domain", + style='Action.TButton', + command=self._add_domain + ).grid(row=0, column=0, padx=5) - # Load saved domains into listbox - for domain in self.config[STR_BLOCKED_DOMAINS].keys(): - self.domains_listbox.insert(tk.END, domain) + ttk.Button( + button_frame, + text="Remove Domain", + style='Action.TButton', + command=self._remove_domain + ).grid(row=0, column=1, padx=5) + + # Right side controls with improved spacing + controls_frame = ttk.Frame(main_container) + controls_frame.grid( + row=0, + column=1, + padx=20, + sticky=(tk.N, tk.S) + ) - # Ad Block controls - ad_frame = ttk.LabelFrame(main_container, text="Ad Block", padding="5") - ad_frame.grid(row=0, column=1, pady=10, sticky=(tk.W, tk.E)) + # Ad Block controls with better styling + ad_frame = ttk.LabelFrame( + controls_frame, + text="Ad Blocking", + padding="15" + ) + ad_frame.grid( + row=0, + column=0, + pady=10, + sticky=(tk.W, tk.E) + ) + # Initialize with config value self.ad_var = tk.StringVar(value=self.config[STR_SETTINGS][STR_AD_BLOCK]) - ttk.Radiobutton(ad_frame, text="on", value="on", variable=self.ad_var).grid(row=0, column=0, padx=10) - ttk.Radiobutton(ad_frame, text="off", value="off", variable=self.ad_var).grid(row=0, column=1, padx=10) + ttk.Radiobutton( + ad_frame, + text="Enable", + value="on", + variable=self.ad_var, + command=self._handle_ad_block + ).grid(row=0, column=0, padx=10) + ttk.Radiobutton( + ad_frame, + text="Disable", + value="off", + variable=self.ad_var, + command=self._handle_ad_block + ).grid(row=0, column=1, padx=10) # Adult sites Block controls - adult_frame = ttk.LabelFrame(main_container, text="Adult sites Block", padding="5") - adult_frame.grid(row=1, column=1, pady=10, sticky=(tk.W, tk.E)) + adult_frame = ttk.LabelFrame( + controls_frame, + text="Adult Content Blocking", + padding="15" + ) + adult_frame.grid( + row=1, + column=0, + pady=10, + sticky=(tk.W, tk.E) + ) + # Initialize with config value self.adult_var = tk.StringVar(value=self.config[STR_SETTINGS][STR_ADULT_BLOCK]) - ttk.Radiobutton(adult_frame, text="on", value="on", variable=self.adult_var).grid(row=0, column=0, padx=10) - ttk.Radiobutton(adult_frame, text="off", value="off", variable=self.adult_var).grid(row=0, column=1, padx=10) - - # Bind radio button commands - self.ad_var.trace_add('write', lambda *args: self._handle_ad_block()) - self.adult_var.trace_add('write', lambda *args: self._handle_adult_block()) - - # Configure grid weights - main_container.columnconfigure(0, weight=1) + ttk.Radiobutton( + adult_frame, + text="Enable", + value="on", + variable=self.adult_var, + command=self._handle_adult_block + ).grid(row=0, column=0, padx=10) + ttk.Radiobutton( + adult_frame, + text="Disable", + value="off", + variable=self.adult_var, + command=self._handle_adult_block + ).grid(row=0, column=1, padx=10) + + # Configure grid weights for better resizing + main_container.columnconfigure(0, weight=3) + main_container.columnconfigure(1, weight=1) sites_frame.columnconfigure(0, weight=1) + listbox_frame.columnconfigure(0, weight=1) + listbox_frame.rowconfigure(0, weight=1) domain_entry_frame.columnconfigure(1, weight=1) + button_frame.columnconfigure(0, weight=1) + button_frame.columnconfigure(1, weight=1) + + # Bind events + self.domains_listbox.bind('', lambda e: self._remove_domain()) + + # Load saved domains + for domain in self.config[STR_BLOCKED_DOMAINS].keys(): + self.domains_listbox.insert(tk.END, domain) From b0c04b626f09baa0dace19fb70902cc8189f7ce0 Mon Sep 17 00:00:00 2001 From: elipaz Date: Wed, 6 Nov 2024 15:56:48 +0200 Subject: [PATCH 20/38] Finish tests --- client/tests/test_application.py | 64 +++++++++-- client/tests/test_communicator.py | 73 ++++++++++--- client/tests/test_view.py | 175 +++++++++++++++++++++++------- 3 files changed, 241 insertions(+), 71 deletions(-) diff --git a/client/tests/test_application.py b/client/tests/test_application.py index ca9df06..f55c206 100644 --- a/client/tests/test_application.py +++ b/client/tests/test_application.py @@ -1,30 +1,36 @@ -import os import logging from unittest import mock -from datetime import datetime from typing import Optional, Callable +import json import pytest from src.Application import Application from src.View import Viewer from src.Communicator import Communicator +from src.utils import ( + STR_CODE, STR_CONTENT, + Codes, DEFAULT_CONFIG +) @pytest.fixture -def mock_callback() -> Callable[[str], None]: - """Fixture to provide a mock callback function.""" - return mock.Mock() +def mock_config_manager() -> mock.Mock: + """Fixture to provide a mock configuration manager.""" + config_manager = mock.Mock() + config_manager.get_config.return_value = DEFAULT_CONFIG + return config_manager @pytest.fixture -def application(mock_callback: Callable[[str], None]) -> Application: +def application(mock_config_manager: mock.Mock) -> Application: """Fixture to create an Application instance.""" with mock.patch('src.Application.Viewer') as mock_viewer, \ mock.patch('src.Application.Communicator') as mock_comm, \ mock.patch('src.Application.setup_logger') as mock_logger: app = Application() app._logger = mock.Mock() + app._config_manager = mock_config_manager return app @@ -33,10 +39,15 @@ def test_init(application: Application) -> None: assert hasattr(application, '_logger') assert hasattr(application, '_view') assert hasattr(application, '_communicator') + assert hasattr(application, '_request_lock') + assert hasattr(application, '_config_manager') @mock.patch('src.Application.threading.Thread') -def test_start_communication(mock_thread: mock.Mock, application: Application) -> None: +def test_start_communication( + mock_thread: mock.Mock, + application: Application +) -> None: """Test the communication startup.""" application._start_communication() @@ -54,13 +65,32 @@ def test_start_gui(application: Application) -> None: application._view.run.assert_called_once() -def test_handle_request(application: Application) -> None: - """Test request handling.""" - test_request = '{"type": "test", "content": "message"}' +def test_handle_request_ad_block(application: Application) -> None: + """Test handling ad block request.""" + test_request = { + STR_CODE: Codes.CODE_AD_BLOCK, + STR_CONTENT: "test" + } + + application._communicator.send_message = mock.Mock() + + application._handle_request(json.dumps(test_request)) + + actual_arg = application._communicator.send_message.call_args[0][0] + + assert json.loads(json.loads(actual_arg)) == test_request + + +def test_handle_request_domain_list_update(application: Application) -> None: + """Test handling domain list update request.""" + test_content = ["domain1.com", "domain2.com"] + test_request = json.dumps({ + STR_CODE: Codes.CODE_DOMAIN_LIST_UPDATE, + STR_CONTENT: test_content + }) - # Currently just testing logging as implementation is pending application._handle_request(test_request) - application._logger.debug.assert_called_once_with(f"Processing request: {test_request}") + application._view.update_domain_list.assert_called_once_with(test_content) def test_cleanup(application: Application) -> None: @@ -102,4 +132,14 @@ def test_run_exception(application: Application) -> None: exc_info=True ) mock_cleanup.assert_called_once() + + +def test_handle_request_json_error(application: Application) -> None: + """Test handling of invalid JSON in request.""" + invalid_json = "{" + + with pytest.raises(json.JSONDecodeError): + application._handle_request(invalid_json) + + application._logger.error.assert_called() \ No newline at end of file diff --git a/client/tests/test_communicator.py b/client/tests/test_communicator.py index 1ed8e81..2a761e4 100644 --- a/client/tests/test_communicator.py +++ b/client/tests/test_communicator.py @@ -4,7 +4,20 @@ import pytest -from src.Communicator import Communicator, HOST, PORT, RECEIVE_BUFFER_SIZE +from src.Communicator import Communicator +from src.utils import ( + DEFAULT_HOST, DEFAULT_PORT, DEFAULT_BUFFER_SIZE, + ERR_SOCKET_NOT_SETUP, STR_NETWORK, + DEFAULT_CONFIG +) + + +@pytest.fixture +def mock_config_manager() -> mock.Mock: + """Fixture to provide a mock configuration manager.""" + config_manager = mock.Mock() + config_manager.get_config.return_value = DEFAULT_CONFIG + return config_manager @pytest.fixture @@ -14,27 +27,45 @@ def mock_callback() -> Callable[[str], None]: @pytest.fixture -def communicator(mock_callback: Callable[[str], None]) -> Communicator: +def communicator( + mock_config_manager: mock.Mock, + mock_callback: Callable[[str], None] +) -> Communicator: """Fixture to create a Communicator instance.""" - return Communicator(message_callback=mock_callback) + return Communicator( + config_manager=mock_config_manager, + message_callback=mock_callback + ) -def test_init(communicator: Communicator, mock_callback: Callable[[str], None]) -> None: +def test_init( + communicator: Communicator, + mock_callback: Callable[[str], None] +) -> None: """Test the initialization of Communicator.""" - assert communicator._host == HOST - assert communicator._port == PORT + assert communicator._host == DEFAULT_CONFIG[STR_NETWORK][DEFAULT_HOST] + assert communicator._port == DEFAULT_CONFIG[STR_NETWORK][DEFAULT_PORT] + assert communicator._receive_buffer_size == DEFAULT_CONFIG[STR_NETWORK][DEFAULT_BUFFER_SIZE] assert communicator._socket is None assert communicator._message_callback == mock_callback @mock.patch('socket.socket') -def test_connect(mock_socket_class: mock.Mock, communicator: Communicator) -> None: +def test_connect( + mock_socket_class: mock.Mock, + communicator: Communicator +) -> None: """Test the connect method initializes and connects the socket.""" mock_socket_instance = mock_socket_class.return_value communicator.connect() - mock_socket_class.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) - mock_socket_instance.connect.assert_called_once_with((HOST, PORT)) + mock_socket_class.assert_called_once_with( + socket.AF_INET, + socket.SOCK_STREAM + ) + mock_socket_instance.connect.assert_called_once_with( + (communicator._host, communicator._port) + ) assert communicator._socket is mock_socket_instance @@ -46,11 +77,14 @@ def test_send_message_without_setup( """Test sending a message without setting up the socket raises RuntimeError.""" with pytest.raises(RuntimeError) as exc_info: communicator.send_message("Hello") - assert str(exc_info.value) == "Socket not set up. Call connect method first." + assert str(exc_info.value) == ERR_SOCKET_NOT_SETUP @mock.patch('socket.socket') -def test_send_message(mock_socket_class: mock.Mock, communicator: Communicator) -> None: +def test_send_message( + mock_socket_class: mock.Mock, + communicator: Communicator +) -> None: """Test sending a message successfully.""" mock_socket_instance = mock_socket_class.return_value communicator._socket = mock_socket_instance @@ -58,7 +92,9 @@ def test_send_message(mock_socket_class: mock.Mock, communicator: Communicator) message: str = "Hello, World!" communicator.send_message(message) - mock_socket_instance.send.assert_called_once_with(message.encode('utf-8')) + mock_socket_instance.send.assert_called_once_with( + message.encode('utf-8') + ) @mock.patch('socket.socket') @@ -69,7 +105,7 @@ def test_receive_message_without_setup( """Test receiving a message without setting up the socket raises RuntimeError.""" with pytest.raises(RuntimeError) as exc_info: communicator.receive_message() - assert str(exc_info.value) == "Socket not set up. Call connect method first." + assert str(exc_info.value) == ERR_SOCKET_NOT_SETUP @mock.patch('socket.socket') @@ -82,17 +118,21 @@ def test_receive_message( mock_socket_instance = mock_socket_class.return_value communicator._socket = mock_socket_instance - # Setup mock to return a message once and then empty string to break the loop mock_socket_instance.recv.side_effect = [b'Hello, Client!', b''] communicator.receive_message() - mock_socket_instance.recv.assert_called_with(RECEIVE_BUFFER_SIZE) + mock_socket_instance.recv.assert_called_with( + DEFAULT_CONFIG[STR_NETWORK][DEFAULT_BUFFER_SIZE] + ) mock_callback.assert_called_once_with('Hello, Client!') @mock.patch('socket.socket') -def test_close_socket(mock_socket_class: mock.Mock, communicator: Communicator) -> None: +def test_close_socket( + mock_socket_class: mock.Mock, + communicator: Communicator +) -> None: """Test closing the socket.""" mock_socket_instance = mock_socket_class.return_value communicator._socket = mock_socket_instance @@ -113,7 +153,6 @@ def test_receive_message_decode_error( mock_socket_instance = mock_socket_class.return_value communicator._socket = mock_socket_instance - # Setup mock to return invalid UTF-8 bytes mock_socket_instance.recv.side_effect = [bytes([0xFF, 0xFE, 0xFD]), b''] with pytest.raises(UnicodeDecodeError): diff --git a/client/tests/test_view.py b/client/tests/test_view.py index 888e78b..f83f0b9 100644 --- a/client/tests/test_view.py +++ b/client/tests/test_view.py @@ -1,67 +1,158 @@ -import tkinter as tk +import pytest from unittest import mock import json from typing import Callable -import pytest - from src.View import Viewer +from src.utils import ( + Codes, STR_CODE, STR_CONTENT, + STR_SETTINGS, STR_AD_BLOCK, STR_ADULT_BLOCK, + STR_BLOCKED_DOMAINS, DEFAULT_CONFIG, + ERR_DUPLICATE_DOMAIN +) +@pytest.fixture +def mock_config_manager() -> mock.Mock: + """Fixture to provide a mock configuration manager.""" + config_manager = mock.Mock() + config_manager.get_config.return_value = DEFAULT_CONFIG.copy() + return config_manager @pytest.fixture def mock_callback() -> Callable[[str], None]: """Fixture to provide a mock callback function.""" return mock.Mock() - @pytest.fixture -def viewer(mock_callback: Callable[[str], None]) -> Viewer: - """Fixture to create a Viewer instance.""" - with mock.patch('src.View.tk.Tk') as mock_tk: - mock_tk.return_value.title = mock.Mock() - mock_tk.return_value.geometry = mock.Mock() - return Viewer(message_callback=mock_callback) - - -def test_init(viewer: Viewer, mock_callback: Callable[[str], None]) -> None: - """Test the initialization of Viewer.""" - viewer.root.title.assert_called_once_with("Chat Application") - viewer.root.geometry.assert_called_once_with("800x600") - assert viewer._message_callback == mock_callback +def viewer(mock_config_manager: mock.Mock, mock_callback: mock.Mock) -> Viewer: + """Fixture to create a Viewer instance with mocked components.""" + with mock.patch('tkinter.Tk') as mock_tk, \ + mock.patch('tkinter.ttk.Style'): + # Create a mock Tk instance + root = mock_tk.return_value + + # Set up the mock root properly + mock_tk._default_root = root + root._default_root = root + + # Create StringVar mock that returns string values + with mock.patch('tkinter.StringVar') as mock_string_var: + string_var_instance = mock.Mock() + string_var_instance.get.return_value = "off" + mock_string_var.return_value = string_var_instance + + # Create Entry and Listbox mocks + with mock.patch('tkinter.Entry') as mock_entry, \ + mock.patch('tkinter.Listbox') as mock_listbox: + + # Setup Entry mock + entry_instance = mock.Mock() + entry_instance.get.return_value = "" + mock_entry.return_value = entry_instance + + # Setup Listbox mock + listbox_instance = mock.Mock() + listbox_instance.curselection.return_value = () + listbox_instance.get.return_value = "" + mock_listbox.return_value = listbox_instance + + viewer = Viewer( + config_manager=mock_config_manager, + message_callback=mock_callback + ) + + # Store mock instances for easy access in tests + viewer.domain_entry = entry_instance + viewer.domains_listbox = listbox_instance + + # Mock the _show_error method + viewer._show_error = mock.Mock() + + return viewer +def test_get_block_settings(viewer: Viewer) -> None: + """Test getting block settings.""" + # Configure the mock StringVar to return specific values + viewer.ad_var.get.return_value = "off" + viewer.adult_var.get.return_value = "off" + + settings = viewer.get_block_settings() + assert STR_AD_BLOCK in settings + assert STR_ADULT_BLOCK in settings + assert isinstance(settings[STR_AD_BLOCK], str) + assert isinstance(settings[STR_ADULT_BLOCK], str) -def test_send_message(viewer: Viewer, mock_callback: Callable[[str], None]) -> None: - """Test sending a message.""" - test_message = "Hello, World!" - viewer.input_field = mock.Mock() - viewer.input_field.get.return_value = test_message +def test_handle_ad_block(viewer: Viewer) -> None: + """Test handling ad block setting changes.""" + # Configure the mock StringVar to return "on" + viewer.ad_var.get.return_value = "on" + viewer._handle_ad_block() - viewer._send_message() + expected_json = json.dumps({ + STR_CODE: Codes.CODE_AD_BLOCK, + STR_CONTENT: "on" + }) - expected_json = json.dumps({"CODE": "100", "content": test_message}) - mock_callback.assert_called_once_with(expected_json) - viewer.input_field.delete.assert_called_once_with(0, tk.END) + viewer._message_callback.assert_called_once_with(expected_json) + viewer.config_manager.save_config.assert_called_once_with(viewer.config) + assert viewer.config[STR_SETTINGS][STR_AD_BLOCK] == "on" +def test_handle_adult_block(viewer: Viewer) -> None: + """Test handling adult block setting changes.""" + # Configure the mock StringVar to return "on" + viewer.adult_var.get.return_value = "on" + viewer._handle_adult_block() + + expected_json = json.dumps({ + STR_CODE: Codes.CODE_ADULT_BLOCK, + STR_CONTENT: "on" + }) + + viewer._message_callback.assert_called_once_with(expected_json) + viewer.config_manager.save_config.assert_called_once_with(viewer.config) + assert viewer.config[STR_SETTINGS][STR_ADULT_BLOCK] == "on" -def test_display_message(viewer: Viewer) -> None: - """Test displaying a message.""" - viewer.message_area = mock.Mock() +def test_add_domain(viewer: Viewer) -> None: + """Test adding a domain.""" + domain = "test.com" + viewer.domain_entry.get.return_value = domain + viewer._add_domain() - viewer.display_message("User", "Test message") + expected_json = json.dumps({ + STR_CODE: Codes.CODE_ADD_DOMAIN, + STR_CONTENT: domain + }) - viewer.message_area.config.assert_any_call(state=tk.NORMAL) - viewer.message_area.insert.assert_called_once_with(tk.END, "User: Test message\n") - viewer.message_area.see.assert_called_once_with(tk.END) - viewer.message_area.config.assert_any_call(state=tk.DISABLED) + viewer._message_callback.assert_called_once_with(expected_json) + viewer.config_manager.save_config.assert_called_once_with(viewer.config) + assert viewer.config[STR_BLOCKED_DOMAINS][domain] is True +def test_add_duplicate_domain(viewer: Viewer) -> None: + """Test adding a duplicate domain.""" + domain = "test.com" + viewer.config[STR_BLOCKED_DOMAINS][domain] = True + viewer.domain_entry.get.return_value = domain + + viewer._add_domain() + + viewer._message_callback.assert_not_called() + viewer._show_error.assert_called_once_with(ERR_DUPLICATE_DOMAIN) + assert len(viewer.config[STR_BLOCKED_DOMAINS]) == 1 -def test_display_error(viewer: Viewer) -> None: - """Test displaying an error message.""" - viewer.message_area = mock.Mock() +def test_remove_domain(viewer: Viewer) -> None: + """Test removing a domain.""" + domain = "test.com" + viewer.config[STR_BLOCKED_DOMAINS][domain] = True + viewer.domains_listbox.curselection.return_value = (0,) + viewer.domains_listbox.get.return_value = domain + + viewer._remove_domain() - viewer.display_error("Test error") + expected_json = json.dumps({ + STR_CODE: Codes.CODE_REMOVE_DOMAIN, + STR_CONTENT: domain + }) - viewer.message_area.config.assert_any_call(state=tk.NORMAL) - viewer.message_area.insert.assert_called_once_with(tk.END, "Error: Test error\n") - viewer.message_area.see.assert_called_once_with(tk.END) - viewer.message_area.config.assert_any_call(state=tk.DISABLED) + viewer._message_callback.assert_called_once_with(expected_json) + viewer.config_manager.save_config.assert_called_once_with(viewer.config) + assert domain not in viewer.config[STR_BLOCKED_DOMAINS] From 22bdea473534787bceaccd0fb63554de9c745036 Mon Sep 17 00:00:00 2001 From: elipaz Date: Wed, 6 Nov 2024 15:59:31 +0200 Subject: [PATCH 21/38] Update requirments --- client/requirments.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/requirments.txt b/client/requirments.txt index ea7c08e..0a08159 100644 --- a/client/requirments.txt +++ b/client/requirments.txt @@ -1,4 +1,4 @@ --e git+https://github.com/pazMenachem/My_Internet.git@94c2d2bea3c2ad912444fd3cba2ea4a4e9bf6daf#egg=client&subdirectory=client +-e git+https://github.com/pazMenachem/My_Internet.git@b0c04b626f09baa0dace19fb70902cc8189f7ce0#egg=client&subdirectory=client colorama==0.4.6 exceptiongroup==1.2.2 iniconfig==2.0.0 From 1ea52e8269def8e75caf37a108a60043db6ba8d8 Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:06:41 +0200 Subject: [PATCH 22/38] Restructure response code system to match client --- server/src/response_codes.py | 77 ++++++++++++++++++++++++++++-------- 1 file changed, 60 insertions(+), 17 deletions(-) diff --git a/server/src/response_codes.py b/server/src/response_codes.py index 7095c06..ba43d09 100644 --- a/server/src/response_codes.py +++ b/server/src/response_codes.py @@ -1,21 +1,64 @@ -# response_codes.py +from typing import Dict, Any -from typing import Dict +# Client Command Codes (incoming requests) +class ClientCodes: + AD_BLOCK = "50" # Toggle ad blocking + ADULT_BLOCK = "51" # Toggle adult content blocking + ADD_DOMAIN = "52" # Add domain to block list + REMOVE_DOMAIN = "53" # Remove domain from block list + DOMAIN_LIST_UPDATE = "54" # Update domain list -# Response codes -SUCCESS: int = 200 -INVALID_REQUEST: int = 400 -DOMAIN_BLOCKED: int = 201 -DOMAIN_NOT_FOUND: int = 404 -AD_BLOCK_ENABLED: int = 202 -ADULT_CONTENT_BLOCKED: int = 203 +# Server Response Codes +class ServerCodes: + SUCCESS = 200 + INVALID_REQUEST = 400 + DOMAIN_BLOCKED = 201 + DOMAIN_NOT_FOUND = 404 + AD_BLOCK_ENABLED = 202 + ADULT_CONTENT_BLOCKED = 203 -# Response messages +# Response Messages RESPONSE_MESSAGES: Dict[int, str] = { - SUCCESS: "Request processed successfully.", - INVALID_REQUEST: "Invalid request. Please check the request format.", - DOMAIN_BLOCKED: "Domain has been successfully blocked.", - DOMAIN_NOT_FOUND: "Domain not found in the block list.", - AD_BLOCK_ENABLED: "Ad blocking has been enabled for the domain.", - ADULT_CONTENT_BLOCKED: "Adult content has been blocked for the domain." -} \ No newline at end of file + ServerCodes.SUCCESS: "Request processed successfully.", + ServerCodes.INVALID_REQUEST: "Invalid request. Please check the request format.", + ServerCodes.DOMAIN_BLOCKED: "Domain has been successfully blocked.", + ServerCodes.DOMAIN_NOT_FOUND: "Domain not found in the block list.", + ServerCodes.AD_BLOCK_ENABLED: "Ad blocking has been enabled for the domain.", + ServerCodes.ADULT_CONTENT_BLOCKED: "Adult content has been blocked for the domain." +} + +def create_response( + server_code: int, + client_code: str = None, + content: Any = None, + message: str = None +) -> Dict[str, Any]: + """ + Create a standardized response format. + + Args: + server_code: Internal server response code + client_code: Client's command code (if responding to specific command) + content: Optional response payload + message: Custom message (uses default if None) + + Returns: + Formatted response dictionary + """ + response = { + "code": client_code if client_code else str(server_code), + "message": message if message else RESPONSE_MESSAGES.get(server_code, ""), + "status": server_code < 400 # True for success codes, False for error codes + } + + if content is not None: + response["content"] = content + + return response + +def create_error_response(message: str) -> Dict[str, Any]: + """Create a standardized error response.""" + return create_response( + ServerCodes.INVALID_REQUEST, + message=message + ) \ No newline at end of file From 367cb99cfe8bee0e4ab5c365a494894700812dba Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:16:55 +0200 Subject: [PATCH 23/38] Add state management system and enhance database operations --- server/src/db_manager.py | 236 ++++++++++++++++++++++++++++-------- server/src/state_manager.py | 135 +++++++++++++++++++++ 2 files changed, 318 insertions(+), 53 deletions(-) create mode 100644 server/src/state_manager.py diff --git a/server/src/db_manager.py b/server/src/db_manager.py index dfdff94..ef14f5d 100644 --- a/server/src/db_manager.py +++ b/server/src/db_manager.py @@ -1,86 +1,216 @@ -# db_manager.py - import sqlite3 -from typing import List, Tuple - +from typing import List, Dict, Any, Optional +from datetime import datetime +import threading class DatabaseManager: def __init__(self, db_file: str): self.db_file = db_file + self._connection_lock = threading.Lock() self.create_tables() def create_tables(self) -> None: + """Create necessary tables if they don't exist.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() + + # Existing tables with enhancements cursor.execute(""" CREATE TABLE IF NOT EXISTS blocked_domains ( id INTEGER PRIMARY KEY AUTOINCREMENT, - domain TEXT UNIQUE + domain TEXT UNIQUE, + timestamp TEXT, + reason TEXT, + active BOOLEAN DEFAULT TRUE ) """) + cursor.execute(""" CREATE TABLE IF NOT EXISTS easylist ( id INTEGER PRIMARY KEY AUTOINCREMENT, - entry TEXT UNIQUE + entry TEXT UNIQUE, + category TEXT, + timestamp TEXT ) """) + + # New tables for state management + cursor.execute(""" + CREATE TABLE IF NOT EXISTS feature_states ( + feature TEXT PRIMARY KEY, + state TEXT, + last_updated TEXT + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS client_sessions ( + client_id TEXT PRIMARY KEY, + last_sync TEXT, + connected BOOLEAN, + last_updated TEXT + ) + """) + + # Initialize feature states if not exists + cursor.execute(""" + INSERT OR IGNORE INTO feature_states (feature, state, last_updated) + VALUES + ('ad_block', 'off', ?), + ('adult_block', 'off', ?) + """, (datetime.now().isoformat(), datetime.now().isoformat())) + conn.commit() - def add_blocked_domain(self, domain: str) -> None: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - try: + def get_feature_state(self, feature: str) -> Optional[Dict[str, Any]]: + """Get the current state of a feature.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() cursor.execute(""" - INSERT INTO blocked_domains (domain) - VALUES (?) - """, (domain,)) + SELECT state, last_updated + FROM feature_states + WHERE feature = ? + """, (feature,)) + + result = cursor.fetchone() + if result: + return { + "state": result[0], + "last_updated": result[1] + } + return None + + def update_feature_state(self, feature: str, state: str) -> None: + """Update the state of a feature.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE feature_states + SET state = ?, last_updated = ? + WHERE feature = ? + """, (state, datetime.now().isoformat(), feature)) conn.commit() - except sqlite3.IntegrityError: - print(f"Domain {domain} already exists in the database.") - def remove_blocked_domain(self, domain: str) -> None: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - DELETE FROM blocked_domains - WHERE domain = ? - """, (domain,)) - conn.commit() + def add_blocked_domain(self, domain: str, reason: str = "manual") -> None: + """Add a domain to blocked list with metadata.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + try: + cursor.execute(""" + INSERT INTO blocked_domains (domain, timestamp, reason, active) + VALUES (?, ?, ?, TRUE) + """, (domain, datetime.now().isoformat(), reason)) + conn.commit() + except sqlite3.IntegrityError: + cursor.execute(""" + UPDATE blocked_domains + SET active = TRUE, timestamp = ?, reason = ? + WHERE domain = ? + """, (datetime.now().isoformat(), reason, domain)) + conn.commit() - def is_domain_blocked(self, domain: str) -> bool: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT domain FROM blocked_domains - WHERE domain = ? - """, (domain,)) - result = cursor.fetchone() - return result is not None + def remove_blocked_domain(self, domain: str) -> bool: + """Soft delete a domain from blocked list.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE blocked_domains + SET active = FALSE, timestamp = ? + WHERE domain = ? + """, (datetime.now().isoformat(), domain)) + conn.commit() + return cursor.rowcount > 0 - def store_easylist_entries(self, entries: List[Tuple[str]]) -> None: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.executemany(""" - INSERT OR IGNORE INTO easylist (entry) - VALUES (?) - """, entries) - conn.commit() + def get_blocked_domains(self) -> List[Dict[str, Any]]: + """Get all active blocked domains with their metadata.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT domain, timestamp, reason + FROM blocked_domains + WHERE active = TRUE + """) + return [ + { + "domain": row[0], + "timestamp": row[1], + "reason": row[2] + } + for row in cursor.fetchall() + ] + + def store_easylist_entries(self, entries: List[tuple[str, str]]) -> None: + """Store easylist entries with categories.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + timestamp = datetime.now().isoformat() + cursor.executemany(""" + INSERT OR REPLACE INTO easylist (entry, category, timestamp) + VALUES (?, ?, ?) + """, [(entry, category, timestamp) for entry, category in entries]) + conn.commit() + + def is_domain_blocked(self, domain: str) -> bool: + """Check if a domain is manually blocked.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT 1 FROM blocked_domains + WHERE domain = ? AND active = TRUE + """, (domain,)) + return cursor.fetchone() is not None def is_easylist_blocked(self, domain: str) -> bool: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT 1 FROM easylist - WHERE ? GLOB '*' || entry || '*' - """, (domain,)) - result = cursor.fetchone() - return result is not None + """Check if a domain matches any easylist pattern.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + SELECT 1 FROM easylist + WHERE ? GLOB '*' || entry || '*' + """, (domain,)) + return cursor.fetchone() is not None + + def register_client(self, client_id: str) -> None: + """Register or update a client session.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + timestamp = datetime.now().isoformat() + cursor.execute(""" + INSERT OR REPLACE INTO client_sessions + (client_id, last_sync, connected, last_updated) + VALUES (?, ?, TRUE, ?) + """, (client_id, timestamp, timestamp)) + conn.commit() + + def unregister_client(self, client_id: str) -> None: + """Mark a client as disconnected.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE client_sessions + SET connected = FALSE, last_updated = ? + WHERE client_id = ? + """, (datetime.now().isoformat(), client_id)) + conn.commit() def clear_easylist(self) -> None: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute("DELETE FROM easylist") - conn.commit() + """Clear all easylist entries.""" + with self._connection_lock: + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM easylist") + conn.commit() def close(self) -> None: - pass \ No newline at end of file + """Clean up resources.""" + pass \ No newline at end of file diff --git a/server/src/state_manager.py b/server/src/state_manager.py new file mode 100644 index 0000000..1da9e78 --- /dev/null +++ b/server/src/state_manager.py @@ -0,0 +1,135 @@ +from typing import Dict, Any, Optional +from datetime import datetime +import threading +from .db_manager import DatabaseManager + +class StateManager: + """Manages application state including feature toggles and domain states.""" + + def __init__(self, db_manager: DatabaseManager): + self._db_manager = db_manager + self._state_lock = threading.Lock() + self._states = { + "settings": { + "ad_block": { + "state": "off", + "last_updated": datetime.now().isoformat() + }, + "adult_block": { + "state": "off", + "last_updated": datetime.now().isoformat() + } + }, + "domains": {}, # Will be populated from database + "clients": {} # Will track connected clients + } + self._load_initial_state() + + def _load_initial_state(self) -> None: + """Load initial state from database.""" + # Load blocked domains + domains = self._db_manager.get_blocked_domains() + for domain in domains: + self._states["domains"][domain] = { + "blocked": True, + "timestamp": datetime.now().isoformat(), + "reason": "manual" + } + + def update_feature_state(self, feature: str, state: str) -> Dict[str, Any]: + """ + Update the state of a feature (ad_block or adult_block). + + Args: + feature: Feature to update ('ad_block' or 'adult_block') + state: New state ('on' or 'off') + + Returns: + Dict containing the updated state information + """ + with self._state_lock: + if feature not in self._states["settings"]: + raise ValueError(f"Invalid feature: {feature}") + + if state not in ["on", "off"]: + raise ValueError(f"Invalid state: {state}") + + self._states["settings"][feature] = { + "state": state, + "last_updated": datetime.now().isoformat() + } + + return self._states["settings"][feature] + + def get_feature_state(self, feature: str) -> Dict[str, Any]: + """Get the current state of a feature.""" + with self._state_lock: + if feature not in self._states["settings"]: + raise ValueError(f"Invalid feature: {feature}") + + return self._states["settings"][feature] + + def add_domain(self, domain: str, reason: str = "manual") -> Dict[str, Any]: + """ + Add a domain to blocked list. + + Args: + domain: Domain to block + reason: Reason for blocking ('manual', 'easylist', 'adult') + + Returns: + Dict containing the domain state information + """ + with self._state_lock: + domain_state = { + "blocked": True, + "timestamp": datetime.now().isoformat(), + "reason": reason + } + self._states["domains"][domain] = domain_state + self._db_manager.add_blocked_domain(domain) + + return domain_state + + def remove_domain(self, domain: str) -> bool: + """ + Remove a domain from blocked list. + + Returns: + bool indicating if domain was removed + """ + with self._state_lock: + if domain in self._states["domains"]: + del self._states["domains"][domain] + self._db_manager.remove_blocked_domain(domain) + return True + return False + + def get_domain_state(self, domain: str) -> Optional[Dict[str, Any]]: + """Get the current state of a domain.""" + with self._state_lock: + return self._states["domains"].get(domain) + + def get_all_domains(self) -> Dict[str, Dict[str, Any]]: + """Get all domain states.""" + with self._state_lock: + return self._states["domains"].copy() + + def register_client(self, client_id: str) -> None: + """Register a new client connection.""" + with self._state_lock: + self._states["clients"][client_id] = { + "last_sync": datetime.now().isoformat(), + "connected": True + } + + def unregister_client(self, client_id: str) -> None: + """Unregister a client connection.""" + with self._state_lock: + if client_id in self._states["clients"]: + self._states["clients"][client_id]["connected"] = False + + def get_full_state(self) -> Dict[str, Any]: + """Get the complete current state.""" + with self._state_lock: + return self._states.copy() \ No newline at end of file From ffc8a3cf85e57aa1f985d4ec44a812648562d51c Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:24:13 +0200 Subject: [PATCH 24/38] update handlers with new state management --- server/src/handlers.py | 312 +++++++++++++++++------------------------ 1 file changed, 132 insertions(+), 180 deletions(-) diff --git a/server/src/handlers.py b/server/src/handlers.py index 907e140..7734d1f 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -1,223 +1,175 @@ -# handlers.py -import requests -import socket from abc import ABC, abstractmethod from typing import Dict, Any, Optional +import requests from .db_manager import DatabaseManager +from .state_manager import StateManager from .response_codes import ( - SUCCESS, INVALID_REQUEST, DOMAIN_BLOCKED, - DOMAIN_NOT_FOUND, AD_BLOCK_ENABLED, - ADULT_CONTENT_BLOCKED, RESPONSE_MESSAGES + create_response, + create_error_response, + ServerCodes, + ClientCodes ) - class RequestHandler(ABC): @abstractmethod def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: pass -EASYLIST_URL = "https://easylist.to/easylist/easylist.txt" - - class AdBlockHandler(RequestHandler): - def __init__(self, db_manager: DatabaseManager): + def __init__(self, db_manager: DatabaseManager, state_manager: StateManager): self.db_manager = db_manager + self.state_manager = state_manager self.load_easylist() - def load_easylist(self) -> None: - try: - response = requests.get(EASYLIST_URL) - response.raise_for_status() - easylist_data = response.text - self.parse_and_store_easylist(easylist_data) - except requests.exceptions.RequestException as e: - print(f"Error loading EasyList: {e}") - - def parse_and_store_easylist(self, easylist_data: str) -> None: - entries = [] - for line in easylist_data.split("\n"): - line = line.strip() - if line and not line.startswith("!"): - entries.append((line,)) - self.db_manager.clear_easylist() - self.db_manager.store_easylist_entries(entries) - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - domain = request_data.get('domain') - if self.is_domain_blocked(domain): - return { - 'code': AD_BLOCK_ENABLED, - 'message': RESPONSE_MESSAGES[AD_BLOCK_ENABLED] - } - else: - return { - 'code': SUCCESS, - 'message': RESPONSE_MESSAGES[SUCCESS] - } + try: + if 'action' in request_data: + # Handle toggle request + state = request_data['action'] # 'on' or 'off' + self.state_manager.update_feature_state('ad_block', state) + return create_response( + ServerCodes.SUCCESS, + ClientCodes.AD_BLOCK, + content={"state": state} + ) + elif 'domain' in request_data: + # Handle domain check + domain = request_data['domain'] + if self.is_domain_blocked(domain): + return create_response( + ServerCodes.AD_BLOCK_ENABLED, + ClientCodes.AD_BLOCK, + content={"domain": domain} + ) + + return create_response(ServerCodes.SUCCESS) + + except Exception as e: + return create_error_response(str(e)) def is_domain_blocked(self, domain: str) -> bool: + feature_state = self.state_manager.get_feature_state('ad_block') + if feature_state['state'] == 'off': + return False return self.db_manager.is_easylist_blocked(domain) - -class DomainBlockHandler(RequestHandler): - def __init__(self, db_manager: DatabaseManager): - self.db_manager = db_manager - - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - domain = request_data.get('domain') - action = request_data.get('action') - - if action == 'block': - self.db_manager.add_blocked_domain(domain) - return { - 'code': DOMAIN_BLOCKED, - 'message': RESPONSE_MESSAGES[DOMAIN_BLOCKED] - } - elif action == 'unblock': - if self.db_manager.is_domain_blocked(domain): - self.db_manager.remove_blocked_domain(domain) - return { - 'code': SUCCESS, - 'message': RESPONSE_MESSAGES[SUCCESS] - } - else: - return { - 'code': DOMAIN_NOT_FOUND, - 'message': RESPONSE_MESSAGES[DOMAIN_NOT_FOUND] - } - else: - return { - 'code': INVALID_REQUEST, - 'message': RESPONSE_MESSAGES[INVALID_REQUEST] - } - - class AdultContentBlockHandler(RequestHandler): - # Class-level variable to track status across all instances - # We use a class variable so all instances share the same state - _is_enabled: bool = False - - def __init__(self, db_manager: DatabaseManager): + def __init__(self, db_manager: DatabaseManager, state_manager: StateManager): self.db_manager = db_manager + self.state_manager = state_manager def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - action = request_data.get('action') - try: - if action == 'status': - return self._get_status() - - elif action in ['enable', 'disable']: - return self._toggle_blocking(action) + action = request_data.get('action') + if not action: + return create_error_response("Missing action parameter") + + if action in ['enable', 'disable']: + # Convert enable/disable to on/off + state = 'on' if action == 'enable' else 'off' + self.state_manager.update_feature_state('adult_block', state) + return create_response( + ServerCodes.SUCCESS, + ClientCodes.ADULT_BLOCK, + content={"state": state} + ) elif action == 'check': - return self._check_domain(request_data.get('domain')) + domain = request_data.get('domain') + if not domain: + return create_error_response("Missing domain parameter") + + feature_state = self.state_manager.get_feature_state('adult_block') + if feature_state['state'] == 'on': + return create_response( + ServerCodes.ADULT_CONTENT_BLOCKED, + ClientCodes.ADULT_BLOCK, + content={"domain": domain} + ) + + return create_response(ServerCodes.SUCCESS) - else: - return { - 'code': INVALID_REQUEST, - 'message': RESPONSE_MESSAGES[INVALID_REQUEST] - } - except Exception as e: - print(f"Error in adult content handler: {e}") - return { - 'code': INVALID_REQUEST, - 'message': "An error occurred processing the request" - } + return create_error_response(str(e)) - def _get_status(self) -> Dict[str, Any]: - """Get current blocking status.""" - return { - 'code': SUCCESS, - 'message': RESPONSE_MESSAGES[SUCCESS], - 'adult_content_block': 'on' if self._is_enabled else 'off' - } - - def _toggle_blocking(self, action: str) -> Dict[str, Any]: - """Enable or disable blocking.""" - self.__class__._is_enabled = (action == 'enable') - status = 'enabled' if self._is_enabled else 'disabled' - - print(f"Adult content blocking {status}") - - return { - 'code': SUCCESS, - 'message': f"Adult content blocking has been {status}.", - 'adult_content_block': 'on' if self._is_enabled else 'off' - } +class DomainBlockHandler(RequestHandler): + def __init__(self, db_manager: DatabaseManager, state_manager: StateManager): + self.db_manager = db_manager + self.state_manager = state_manager - def _check_domain(self, domain: str) -> Dict[str, Any]: - """Check if a domain should be blocked.""" - if not domain: - return { - 'code': INVALID_REQUEST, - 'message': RESPONSE_MESSAGES[INVALID_REQUEST] - } + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + try: + action = request_data.get('action') + domain = request_data.get('domain') + + if not action or not domain: + return create_error_response("Missing action or domain parameter") + + if action == 'block': + self.state_manager.add_domain(domain, reason='manual') + return create_response( + ServerCodes.DOMAIN_BLOCKED, + ClientCodes.ADD_DOMAIN, + content={"domain": domain} + ) + + elif action == 'unblock': + if self.state_manager.remove_domain(domain): + return create_response( + ServerCodes.SUCCESS, + ClientCodes.REMOVE_DOMAIN, + content={"domain": domain} + ) + else: + return create_response( + ServerCodes.DOMAIN_NOT_FOUND, + ClientCodes.REMOVE_DOMAIN, + content={"domain": domain} + ) + + return create_error_response("Invalid action") - if self._is_enabled: - return { - 'code': ADULT_CONTENT_BLOCKED, - 'message': RESPONSE_MESSAGES[ADULT_CONTENT_BLOCKED] - } - - return { - 'code': SUCCESS, - 'message': RESPONSE_MESSAGES[SUCCESS] - } - - @classmethod - def is_blocking_enabled(cls) -> bool: - """Public method to check if blocking is enabled.""" - return cls._is_enabled - + except Exception as e: + return create_error_response(str(e)) class RequestFactory: - def __init__(self, db_manager: DatabaseManager): - """ - Initialize the RequestFactory with a database manager instance. - - Args: - db_manager: DatabaseManager instance for handling database operations - """ + def __init__(self, db_manager: DatabaseManager, state_manager: StateManager): self.db_manager = db_manager - # Map request types to handler creator functions + self.state_manager = state_manager self._handlers = { - 'ad_block': lambda: AdBlockHandler(self.db_manager), - 'domain_block': lambda: DomainBlockHandler(self.db_manager), - 'adult_content_block': lambda: AdultContentBlockHandler(self.db_manager) + ClientCodes.AD_BLOCK: lambda: AdBlockHandler(self.db_manager, self.state_manager), + ClientCodes.ADULT_BLOCK: lambda: AdultContentBlockHandler(self.db_manager, self.state_manager), + ClientCodes.ADD_DOMAIN: lambda: DomainBlockHandler(self.db_manager, self.state_manager), + ClientCodes.REMOVE_DOMAIN: lambda: DomainBlockHandler(self.db_manager, self.state_manager) } - + def create_request_handler(self, request_type: str) -> Optional[RequestHandler]: - """ - Creates and returns the appropriate request handler based on request type. - - Args: - request_type: The type of request to handle - - Returns: - RequestHandler instance or None if request type is not supported - """ handler_creator = self._handlers.get(request_type) return handler_creator() if handler_creator else None def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """ - Routes the request to appropriate handler and processes it. - - Args: - request_data: The request data containing type and other parameters + try: + request_type = request_data.get('code') + handler = self.create_request_handler(request_type) - Returns: - Dict containing response code and message - """ - request_type = request_data.get('type') - handler = self.create_request_handler(request_type) - - if handler: - return handler.handle_request(request_data) - else: - return { - 'code': INVALID_REQUEST, - 'message': RESPONSE_MESSAGES[INVALID_REQUEST] - } \ No newline at end of file + if handler: + response = handler.handle_request(request_data) + # After handling request, check if we need to broadcast updates + if request_type in [ClientCodes.ADD_DOMAIN, ClientCodes.REMOVE_DOMAIN]: + self._broadcast_domain_update() + return response + + return create_error_response("Invalid request type") + + except Exception as e: + return create_error_response(str(e)) + + def _broadcast_domain_update(self) -> None: + """Prepare domain list update for broadcasting.""" + domains = list(self.state_manager.get_all_domains().keys()) + # The server will use this to broadcast to all clients + return create_response( + ServerCodes.SUCCESS, + ClientCodes.DOMAIN_LIST_UPDATE, + content=domains + ) \ No newline at end of file From c2e761b090b9891ba60da287eaa9302d3a38c605 Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Wed, 6 Nov 2024 19:31:18 +0200 Subject: [PATCH 25/38] Update server with comprehensive connection handling and state management --- server/src/server.py | 245 +++++++++++++++++++++++++++++-------------- 1 file changed, 168 insertions(+), 77 deletions(-) diff --git a/server/src/server.py b/server/src/server.py index 001481d..71bb349 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -1,87 +1,180 @@ -# server.py - import asyncio import json -from typing import Dict, Any +import uuid +from typing import Dict, Any, Set, Optional from .config import HOST, CLIENT_PORT, KERNEL_PORT, DB_FILE from .db_manager import DatabaseManager -from .handlers import RequestFactory, AdultContentBlockHandler - - - -async def handle_client( - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - request_factory: RequestFactory -) -> None: - while True: +from .state_manager import StateManager +from .handlers import RequestFactory +from .response_codes import ( + create_response, + create_error_response, + ServerCodes, + ClientCodes +) + +class Server: + def __init__(self, db_file: str): + """Initialize server with all necessary components.""" + self.db_manager = DatabaseManager(db_file) + self.state_manager = StateManager(self.db_manager) + self.request_factory = RequestFactory(self.db_manager, self.state_manager) + + # Track client connections + self.clients: Dict[str, asyncio.StreamWriter] = {} + self.kernel_writers: Set[asyncio.StreamWriter] = set() + + async def handle_client( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter + ) -> None: + """Handle client connections with unique ID and state tracking.""" + client_id = str(uuid.uuid4()) + self.clients[client_id] = writer + self.state_manager.register_client(client_id) + try: - data = await reader.readline() - if not data: - break - - request_data = json.loads(data.decode('utf-8')) - response_data = request_factory.handle_request(request_data) - - writer.write(json.dumps(response_data).encode('utf-8') + b'\n') - await writer.drain() - - except ConnectionResetError: - print("Client disconnected.") - break - - writer.close() - - -async def handle_kernel( - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - db_manager: DatabaseManager -) -> None: - while True: + # Send initial state to client + await self.send_initial_state(writer, client_id) + + while True: + data = await reader.readline() + if not data: + break + + request_data = json.loads(data.decode('utf-8')) + response_data = self.request_factory.handle_request(request_data) + + # Send response to the requesting client + writer.write(json.dumps(response_data).encode('utf-8') + b'\n') + await writer.drain() + + # Broadcast updates if necessary + if self._should_broadcast(request_data): + await self.broadcast_state_update(client_id) + + except Exception as e: + print(f"Error handling client {client_id}: {str(e)}") + finally: + self.cleanup_client(client_id, writer) + + async def handle_kernel( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter + ) -> None: + """Handle kernel module connections.""" + self.kernel_writers.add(writer) try: - data = await reader.readline() - if not data: - break - - request_data = json.loads(data.decode('utf-8')) - response_data = route_kernel_request(request_data, db_manager) - - writer.write(json.dumps(response_data).encode('utf-8') + b'\n') - await writer.drain() - - except ConnectionResetError: - print("Kernel module disconnected.") - break - - writer.close() - - -def route_kernel_request(request_data: Dict[str, Any], db_manager: DatabaseManager) -> Dict[str, Any]: - domain = request_data.get('domain') - categories = request_data.get('categories', []) - - # Fast checks in order of most common to least common - should_block = ( - db_manager.is_domain_blocked(domain) or - db_manager.is_easylist_blocked(domain) or - (AdultContentBlockHandler.is_blocking_enabled() and 'adult' in categories) - ) - - return {'block': should_block} + while True: + data = await reader.readline() + if not data: + break + + request_data = json.loads(data.decode('utf-8')) + response = self.handle_kernel_request(request_data) + + writer.write(json.dumps(response).encode('utf-8') + b'\n') + await writer.drain() + + except Exception as e: + print(f"Error handling kernel request: {str(e)}") + finally: + self.kernel_writers.remove(writer) + writer.close() + await writer.wait_closed() + + def handle_kernel_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Process kernel module requests.""" + domain = request_data.get('domain') + categories = request_data.get('categories', []) + + # Check various blocking conditions + should_block = ( + self.db_manager.is_domain_blocked(domain) or + (self.db_manager.is_easylist_blocked(domain) and + self.state_manager.get_feature_state('ad_block')['state'] == 'on') or + ( + 'adult' in categories and + self.state_manager.get_feature_state('adult_block')['state'] == 'on' + ) + ) + return {'block': should_block} + + async def send_initial_state( + self, + writer: asyncio.StreamWriter, + client_id: str + ) -> None: + """Send initial state to new client connections.""" + initial_state = { + 'settings': { + 'ad_block': self.state_manager.get_feature_state('ad_block'), + 'adult_block': self.state_manager.get_feature_state('adult_block') + }, + 'domains': list(self.state_manager.get_all_domains().keys()) + } + + response = create_response( + ServerCodes.SUCCESS, + ClientCodes.DOMAIN_LIST_UPDATE, + content=initial_state + ) + + writer.write(json.dumps(response).encode('utf-8') + b'\n') + await writer.drain() + + async def broadcast_state_update(self, exclude_client: Optional[str] = None) -> None: + """Broadcast state updates to all connected clients except the sender.""" + domains = list(self.state_manager.get_all_domains().keys()) + update_message = create_response( + ServerCodes.SUCCESS, + ClientCodes.DOMAIN_LIST_UPDATE, + content=domains + ) + + message_data = json.dumps(update_message).encode('utf-8') + b'\n' + + for client_id, writer in self.clients.items(): + if client_id != exclude_client: + try: + writer.write(message_data) + await writer.drain() + except Exception as e: + print(f"Error broadcasting to client {client_id}: {str(e)}") + + def cleanup_client(self, client_id: str, writer: asyncio.StreamWriter) -> None: + """Clean up client connection resources.""" + if client_id in self.clients: + del self.clients[client_id] + self.state_manager.unregister_client(client_id) + writer.close() + + def _should_broadcast(self, request_data: Dict[str, Any]) -> bool: + """Determine if a request should trigger a broadcast.""" + broadcast_codes = { + ClientCodes.ADD_DOMAIN, + ClientCodes.REMOVE_DOMAIN, + ClientCodes.AD_BLOCK, + ClientCodes.ADULT_BLOCK + } + return request_data.get('code') in broadcast_codes async def start_server(db_manager: DatabaseManager) -> None: - request_factory = RequestFactory(db_manager) + """Start the server with both client and kernel handlers.""" + server = Server(db_manager.db_file) client_server = await asyncio.start_server( - lambda r, w: handle_client(r, w, request_factory), - HOST, + server.handle_client, + HOST, CLIENT_PORT ) + kernel_server = await asyncio.start_server( - lambda r, w: handle_kernel(r, w, db_manager), - HOST, + server.handle_kernel, + HOST, KERNEL_PORT ) @@ -94,13 +187,11 @@ async def start_server(db_manager: DatabaseManager) -> None: kernel_server.serve_forever() ) - def run(db_file: str) -> None: - """Initialize and run the server with the given database file.""" - db_manager = DatabaseManager(db_file) + """Initialize and run the server.""" try: - asyncio.run(start_server(db_manager)) + asyncio.run(start_server(DatabaseManager(db_file))) except KeyboardInterrupt: - print("Server stopped by user.") - finally: - db_manager.close() \ No newline at end of file + print("Server stopped by user") + except Exception as e: + print(f"Server error: {str(e)}") \ No newline at end of file From b93cd74c6b46c696c1c8d7c868859e9288229dcb Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Wed, 6 Nov 2024 23:11:25 +0200 Subject: [PATCH 26/38] change logic - start ingreate with client logic --- server/src/config.py | 8 +- server/src/db_manager.py | 251 +++++++++++------------------------ server/src/handlers.py | 236 ++++++++++++++++---------------- server/src/response_codes.py | 78 +++-------- server/src/server.py | 202 +++++++++++++--------------- 5 files changed, 312 insertions(+), 463 deletions(-) diff --git a/server/src/config.py b/server/src/config.py index 9f4955a..76737c8 100644 --- a/server/src/config.py +++ b/server/src/config.py @@ -1,6 +1,8 @@ -# config.py - +# Network Configuration HOST: str = '127.0.0.1' CLIENT_PORT: int = 65432 KERNEL_PORT: int = 65433 -DB_FILE: str = 'my_internet.db' \ No newline at end of file +DB_FILE: str = 'my_internet.db' + +# EasyList URL +EASYLIST_URL = "https://easylist.to/easylist/easylist.txt" \ No newline at end of file diff --git a/server/src/db_manager.py b/server/src/db_manager.py index ef14f5d..29ab392 100644 --- a/server/src/db_manager.py +++ b/server/src/db_manager.py @@ -1,216 +1,125 @@ import sqlite3 from typing import List, Dict, Any, Optional -from datetime import datetime -import threading +import requests class DatabaseManager: def __init__(self, db_file: str): self.db_file = db_file - self._connection_lock = threading.Lock() self.create_tables() def create_tables(self) -> None: - """Create necessary tables if they don't exist.""" + """Create the necessary tables if they don't exist.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - # Existing tables with enhancements + # Blocked domains table cursor.execute(""" CREATE TABLE IF NOT EXISTS blocked_domains ( id INTEGER PRIMARY KEY AUTOINCREMENT, - domain TEXT UNIQUE, - timestamp TEXT, - reason TEXT, - active BOOLEAN DEFAULT TRUE + domain TEXT UNIQUE ) """) + # Easylist entries table cursor.execute(""" CREATE TABLE IF NOT EXISTS easylist ( id INTEGER PRIMARY KEY AUTOINCREMENT, - entry TEXT UNIQUE, - category TEXT, - timestamp TEXT + entry TEXT UNIQUE ) """) - # New tables for state management + # Settings table for toggles cursor.execute(""" - CREATE TABLE IF NOT EXISTS feature_states ( - feature TEXT PRIMARY KEY, - state TEXT, - last_updated TEXT + CREATE TABLE IF NOT EXISTS settings ( + setting TEXT PRIMARY KEY, + value TEXT ) """) + # Initialize settings if not exists cursor.execute(""" - CREATE TABLE IF NOT EXISTS client_sessions ( - client_id TEXT PRIMARY KEY, - last_sync TEXT, - connected BOOLEAN, - last_updated TEXT - ) - """) - - # Initialize feature states if not exists - cursor.execute(""" - INSERT OR IGNORE INTO feature_states (feature, state, last_updated) + INSERT OR IGNORE INTO settings (setting, value) VALUES - ('ad_block', 'off', ?), - ('adult_block', 'off', ?) - """, (datetime.now().isoformat(), datetime.now().isoformat())) + ('ad_block', 'off'), + ('adult_block', 'off') + """) conn.commit() - def get_feature_state(self, feature: str) -> Optional[Dict[str, Any]]: - """Get the current state of a feature.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT state, last_updated - FROM feature_states - WHERE feature = ? - """, (feature,)) - - result = cursor.fetchone() - if result: - return { - "state": result[0], - "last_updated": result[1] - } - return None - - def update_feature_state(self, feature: str, state: str) -> None: - """Update the state of a feature.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - UPDATE feature_states - SET state = ?, last_updated = ? - WHERE feature = ? - """, (state, datetime.now().isoformat(), feature)) - conn.commit() + def get_setting(self, setting: str) -> str: + """Get setting value.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("SELECT value FROM settings WHERE setting = ?", (setting,)) + result = cursor.fetchone() + return result[0] if result else 'off' - def add_blocked_domain(self, domain: str, reason: str = "manual") -> None: - """Add a domain to blocked list with metadata.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - try: - cursor.execute(""" - INSERT INTO blocked_domains (domain, timestamp, reason, active) - VALUES (?, ?, ?, TRUE) - """, (domain, datetime.now().isoformat(), reason)) - conn.commit() - except sqlite3.IntegrityError: - cursor.execute(""" - UPDATE blocked_domains - SET active = TRUE, timestamp = ?, reason = ? - WHERE domain = ? - """, (datetime.now().isoformat(), reason, domain)) - conn.commit() + def update_setting(self, setting: str, value: str) -> None: + """Update setting value.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE settings + SET value = ? + WHERE setting = ? + """, (value, setting)) + conn.commit() - def remove_blocked_domain(self, domain: str) -> bool: - """Soft delete a domain from blocked list.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - UPDATE blocked_domains - SET active = FALSE, timestamp = ? - WHERE domain = ? - """, (datetime.now().isoformat(), domain)) + def add_blocked_domain(self, domain: str) -> None: + """Add a domain to blocked list.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + try: + cursor.execute("INSERT INTO blocked_domains (domain) VALUES (?)", (domain,)) conn.commit() - return cursor.rowcount > 0 + except sqlite3.IntegrityError: + print(f"Domain {domain} already exists in the database.") - def get_blocked_domains(self) -> List[Dict[str, Any]]: - """Get all active blocked domains with their metadata.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT domain, timestamp, reason - FROM blocked_domains - WHERE active = TRUE - """) - return [ - { - "domain": row[0], - "timestamp": row[1], - "reason": row[2] - } - for row in cursor.fetchall() - ] + def remove_blocked_domain(self, domain: str) -> bool: + """Remove a domain from blocked list.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM blocked_domains WHERE domain = ?", (domain,)) + conn.commit() + return cursor.rowcount > 0 - def store_easylist_entries(self, entries: List[tuple[str, str]]) -> None: - """Store easylist entries with categories.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - timestamp = datetime.now().isoformat() - cursor.executemany(""" - INSERT OR REPLACE INTO easylist (entry, category, timestamp) - VALUES (?, ?, ?) - """, [(entry, category, timestamp) for entry, category in entries]) - conn.commit() + def get_blocked_domains(self) -> List[str]: + """Get list of all blocked domains.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("SELECT domain FROM blocked_domains") + return [row[0] for row in cursor.fetchall()] def is_domain_blocked(self, domain: str) -> bool: - """Check if a domain is manually blocked.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT 1 FROM blocked_domains - WHERE domain = ? AND active = TRUE - """, (domain,)) - return cursor.fetchone() is not None - - def is_easylist_blocked(self, domain: str) -> bool: - """Check if a domain matches any easylist pattern.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - SELECT 1 FROM easylist - WHERE ? GLOB '*' || entry || '*' - """, (domain,)) - return cursor.fetchone() is not None + """Check if domain is in blocked list.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1 FROM blocked_domains WHERE domain = ?", (domain,)) + return cursor.fetchone() is not None - def register_client(self, client_id: str) -> None: - """Register or update a client session.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - timestamp = datetime.now().isoformat() - cursor.execute(""" - INSERT OR REPLACE INTO client_sessions - (client_id, last_sync, connected, last_updated) - VALUES (?, ?, TRUE, ?) - """, (client_id, timestamp, timestamp)) - conn.commit() + def store_easylist_entries(self, entries: List[tuple]) -> None: + """Store easylist entries.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.executemany( + "INSERT OR IGNORE INTO easylist (entry) VALUES (?)", + entries + ) + conn.commit() - def unregister_client(self, client_id: str) -> None: - """Mark a client as disconnected.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - UPDATE client_sessions - SET connected = FALSE, last_updated = ? - WHERE client_id = ? - """, (datetime.now().isoformat(), client_id)) - conn.commit() + def is_easylist_blocked(self, domain: str) -> bool: + """Check if domain matches any easylist pattern.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute( + "SELECT 1 FROM easylist WHERE ? GLOB '*' || entry || '*'", + (domain,) + ) + return cursor.fetchone() is not None def clear_easylist(self) -> None: """Clear all easylist entries.""" - with self._connection_lock: - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute("DELETE FROM easylist") - conn.commit() - - def close(self) -> None: - """Clean up resources.""" - pass \ No newline at end of file + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("DELETE FROM easylist") + conn.commit() \ No newline at end of file diff --git a/server/src/handlers.py b/server/src/handlers.py index 7734d1f..ee1fad6 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -1,14 +1,9 @@ -from abc import ABC, abstractmethod -from typing import Dict, Any, Optional import requests +from abc import ABC, abstractmethod +from typing import Dict, Any from .db_manager import DatabaseManager -from .state_manager import StateManager -from .response_codes import ( - create_response, - create_error_response, - ServerCodes, - ClientCodes -) +from .config import EASYLIST_URL +from .response_codes import Codes, RESPONSE_MESSAGES class RequestHandler(ABC): @abstractmethod @@ -16,160 +11,173 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: pass class AdBlockHandler(RequestHandler): - def __init__(self, db_manager: DatabaseManager, state_manager: StateManager): + def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager - self.state_manager = state_manager self.load_easylist() + def load_easylist(self) -> None: + """Load and parse easylist.""" + try: + response = requests.get(EASYLIST_URL) + response.raise_for_status() + easylist_data = response.text + + # Parse and store valid entries + entries = [] + for line in easylist_data.split("\n"): + line = line.strip() + if line and not line.startswith("!"): + entries.append((line,)) + + self.db_manager.clear_easylist() + self.db_manager.store_easylist_entries(entries) + + except requests.exceptions.RequestException as e: + print(f"Error loading EasyList: {e}") + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle ad block requests.""" try: if 'action' in request_data: # Handle toggle request state = request_data['action'] # 'on' or 'off' - self.state_manager.update_feature_state('ad_block', state) - return create_response( - ServerCodes.SUCCESS, - ClientCodes.AD_BLOCK, - content={"state": state} - ) - elif 'domain' in request_data: - # Handle domain check - domain = request_data['domain'] - if self.is_domain_blocked(domain): - return create_response( - ServerCodes.AD_BLOCK_ENABLED, - ClientCodes.AD_BLOCK, - content={"domain": domain} - ) + self.db_manager.update_setting('ad_block', state) + return { + 'code': Codes.CODE_AD_BLOCK, + 'message': RESPONSE_MESSAGES['success'] + } - return create_response(ServerCodes.SUCCESS) + elif 'domain' in request_data: + # Check if domain should be blocked + if self.is_domain_blocked(request_data['domain']): + return { + 'code': Codes.CODE_AD_BLOCK, + 'message': "Domain contains ads" + } + + return { + 'code': Codes.CODE_AD_BLOCK, + 'message': RESPONSE_MESSAGES['success'] + } except Exception as e: - return create_error_response(str(e)) + return { + 'code': Codes.CODE_AD_BLOCK, + 'message': str(e) + } def is_domain_blocked(self, domain: str) -> bool: - feature_state = self.state_manager.get_feature_state('ad_block') - if feature_state['state'] == 'off': + """Check if domain should be blocked based on easylist.""" + if self.db_manager.get_setting('ad_block') == 'off': return False return self.db_manager.is_easylist_blocked(domain) class AdultContentBlockHandler(RequestHandler): - def __init__(self, db_manager: DatabaseManager, state_manager: StateManager): + def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager - self.state_manager = state_manager def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle adult content block requests.""" try: - action = request_data.get('action') - if not action: - return create_error_response("Missing action parameter") - - if action in ['enable', 'disable']: - # Convert enable/disable to on/off - state = 'on' if action == 'enable' else 'off' - self.state_manager.update_feature_state('adult_block', state) - return create_response( - ServerCodes.SUCCESS, - ClientCodes.ADULT_BLOCK, - content={"state": state} - ) - - elif action == 'check': - domain = request_data.get('domain') - if not domain: - return create_error_response("Missing domain parameter") - - feature_state = self.state_manager.get_feature_state('adult_block') - if feature_state['state'] == 'on': - return create_response( - ServerCodes.ADULT_CONTENT_BLOCKED, - ClientCodes.ADULT_BLOCK, - content={"domain": domain} - ) - - return create_response(ServerCodes.SUCCESS) + if 'action' in request_data: + # Handle toggle request + state = request_data['action'] # 'on' or 'off' + self.db_manager.update_setting('adult_block', state) + return { + 'code': Codes.CODE_ADULT_BLOCK, + 'message': RESPONSE_MESSAGES['success'] + } + + elif 'domain' in request_data: + # Check if adult blocking is enabled + if self.db_manager.get_setting('adult_block') == 'on': + return { + 'code': Codes.CODE_ADULT_BLOCK, + 'message': "Adult content blocked" + } + + return { + 'code': Codes.CODE_ADULT_BLOCK, + 'message': RESPONSE_MESSAGES['success'] + } except Exception as e: - return create_error_response(str(e)) + return { + 'code': Codes.CODE_ADULT_BLOCK, + 'message': str(e) + } class DomainBlockHandler(RequestHandler): - def __init__(self, db_manager: DatabaseManager, state_manager: StateManager): + def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager - self.state_manager = state_manager def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle domain block/unblock requests.""" try: action = request_data.get('action') domain = request_data.get('domain') - if not action or not domain: - return create_error_response("Missing action or domain parameter") + if not domain: + return { + 'code': request_data.get('code'), + 'message': RESPONSE_MESSAGES['invalid_request'] + } if action == 'block': - self.state_manager.add_domain(domain, reason='manual') - return create_response( - ServerCodes.DOMAIN_BLOCKED, - ClientCodes.ADD_DOMAIN, - content={"domain": domain} - ) + self.db_manager.add_blocked_domain(domain) + return { + 'code': Codes.CODE_ADD_DOMAIN, + 'message': RESPONSE_MESSAGES['domain_blocked'] + } elif action == 'unblock': - if self.state_manager.remove_domain(domain): - return create_response( - ServerCodes.SUCCESS, - ClientCodes.REMOVE_DOMAIN, - content={"domain": domain} - ) + if self.db_manager.remove_blocked_domain(domain): + return { + 'code': Codes.CODE_REMOVE_DOMAIN, + 'message': RESPONSE_MESSAGES['success'] + } else: - return create_response( - ServerCodes.DOMAIN_NOT_FOUND, - ClientCodes.REMOVE_DOMAIN, - content={"domain": domain} - ) - - return create_error_response("Invalid action") + return { + 'code': Codes.CODE_REMOVE_DOMAIN, + 'message': RESPONSE_MESSAGES['domain_not_found'] + } + + return { + 'code': request_data.get('code'), + 'message': RESPONSE_MESSAGES['invalid_request'] + } except Exception as e: - return create_error_response(str(e)) + return { + 'code': request_data.get('code'), + 'message': str(e) + } class RequestFactory: - def __init__(self, db_manager: DatabaseManager, state_manager: StateManager): + def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager - self.state_manager = state_manager self._handlers = { - ClientCodes.AD_BLOCK: lambda: AdBlockHandler(self.db_manager, self.state_manager), - ClientCodes.ADULT_BLOCK: lambda: AdultContentBlockHandler(self.db_manager, self.state_manager), - ClientCodes.ADD_DOMAIN: lambda: DomainBlockHandler(self.db_manager, self.state_manager), - ClientCodes.REMOVE_DOMAIN: lambda: DomainBlockHandler(self.db_manager, self.state_manager) + Codes.CODE_AD_BLOCK: lambda: AdBlockHandler(self.db_manager), + Codes.CODE_ADULT_BLOCK: lambda: AdultContentBlockHandler(self.db_manager), + Codes.CODE_ADD_DOMAIN: lambda: DomainBlockHandler(self.db_manager), + Codes.CODE_REMOVE_DOMAIN: lambda: DomainBlockHandler(self.db_manager) } - def create_request_handler(self, request_type: str) -> Optional[RequestHandler]: + def create_request_handler(self, request_type: str) -> RequestHandler: + """Create appropriate handler based on request type.""" handler_creator = self._handlers.get(request_type) - return handler_creator() if handler_creator else None + if handler_creator: + return handler_creator() + raise ValueError(f"Invalid request type: {request_type}") def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle incoming request using appropriate handler.""" try: request_type = request_data.get('code') handler = self.create_request_handler(request_type) - - if handler: - response = handler.handle_request(request_data) - # After handling request, check if we need to broadcast updates - if request_type in [ClientCodes.ADD_DOMAIN, ClientCodes.REMOVE_DOMAIN]: - self._broadcast_domain_update() - return response - - return create_error_response("Invalid request type") - + return handler.handle_request(request_data) except Exception as e: - return create_error_response(str(e)) - - def _broadcast_domain_update(self) -> None: - """Prepare domain list update for broadcasting.""" - domains = list(self.state_manager.get_all_domains().keys()) - # The server will use this to broadcast to all clients - return create_response( - ServerCodes.SUCCESS, - ClientCodes.DOMAIN_LIST_UPDATE, - content=domains - ) \ No newline at end of file + return { + 'code': request_data.get('code', ''), + 'message': str(e) + } \ No newline at end of file diff --git a/server/src/response_codes.py b/server/src/response_codes.py index ba43d09..d318126 100644 --- a/server/src/response_codes.py +++ b/server/src/response_codes.py @@ -1,64 +1,18 @@ -from typing import Dict, Any +from typing import Dict -# Client Command Codes (incoming requests) -class ClientCodes: - AD_BLOCK = "50" # Toggle ad blocking - ADULT_BLOCK = "51" # Toggle adult content blocking - ADD_DOMAIN = "52" # Add domain to block list - REMOVE_DOMAIN = "53" # Remove domain from block list - DOMAIN_LIST_UPDATE = "54" # Update domain list +# Client Command Codes (exact match with client's codes) +class Codes: + CODE_AD_BLOCK = "50" + CODE_ADULT_BLOCK = "51" + CODE_ADD_DOMAIN = "52" + CODE_REMOVE_DOMAIN = "53" + CODE_DOMAIN_LIST_UPDATE = "54" -# Server Response Codes -class ServerCodes: - SUCCESS = 200 - INVALID_REQUEST = 400 - DOMAIN_BLOCKED = 201 - DOMAIN_NOT_FOUND = 404 - AD_BLOCK_ENABLED = 202 - ADULT_CONTENT_BLOCKED = 203 - -# Response Messages -RESPONSE_MESSAGES: Dict[int, str] = { - ServerCodes.SUCCESS: "Request processed successfully.", - ServerCodes.INVALID_REQUEST: "Invalid request. Please check the request format.", - ServerCodes.DOMAIN_BLOCKED: "Domain has been successfully blocked.", - ServerCodes.DOMAIN_NOT_FOUND: "Domain not found in the block list.", - ServerCodes.AD_BLOCK_ENABLED: "Ad blocking has been enabled for the domain.", - ServerCodes.ADULT_CONTENT_BLOCKED: "Adult content has been blocked for the domain." -} - -def create_response( - server_code: int, - client_code: str = None, - content: Any = None, - message: str = None -) -> Dict[str, Any]: - """ - Create a standardized response format. - - Args: - server_code: Internal server response code - client_code: Client's command code (if responding to specific command) - content: Optional response payload - message: Custom message (uses default if None) - - Returns: - Formatted response dictionary - """ - response = { - "code": client_code if client_code else str(server_code), - "message": message if message else RESPONSE_MESSAGES.get(server_code, ""), - "status": server_code < 400 # True for success codes, False for error codes - } - - if content is not None: - response["content"] = content - - return response - -def create_error_response(message: str) -> Dict[str, Any]: - """Create a standardized error response.""" - return create_response( - ServerCodes.INVALID_REQUEST, - message=message - ) \ No newline at end of file +# Response messages +RESPONSE_MESSAGES = { + 'success': "Request processed successfully.", + 'invalid_request': "Invalid request format.", + 'domain_blocked': "Domain has been successfully blocked.", + 'domain_not_found': "Domain not found in block list.", + 'domain_exists': "Domain already exists in block list." +} \ No newline at end of file diff --git a/server/src/server.py b/server/src/server.py index 71bb349..53fb5cc 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -1,63 +1,69 @@ import asyncio import json -import uuid -from typing import Dict, Any, Set, Optional -from .config import HOST, CLIENT_PORT, KERNEL_PORT, DB_FILE +from typing import Dict, Any +from .config import HOST, CLIENT_PORT, KERNEL_PORT from .db_manager import DatabaseManager -from .state_manager import StateManager from .handlers import RequestFactory -from .response_codes import ( - create_response, - create_error_response, - ServerCodes, - ClientCodes -) +from .response_codes import Codes, RESPONSE_MESSAGES class Server: def __init__(self, db_file: str): - """Initialize server with all necessary components.""" + """Initialize server with database and request factory.""" self.db_manager = DatabaseManager(db_file) - self.state_manager = StateManager(self.db_manager) - self.request_factory = RequestFactory(self.db_manager, self.state_manager) - - # Track client connections - self.clients: Dict[str, asyncio.StreamWriter] = {} - self.kernel_writers: Set[asyncio.StreamWriter] = set() + self.request_factory = RequestFactory(self.db_manager) + self.client_writer = None # Store the single client connection async def handle_client( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: - """Handle client connections with unique ID and state tracking.""" - client_id = str(uuid.uuid4()) - self.clients[client_id] = writer - self.state_manager.register_client(client_id) - + """Handle the client connection.""" + # Store the client writer for potential updates + self.client_writer = writer + print(f"Client connected from {writer.get_extra_info('peername')}") + try: - # Send initial state to client - await self.send_initial_state(writer, client_id) - + # Send initial domain list to client + await self._send_domain_list() + while True: data = await reader.readline() if not data: break - request_data = json.loads(data.decode('utf-8')) - response_data = self.request_factory.handle_request(request_data) + try: + request_data = json.loads(data.decode('utf-8')) + print(f"Received from client: {request_data}") + + # Process request + response = self.request_factory.handle_request(request_data) + print(f"Sending response: {response}") - # Send response to the requesting client - writer.write(json.dumps(response_data).encode('utf-8') + b'\n') - await writer.drain() + # Send response + writer.write(json.dumps(response).encode('utf-8') + b'\n') + await writer.drain() - # Broadcast updates if necessary - if self._should_broadcast(request_data): - await self.broadcast_state_update(client_id) + # If domain list was modified, send update + if request_data.get('code') in [Codes.CODE_ADD_DOMAIN, Codes.CODE_REMOVE_DOMAIN]: + await self._send_domain_list() + + except json.JSONDecodeError: + print("Invalid JSON received") + error_response = { + 'code': request_data.get('code', ''), + 'message': RESPONSE_MESSAGES['invalid_request'] + } + writer.write(json.dumps(error_response).encode('utf-8') + b'\n') + await writer.drain() except Exception as e: - print(f"Error handling client {client_id}: {str(e)}") + print(f"Error handling client: {e}") finally: - self.cleanup_client(client_id, writer) + self.client_writer = None + writer.close() + await writer.wait_closed() + print("Client disconnected") async def handle_kernel( self, @@ -65,123 +71,93 @@ async def handle_kernel( writer: asyncio.StreamWriter ) -> None: """Handle kernel module connections.""" - self.kernel_writers.add(writer) + print(f"Kernel module connected from {writer.get_extra_info('peername')}") + try: while True: data = await reader.readline() if not data: break - request_data = json.loads(data.decode('utf-8')) - response = self.handle_kernel_request(request_data) - - writer.write(json.dumps(response).encode('utf-8') + b'\n') - await writer.drain() + try: + request_data = json.loads(data.decode('utf-8')) + print(f"Received from kernel: {request_data}") + + # Process kernel request + response = self.handle_kernel_request(request_data) + print(f"Sending to kernel: {response}") + + writer.write(json.dumps(response).encode('utf-8') + b'\n') + await writer.drain() + + except json.JSONDecodeError: + print("Invalid JSON received from kernel") except Exception as e: - print(f"Error handling kernel request: {str(e)}") + print(f"Error handling kernel module: {e}") finally: - self.kernel_writers.remove(writer) writer.close() await writer.wait_closed() + print("Kernel module disconnected") def handle_kernel_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Process kernel module requests.""" domain = request_data.get('domain') - categories = request_data.get('categories', []) + if not domain: + return {'block': False} - # Check various blocking conditions + # Check if domain should be blocked should_block = ( + # Check manually blocked domains self.db_manager.is_domain_blocked(domain) or - (self.db_manager.is_easylist_blocked(domain) and - self.state_manager.get_feature_state('ad_block')['state'] == 'on') or + # Check ad blocking ( - 'adult' in categories and - self.state_manager.get_feature_state('adult_block')['state'] == 'on' + self.db_manager.get_setting('ad_block') == 'on' and + self.db_manager.is_easylist_blocked(domain) + ) or + # Check adult content blocking + ( + self.db_manager.get_setting('adult_block') == 'on' and + 'adult' in request_data.get('categories', []) ) ) return {'block': should_block} - async def send_initial_state( - self, - writer: asyncio.StreamWriter, - client_id: str - ) -> None: - """Send initial state to new client connections.""" - initial_state = { - 'settings': { - 'ad_block': self.state_manager.get_feature_state('ad_block'), - 'adult_block': self.state_manager.get_feature_state('adult_block') - }, - 'domains': list(self.state_manager.get_all_domains().keys()) - } - - response = create_response( - ServerCodes.SUCCESS, - ClientCodes.DOMAIN_LIST_UPDATE, - content=initial_state - ) - - writer.write(json.dumps(response).encode('utf-8') + b'\n') - await writer.drain() - - async def broadcast_state_update(self, exclude_client: Optional[str] = None) -> None: - """Broadcast state updates to all connected clients except the sender.""" - domains = list(self.state_manager.get_all_domains().keys()) - update_message = create_response( - ServerCodes.SUCCESS, - ClientCodes.DOMAIN_LIST_UPDATE, - content=domains - ) - - message_data = json.dumps(update_message).encode('utf-8') + b'\n' - - for client_id, writer in self.clients.items(): - if client_id != exclude_client: - try: - writer.write(message_data) - await writer.drain() - except Exception as e: - print(f"Error broadcasting to client {client_id}: {str(e)}") - - def cleanup_client(self, client_id: str, writer: asyncio.StreamWriter) -> None: - """Clean up client connection resources.""" - if client_id in self.clients: - del self.clients[client_id] - self.state_manager.unregister_client(client_id) - writer.close() - - def _should_broadcast(self, request_data: Dict[str, Any]) -> bool: - """Determine if a request should trigger a broadcast.""" - broadcast_codes = { - ClientCodes.ADD_DOMAIN, - ClientCodes.REMOVE_DOMAIN, - ClientCodes.AD_BLOCK, - ClientCodes.ADULT_BLOCK - } - return request_data.get('code') in broadcast_codes + async def _send_domain_list(self) -> None: + """Send updated domain list to client.""" + if self.client_writer: + domains = self.db_manager.get_blocked_domains() + update_message = { + 'code': Codes.CODE_DOMAIN_LIST_UPDATE, + 'content': domains + } + try: + self.client_writer.write(json.dumps(update_message).encode('utf-8') + b'\n') + await self.client_writer.drain() + except Exception as e: + print(f"Error sending domain list update: {e}") async def start_server(db_manager: DatabaseManager) -> None: """Start the server with both client and kernel handlers.""" server = Server(db_manager.db_file) - + client_server = await asyncio.start_server( server.handle_client, HOST, CLIENT_PORT ) - + kernel_server = await asyncio.start_server( server.handle_kernel, HOST, KERNEL_PORT ) - print(f"Client server running on {HOST}:{CLIENT_PORT}") - print(f"Kernel server running on {HOST}:{KERNEL_PORT}") - async with client_server, kernel_server: + print(f"Client server running on {HOST}:{CLIENT_PORT}") + print(f"Kernel server running on {HOST}:{KERNEL_PORT}") + await asyncio.gather( client_server.serve_forever(), kernel_server.serve_forever() @@ -192,6 +168,6 @@ def run(db_file: str) -> None: try: asyncio.run(start_server(DatabaseManager(db_file))) except KeyboardInterrupt: - print("Server stopped by user") + print("\nServer stopped by user") except Exception as e: print(f"Server error: {str(e)}") \ No newline at end of file From a2e3788da004d247d4894bda4ee5401385bb2fc8 Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Wed, 6 Nov 2024 23:24:21 +0200 Subject: [PATCH 27/38] remove state manager --- server/src/state_manager.py | 135 ------------------------------------ 1 file changed, 135 deletions(-) delete mode 100644 server/src/state_manager.py diff --git a/server/src/state_manager.py b/server/src/state_manager.py deleted file mode 100644 index 1da9e78..0000000 --- a/server/src/state_manager.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import Dict, Any, Optional -from datetime import datetime -import threading -from .db_manager import DatabaseManager - -class StateManager: - """Manages application state including feature toggles and domain states.""" - - def __init__(self, db_manager: DatabaseManager): - self._db_manager = db_manager - self._state_lock = threading.Lock() - self._states = { - "settings": { - "ad_block": { - "state": "off", - "last_updated": datetime.now().isoformat() - }, - "adult_block": { - "state": "off", - "last_updated": datetime.now().isoformat() - } - }, - "domains": {}, # Will be populated from database - "clients": {} # Will track connected clients - } - self._load_initial_state() - - def _load_initial_state(self) -> None: - """Load initial state from database.""" - # Load blocked domains - domains = self._db_manager.get_blocked_domains() - for domain in domains: - self._states["domains"][domain] = { - "blocked": True, - "timestamp": datetime.now().isoformat(), - "reason": "manual" - } - - def update_feature_state(self, feature: str, state: str) -> Dict[str, Any]: - """ - Update the state of a feature (ad_block or adult_block). - - Args: - feature: Feature to update ('ad_block' or 'adult_block') - state: New state ('on' or 'off') - - Returns: - Dict containing the updated state information - """ - with self._state_lock: - if feature not in self._states["settings"]: - raise ValueError(f"Invalid feature: {feature}") - - if state not in ["on", "off"]: - raise ValueError(f"Invalid state: {state}") - - self._states["settings"][feature] = { - "state": state, - "last_updated": datetime.now().isoformat() - } - - return self._states["settings"][feature] - - def get_feature_state(self, feature: str) -> Dict[str, Any]: - """Get the current state of a feature.""" - with self._state_lock: - if feature not in self._states["settings"]: - raise ValueError(f"Invalid feature: {feature}") - - return self._states["settings"][feature] - - def add_domain(self, domain: str, reason: str = "manual") -> Dict[str, Any]: - """ - Add a domain to blocked list. - - Args: - domain: Domain to block - reason: Reason for blocking ('manual', 'easylist', 'adult') - - Returns: - Dict containing the domain state information - """ - with self._state_lock: - domain_state = { - "blocked": True, - "timestamp": datetime.now().isoformat(), - "reason": reason - } - self._states["domains"][domain] = domain_state - self._db_manager.add_blocked_domain(domain) - - return domain_state - - def remove_domain(self, domain: str) -> bool: - """ - Remove a domain from blocked list. - - Returns: - bool indicating if domain was removed - """ - with self._state_lock: - if domain in self._states["domains"]: - del self._states["domains"][domain] - self._db_manager.remove_blocked_domain(domain) - return True - return False - - def get_domain_state(self, domain: str) -> Optional[Dict[str, Any]]: - """Get the current state of a domain.""" - with self._state_lock: - return self._states["domains"].get(domain) - - def get_all_domains(self) -> Dict[str, Dict[str, Any]]: - """Get all domain states.""" - with self._state_lock: - return self._states["domains"].copy() - - def register_client(self, client_id: str) -> None: - """Register a new client connection.""" - with self._state_lock: - self._states["clients"][client_id] = { - "last_sync": datetime.now().isoformat(), - "connected": True - } - - def unregister_client(self, client_id: str) -> None: - """Unregister a client connection.""" - with self._state_lock: - if client_id in self._states["clients"]: - self._states["clients"][client_id]["connected"] = False - - def get_full_state(self) -> Dict[str, Any]: - """Get the complete current state.""" - with self._state_lock: - return self._states.copy() \ No newline at end of file From 5f1dfbfb28c1cb7025eb9ab936ce21fa4879cea3 Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Sun, 10 Nov 2024 15:10:59 +0200 Subject: [PATCH 28/38] server files --- server/main.py | 4 +- server/src/db_manager.py | 108 +++++++++++-- server/src/easylist_manager.py | 86 ++++++++++ server/src/filter_rules.py | 82 ++++++++++ server/src/handlers.py | 62 +++++--- server/src/server.py | 238 +++++++++++----------------- server/tests/conftest.py | 34 ++-- server/tests/test_handlers.py | 281 +++++++++++++-------------------- server/tests/test_server.py | 247 ++++++++++++++++++++--------- 9 files changed, 693 insertions(+), 449 deletions(-) create mode 100644 server/src/easylist_manager.py create mode 100644 server/src/filter_rules.py diff --git a/server/main.py b/server/main.py index 409bf27..8409cff 100644 --- a/server/main.py +++ b/server/main.py @@ -1,5 +1,5 @@ -from src.server import run +from src.server import initialize_server from src.config import DB_FILE if __name__ == '__main__': - run(DB_FILE) \ No newline at end of file + initialize_server(DB_FILE) \ No newline at end of file diff --git a/server/src/db_manager.py b/server/src/db_manager.py index 29ab392..e2537bb 100644 --- a/server/src/db_manager.py +++ b/server/src/db_manager.py @@ -1,6 +1,8 @@ import sqlite3 +import json from typing import List, Dict, Any, Optional import requests +from .filter_rules import FilterRule, PatternType class DatabaseManager: def __init__(self, db_file: str): @@ -12,7 +14,14 @@ def create_tables(self) -> None: with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - # Blocked domains table + # Drop existing indices first to avoid conflicts + cursor.execute("DROP INDEX IF EXISTS idx_pattern_type") + cursor.execute("DROP INDEX IF EXISTS idx_processed_pattern") + + # Drop existing tables to ensure clean schema + cursor.execute("DROP TABLE IF EXISTS easylist") + + # Create tables cursor.execute(""" CREATE TABLE IF NOT EXISTS blocked_domains ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -20,15 +29,16 @@ def create_tables(self) -> None: ) """) - # Easylist entries table cursor.execute(""" CREATE TABLE IF NOT EXISTS easylist ( id INTEGER PRIMARY KEY AUTOINCREMENT, - entry TEXT UNIQUE + raw_pattern TEXT UNIQUE, + pattern_type TEXT NOT NULL, + processed_pattern TEXT NOT NULL, + options TEXT ) """) - # Settings table for toggles cursor.execute(""" CREATE TABLE IF NOT EXISTS settings ( setting TEXT PRIMARY KEY, @@ -36,7 +46,7 @@ def create_tables(self) -> None: ) """) - # Initialize settings if not exists + # Initialize settings cursor.execute(""" INSERT OR IGNORE INTO settings (setting, value) VALUES @@ -44,6 +54,17 @@ def create_tables(self) -> None: ('adult_block', 'off') """) + # Create indices after tables are created + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_pattern_type + ON easylist(pattern_type) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_processed_pattern + ON easylist(processed_pattern) + """) + conn.commit() def get_setting(self, setting: str) -> str: @@ -97,25 +118,80 @@ def is_domain_blocked(self, domain: str) -> bool: cursor.execute("SELECT 1 FROM blocked_domains WHERE domain = ?", (domain,)) return cursor.fetchone() is not None - def store_easylist_entries(self, entries: List[tuple]) -> None: - """Store easylist entries.""" + def store_filter_rule(self, rule: FilterRule) -> None: + """Store a single filter rule in easylist table.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - cursor.executemany( - "INSERT OR IGNORE INTO easylist (entry) VALUES (?)", - entries - ) + cursor.execute(""" + INSERT OR REPLACE INTO easylist + (raw_pattern, pattern_type, processed_pattern, options) + VALUES (?, ?, ?, ?) + """, ( + rule.raw_pattern, + rule.pattern_type.value, + rule.processed_pattern, + json.dumps(rule.options) + )) + conn.commit() + + def store_easylist_entries(self, entries: List[str]) -> None: + """Store easylist entries with proper pattern parsing.""" + rules = [] + for entry in entries: + try: + rule = FilterRule(entry) + rules.append(( + rule.raw_pattern, + rule.pattern_type.value, + rule.processed_pattern, + json.dumps(rule.options) + )) + except Exception as e: + print(f"Error parsing rule '{entry}': {e}") + continue + + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.executemany(""" + INSERT OR IGNORE INTO easylist + (raw_pattern, pattern_type, processed_pattern, options) + VALUES (?, ?, ?, ?) + """, rules) conn.commit() def is_easylist_blocked(self, domain: str) -> bool: """Check if domain matches any easylist pattern.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - cursor.execute( - "SELECT 1 FROM easylist WHERE ? GLOB '*' || entry || '*'", - (domain,) - ) - return cursor.fetchone() is not None + + # First check exceptions + cursor.execute(""" + SELECT raw_pattern, pattern_type, processed_pattern, options + FROM easylist + WHERE pattern_type = ? + """, (PatternType.EXCEPTION.value,)) + + for row in cursor.fetchall(): + rule = FilterRule(row[0]) + if rule.matches(domain, domain): # Using domain as both URL and domain + print(f"Domain {domain} matched exception rule: {row[0]}") + return False + + # Then check blocking rules + cursor.execute(""" + SELECT raw_pattern, pattern_type, processed_pattern, options + FROM easylist + WHERE pattern_type != ? + """, (PatternType.EXCEPTION.value,)) + + for row in cursor.fetchall(): + rule = FilterRule(row[0]) + if rule.matches(domain, domain): + print(f"Domain {domain} matched blocking rule: {row[0]}") + return True + + print(f"Domain {domain} did not match any patterns") + return False def clear_easylist(self) -> None: """Clear all easylist entries.""" diff --git a/server/src/easylist_manager.py b/server/src/easylist_manager.py new file mode 100644 index 0000000..933a589 --- /dev/null +++ b/server/src/easylist_manager.py @@ -0,0 +1,86 @@ +import threading +import requests +import json +from datetime import datetime, timedelta +from typing import Optional +from .db_manager import DatabaseManager +from .config import EASYLIST_URL +from .filter_rules import FilterRule + +class EasyListManager: + def __init__(self, db_manager: DatabaseManager, update_interval: int = 24) -> None: + """ + Initialize EasyList manager. + + Args: + db_manager: Database manager instance + update_interval: Update interval in hours (default: 24) + """ + self.db_manager = db_manager + self.update_interval = update_interval + self.update_timer: Optional[threading.Timer] = None + self.running = True + + def start_update_scheduler(self) -> None: + """Start the update scheduler.""" + self.schedule_next_update() + + def stop_update_scheduler(self) -> None: + """Stop the update scheduler.""" + self.running = False + if self.update_timer: + self.update_timer.cancel() + + def schedule_next_update(self) -> None: + """Schedule the next update.""" + if not self.running: + return + + # Schedule next update + self.update_timer = threading.Timer( + self.update_interval * 3600, # Convert hours to seconds + self._perform_update + ) + self.update_timer.daemon = True + self.update_timer.start() + + def _perform_update(self) -> None: + """Perform the EasyList update.""" + try: + print("Starting EasyList update...") + + # Download new EasyList + response = requests.get(EASYLIST_URL) + response.raise_for_status() + + # Parse rules + rules = [] + for line in response.text.split('\n'): + line = line.strip() + if line and not line.startswith('!') and not line.startswith('['): + try: + rule = FilterRule(line) + rules.append(( + rule.raw_pattern, + rule.pattern_type.value, + rule.processed_pattern, + json.dumps(rule.options) + )) + except Exception as e: + print(f"Error parsing rule '{line}': {e}") + continue + + # Update database + self.db_manager.clear_easylist() + self.db_manager.store_easylist_entries(rules) + + print(f"EasyList updated with {len(rules)} rules") + + except Exception as e: + print(f"Error updating EasyList: {e}") + finally: + self.schedule_next_update() + + def force_update(self) -> None: + """Force an immediate update.""" + self._perform_update() \ No newline at end of file diff --git a/server/src/filter_rules.py b/server/src/filter_rules.py new file mode 100644 index 0000000..408c7bb --- /dev/null +++ b/server/src/filter_rules.py @@ -0,0 +1,82 @@ +from enum import Enum +from typing import Optional, List, Dict, Any +import re + +class PatternType(Enum): + DOMAIN = "domain" # ||example.com^ + EXACT = "exact" # |http://example.com/| + WILDCARD = "wildcard" # /banner/*/img^ + EXCEPTION = "exception" # @@||example.com^ + +class FilterRule: + def __init__(self, raw_pattern: str): + self.raw_pattern = raw_pattern + self.pattern_type = self._determine_pattern_type() + self.processed_pattern = self._process_pattern() + self.options: Dict[str, Any] = self._parse_options() + + def _determine_pattern_type(self) -> PatternType: + pattern = self.raw_pattern.strip() + + if pattern.startswith("@@"): + return PatternType.EXCEPTION + elif pattern.startswith("||"): + return PatternType.DOMAIN + elif pattern.startswith("|") and pattern.endswith("|"): + return PatternType.EXACT + else: + return PatternType.WILDCARD + + def _process_pattern(self) -> str: + """Process the raw pattern into a normalized form.""" + pattern = self.raw_pattern.strip() + + # Remove options part if exists + if "$" in pattern: + pattern = pattern.split("$")[0] + + # Process based on type + if self.pattern_type == PatternType.EXCEPTION: + return pattern[2:] # Remove @@ + elif self.pattern_type == PatternType.DOMAIN: + return pattern[2:-1] # Remove || and ^ + elif self.pattern_type == PatternType.EXACT: + return pattern[1:-1] # Remove leading and trailing | + else: + return pattern + + def _parse_options(self) -> Dict[str, Any]: + """Parse filter options like $script,image,domain=example.com.""" + options = {} + if "$" in self.raw_pattern: + options_part = self.raw_pattern.split("$")[1] + for opt in options_part.split(","): + if "=" in opt: + key, value = opt.split("=") + options[key] = value + else: + options[opt] = True + return options + + def matches(self, url: str, domain: str) -> bool: + """Check if URL matches this filter rule.""" + # Check domain restrictions if any + if "domain" in self.options: + allowed_domains = self.options["domain"].split("|") + if not any(domain.endswith(d) for d in allowed_domains): + return False + + # Match based on pattern type + if self.pattern_type == PatternType.DOMAIN: + return domain.endswith(self.processed_pattern) + elif self.pattern_type == PatternType.EXACT: + return url == self.processed_pattern + else: + # Convert wildcard pattern to regex + regex_pattern = ( + self.processed_pattern + .replace(".", r"\.") + .replace("*", ".*") + .replace("?", ".") + ) + return bool(re.search(regex_pattern, url)) \ No newline at end of file diff --git a/server/src/handlers.py b/server/src/handlers.py index ee1fad6..45f03c9 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -1,9 +1,8 @@ -import requests from abc import ABC, abstractmethod from typing import Dict, Any from .db_manager import DatabaseManager -from .config import EASYLIST_URL from .response_codes import Codes, RESPONSE_MESSAGES +from .easylist_manager import EasyListManager class RequestHandler(ABC): @abstractmethod @@ -12,28 +11,13 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: class AdBlockHandler(RequestHandler): def __init__(self, db_manager: DatabaseManager): + """Initialize AdBlockHandler with database manager and EasyList manager.""" self.db_manager = db_manager - self.load_easylist() - - def load_easylist(self) -> None: - """Load and parse easylist.""" - try: - response = requests.get(EASYLIST_URL) - response.raise_for_status() - easylist_data = response.text - - # Parse and store valid entries - entries = [] - for line in easylist_data.split("\n"): - line = line.strip() - if line and not line.startswith("!"): - entries.append((line,)) - - self.db_manager.clear_easylist() - self.db_manager.store_easylist_entries(entries) - - except requests.exceptions.RequestException as e: - print(f"Error loading EasyList: {e}") + self.easylist_manager = EasyListManager(db_manager) + # Start the automatic update scheduler + self.easylist_manager.start_update_scheduler() + # Perform initial load + self.easylist_manager.force_update() def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle ad block requests.""" @@ -42,6 +26,11 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: # Handle toggle request state = request_data['action'] # 'on' or 'off' self.db_manager.update_setting('ad_block', state) + + # If turning on, ensure EasyList is loaded + if state == 'on': + self.easylist_manager.force_update() + return { 'code': Codes.CODE_AD_BLOCK, 'message': RESPONSE_MESSAGES['success'] @@ -72,6 +61,11 @@ def is_domain_blocked(self, domain: str) -> bool: return False return self.db_manager.is_easylist_blocked(domain) + def __del__(self) -> None: + """Cleanup when handler is destroyed.""" + if hasattr(self, 'easylist_manager'): + self.easylist_manager.stop_update_scheduler() + class AdultContentBlockHandler(RequestHandler): def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager @@ -153,6 +147,25 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: 'message': str(e) } +class DomainListHandler(RequestHandler): + def __init__(self, db_manager: DatabaseManager): + self.db_manager = db_manager + + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle domain list requests.""" + try: + domains = self.db_manager.get_blocked_domains() + return { + 'code': Codes.CODE_DOMAIN_LIST_UPDATE, + 'domains': domains, + 'message': RESPONSE_MESSAGES['success'] + } + except Exception as e: + return { + 'code': Codes.CODE_DOMAIN_LIST_UPDATE, + 'message': str(e) + } + class RequestFactory: def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager @@ -160,7 +173,8 @@ def __init__(self, db_manager: DatabaseManager): Codes.CODE_AD_BLOCK: lambda: AdBlockHandler(self.db_manager), Codes.CODE_ADULT_BLOCK: lambda: AdultContentBlockHandler(self.db_manager), Codes.CODE_ADD_DOMAIN: lambda: DomainBlockHandler(self.db_manager), - Codes.CODE_REMOVE_DOMAIN: lambda: DomainBlockHandler(self.db_manager) + Codes.CODE_REMOVE_DOMAIN: lambda: DomainBlockHandler(self.db_manager), + Codes.CODE_DOMAIN_LIST_UPDATE: lambda: DomainListHandler(self.db_manager) } def create_request_handler(self, request_type: str) -> RequestHandler: diff --git a/server/src/server.py b/server/src/server.py index 53fb5cc..4a9302d 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -1,77 +1,67 @@ -import asyncio +import socket +import threading import json +import asyncio from typing import Dict, Any from .config import HOST, CLIENT_PORT, KERNEL_PORT from .db_manager import DatabaseManager from .handlers import RequestFactory -from .response_codes import Codes, RESPONSE_MESSAGES class Server: - def __init__(self, db_file: str): - """Initialize server with database and request factory.""" - self.db_manager = DatabaseManager(db_file) + def __init__(self, db_manager: DatabaseManager) -> None: + self.db_manager = db_manager self.request_factory = RequestFactory(self.db_manager) - self.client_writer = None # Store the single client connection - - async def handle_client( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter - ) -> None: - """Handle the client connection.""" - # Store the client writer for potential updates - self.client_writer = writer - print(f"Client connected from {writer.get_extra_info('peername')}") - - try: - # Send initial domain list to client - await self._send_domain_list() - - while True: - data = await reader.readline() - if not data: - break - - try: - request_data = json.loads(data.decode('utf-8')) - print(f"Received from client: {request_data}") - - # Process request - response = self.request_factory.handle_request(request_data) - print(f"Sending response: {response}") + self.running = True - # Send response - writer.write(json.dumps(response).encode('utf-8') + b'\n') - await writer.drain() - - # If domain list was modified, send update - if request_data.get('code') in [Codes.CODE_ADD_DOMAIN, Codes.CODE_REMOVE_DOMAIN]: - await self._send_domain_list() + def handle_client_thread(self) -> None: + """Handle client connections using traditional socket.""" + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.bind((HOST, CLIENT_PORT)) + client_socket.listen(1) # Only one client needed + print(f"Client server running on {HOST}:{CLIENT_PORT}") - except json.JSONDecodeError: - print("Invalid JSON received") - error_response = { - 'code': request_data.get('code', ''), - 'message': RESPONSE_MESSAGES['invalid_request'] - } - writer.write(json.dumps(error_response).encode('utf-8') + b'\n') - await writer.drain() + while self.running: + try: + conn, addr = client_socket.accept() + print(f"Client connected from {addr}") + + # Send initial domain list + domains = self.db_manager.get_blocked_domains() + conn.send(json.dumps({ + 'type': 'domain_list', + 'domains': domains + }).encode() + b'\n') + + while True: + data = conn.recv(1024) + if not data: + break + + try: + request_data = json.loads(data.decode()) + response = self.request_factory.handle_request(request_data) + conn.send(json.dumps(response).encode() + b'\n') + except json.JSONDecodeError: + conn.send(json.dumps({ + 'status': 'error', + 'message': 'Invalid JSON format' + }).encode() + b'\n') + except Exception as e: + conn.send(json.dumps({ + 'status': 'error', + 'message': str(e) + }).encode() + b'\n') - except Exception as e: - print(f"Error handling client: {e}") - finally: - self.client_writer = None - writer.close() - await writer.wait_closed() - print("Client disconnected") + except Exception as e: + print(f"Client error: {e}") + finally: + if 'conn' in locals(): + conn.close() - async def handle_kernel( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter - ) -> None: - """Handle kernel module connections.""" - print(f"Kernel module connected from {writer.get_extra_info('peername')}") + async def handle_kernel_requests(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + """Handle kernel requests using asyncio for better performance.""" + addr = writer.get_extra_info('peername') + print(f"Kernel module connected from {addr}") try: while True: @@ -79,95 +69,55 @@ async def handle_kernel( if not data: break - try: - request_data = json.loads(data.decode('utf-8')) - print(f"Received from kernel: {request_data}") - - # Process kernel request - response = self.handle_kernel_request(request_data) - print(f"Sending to kernel: {response}") + request_data = json.loads(data.decode()) + domain = request_data.get('domain') - writer.write(json.dumps(response).encode('utf-8') + b'\n') - await writer.drain() + # Fast domain check + should_block = ( + self.db_manager.is_domain_blocked(domain) or + (self.db_manager.get_setting('ad_block') == 'on' and + self.db_manager.is_easylist_blocked(domain)) or + (self.db_manager.get_setting('adult_block') == 'on' and + 'adult' in request_data.get('categories', [])) + ) - except json.JSONDecodeError: - print("Invalid JSON received from kernel") + writer.write(json.dumps({'block': should_block}).encode() + b'\n') + await writer.drain() except Exception as e: - print(f"Error handling kernel module: {e}") + print(f"Kernel error: {e}") finally: writer.close() await writer.wait_closed() - print("Kernel module disconnected") - - def handle_kernel_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """Process kernel module requests.""" - domain = request_data.get('domain') - if not domain: - return {'block': False} - - # Check if domain should be blocked - should_block = ( - # Check manually blocked domains - self.db_manager.is_domain_blocked(domain) or - # Check ad blocking - ( - self.db_manager.get_setting('ad_block') == 'on' and - self.db_manager.is_easylist_blocked(domain) - ) or - # Check adult content blocking - ( - self.db_manager.get_setting('adult_block') == 'on' and - 'adult' in request_data.get('categories', []) - ) - ) - - return {'block': should_block} - - async def _send_domain_list(self) -> None: - """Send updated domain list to client.""" - if self.client_writer: - domains = self.db_manager.get_blocked_domains() - update_message = { - 'code': Codes.CODE_DOMAIN_LIST_UPDATE, - 'content': domains - } - try: - self.client_writer.write(json.dumps(update_message).encode('utf-8') + b'\n') - await self.client_writer.drain() - except Exception as e: - print(f"Error sending domain list update: {e}") - -async def start_server(db_manager: DatabaseManager) -> None: - """Start the server with both client and kernel handlers.""" - server = Server(db_manager.db_file) - - client_server = await asyncio.start_server( - server.handle_client, - HOST, - CLIENT_PORT - ) - kernel_server = await asyncio.start_server( - server.handle_kernel, - HOST, - KERNEL_PORT - ) + def start_server(self) -> None: + """Run both client and kernel handlers.""" + try: + # Start client handler in a separate thread + client_thread = threading.Thread(target=self.handle_client_thread) + client_thread.start() + + # Run kernel handler with asyncio + async def start_kernel_server() -> None: + kernel_server = await asyncio.start_server( + self.handle_kernel_requests, + HOST, + KERNEL_PORT + ) + print(f"Kernel server running on {HOST}:{KERNEL_PORT}") + await kernel_server.serve_forever() + + # Run the asyncio event loop for kernel handler + asyncio.run(start_kernel_server()) + + except KeyboardInterrupt: + self.running = False + print("\nServer stopping...") + except Exception as e: + print(f"Server error: {e}") - async with client_server, kernel_server: - print(f"Client server running on {HOST}:{CLIENT_PORT}") - print(f"Kernel server running on {HOST}:{KERNEL_PORT}") - - await asyncio.gather( - client_server.serve_forever(), - kernel_server.serve_forever() - ) - -def run(db_file: str) -> None: +def initialize_server(db_file: str) -> None: """Initialize and run the server.""" - try: - asyncio.run(start_server(DatabaseManager(db_file))) - except KeyboardInterrupt: - print("\nServer stopped by user") - except Exception as e: - print(f"Server error: {str(e)}") \ No newline at end of file + db_manager = DatabaseManager(db_file) + server = Server(db_manager) + server.start_server() \ No newline at end of file diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 14fb24c..a1cfeff 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -12,17 +12,20 @@ def mock_db_manager() -> mock.Mock: # Setup default returns for common methods mock_db.is_domain_blocked.return_value = False mock_db.is_easylist_blocked.return_value = False + mock_db.get_setting.return_value = 'off' # Default setting state + mock_db.get_blocked_domains.return_value = [] # Default empty domain list return mock_db @pytest.fixture def mock_requests() -> Generator[mock.Mock, None, None]: - """Fixture to mock requests library.""" + """Fixture to mock requests library for easylist downloading.""" with mock.patch('My_Internet.server.src.handlers.requests') as mock_req: # Create a mock response mock_response = mock.Mock() mock_response.text = "test.com\n!comment\nexample.com" mock_req.get.return_value = mock_response + mock_response.raise_for_status = mock.Mock() # Reset the mock to clear any previous calls mock_req.reset_mock() @@ -47,30 +50,37 @@ def sample_requests() -> dict: """Fixture to provide sample request data.""" return { "adult_block": { - "type": "adult_content_block", - "action": "enable" + "code": "51", + "action": "on" }, "domain_block": { - "type": "domain_block", + "code": "52", "action": "block", "domain": "example.com" }, "ad_block": { - "type": "ad_block", - "domain": "test.com" + "code": "50", + "action": "on" + }, + "check_domain": { + "code": "50", + "domain": "ads.example.com" } } @pytest.fixture -async def mock_stream_reader() -> mock.AsyncMock: - """Global fixture for AsyncMock StreamReader.""" +def mock_stream_reader() -> mock.AsyncMock: + """Mock for asyncio StreamReader.""" reader = mock.AsyncMock() + reader.readline = mock.AsyncMock() return reader @pytest.fixture -async def mock_stream_writer() -> mock.AsyncMock: - """Global fixture for AsyncMock StreamWriter.""" - writer = mock.AsyncMock() - writer.write = mock.Mock() # write is usually synchronous +def mock_stream_writer() -> mock.Mock: + """Mock for asyncio StreamWriter.""" + writer = mock.Mock() + writer.write = mock.Mock() writer.drain = mock.AsyncMock() + writer.close = mock.Mock() + writer.wait_closed = mock.AsyncMock() return writer \ No newline at end of file diff --git a/server/tests/test_handlers.py b/server/tests/test_handlers.py index 3fe012d..3d31cb4 100644 --- a/server/tests/test_handlers.py +++ b/server/tests/test_handlers.py @@ -1,87 +1,97 @@ import pytest from unittest import mock from typing import Dict, Any -from My_Internet.server.src.handlers import EASYLIST_URL from My_Internet.server.src.handlers import ( RequestHandler, AdultContentBlockHandler, DomainBlockHandler, AdBlockHandler, - RequestFactory + RequestFactory, + EASYLIST_URL ) from My_Internet.server.src.response_codes import ( - SUCCESS, - INVALID_REQUEST, - DOMAIN_BLOCKED, - DOMAIN_NOT_FOUND, - AD_BLOCK_ENABLED, - ADULT_CONTENT_BLOCKED, + Codes, RESPONSE_MESSAGES ) -class TestAdultContentBlockHandler: +class TestAdBlockHandler: @pytest.fixture - def handler(self, mock_db_manager: mock.Mock) -> AdultContentBlockHandler: - """Create handler instance and reset state.""" - handler = AdultContentBlockHandler(mock_db_manager) - # Reset class-level state before each test - AdultContentBlockHandler._is_enabled = False - return handler - - def test_init(self, handler: AdultContentBlockHandler, mock_db_manager: mock.Mock) -> None: - """Test handler initialization.""" - assert handler.db_manager == mock_db_manager - assert not handler._is_enabled - - def test_handle_enable_request(self, handler: AdultContentBlockHandler) -> None: - """Test enabling adult content blocking.""" - request_data: Dict[str, Any] = {'action': 'enable'} - response = handler.handle_request(request_data) - - assert response['code'] == SUCCESS - assert response['adult_content_block'] == 'on' - assert AdultContentBlockHandler.is_blocking_enabled() - - def test_handle_disable_request(self, handler: AdultContentBlockHandler) -> None: - """Test disabling adult content blocking.""" - AdultContentBlockHandler._is_enabled = True - request_data: Dict[str, Any] = {'action': 'disable'} - - response = handler.handle_request(request_data) - - assert response['code'] == SUCCESS - assert response['adult_content_block'] == 'off' - assert not AdultContentBlockHandler.is_blocking_enabled() + def handler(self, mock_db_manager: mock.Mock, mock_requests: mock.Mock) -> AdBlockHandler: + """Create handler instance.""" + with mock.patch('My_Internet.server.src.handlers.AdBlockHandler.load_easylist'): + handler = AdBlockHandler(mock_db_manager) + mock_requests.reset_mock() + return handler - def test_handle_check_request(self, handler: AdultContentBlockHandler) -> None: - """Test checking domain with blocking enabled.""" - AdultContentBlockHandler._is_enabled = True - request_data = {'action': 'check', 'domain': 'example.com'} - - response = handler.handle_request(request_data) - - assert response['code'] == ADULT_CONTENT_BLOCKED + def test_load_easylist(self, handler: AdBlockHandler, mock_requests: mock.Mock) -> None: + """Test loading and parsing easylist.""" + # Configure mock response + mock_response = mock.Mock() + mock_response.text = "test.com\n!comment\nexample.com" + mock_requests.get.return_value = mock_response - def test_handle_check_request_disabled(self, handler: AdultContentBlockHandler) -> None: - """Test checking domain with blocking disabled.""" - request_data = {'action': 'check', 'domain': 'example.com'} + # Call method + handler.load_easylist() - response = handler.handle_request(request_data) + # Verify calls + mock_requests.get.assert_called_once_with(EASYLIST_URL) + mock_response.raise_for_status.assert_called_once() + handler.db_manager.clear_easylist.assert_called_once() - assert response['code'] == SUCCESS - - def test_handle_invalid_request(self, handler: AdultContentBlockHandler) -> None: - """Test handling invalid requests.""" - invalid_requests = [ - {'action': 'invalid_action'}, - {'action': 'check'}, # Missing domain - {} # Empty request - ] + # Verify easylist storage + expected_entries = [('test.com',), ('example.com',)] + handler.db_manager.store_easylist_entries.assert_called_once_with(expected_entries) + + def test_handle_toggle_request(self, handler: AdBlockHandler) -> None: + """Test toggling ad blocking on/off.""" + # Test enabling + response = handler.handle_request({'action': 'on'}) + handler.db_manager.update_setting.assert_called_with('ad_block', 'on') + assert response['code'] == Codes.CODE_AD_BLOCK + + # Test disabling + response = handler.handle_request({'action': 'off'}) + handler.db_manager.update_setting.assert_called_with('ad_block', 'off') + assert response['code'] == Codes.CODE_AD_BLOCK + + def test_handle_check_domain(self, handler: AdBlockHandler) -> None: + """Test checking domain with ad blocking.""" + # Setup: ad blocking enabled and domain matched in easylist + handler.db_manager.get_setting.return_value = 'on' + handler.db_manager.is_easylist_blocked.return_value = True - for request_data in invalid_requests: - response = handler.handle_request(request_data) - assert response['code'] == INVALID_REQUEST + response = handler.handle_request({'domain': 'ads.example.com'}) + assert response['code'] == Codes.CODE_AD_BLOCK + assert handler.db_manager.is_easylist_blocked.called +class TestAdultContentBlockHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock) -> AdultContentBlockHandler: + """Create handler instance.""" + return AdultContentBlockHandler(mock_db_manager) + + def test_handle_toggle_request(self, handler: AdultContentBlockHandler) -> None: + """Test toggling adult content blocking.""" + # Test enabling + response = handler.handle_request({'action': 'on'}) + handler.db_manager.update_setting.assert_called_with('adult_block', 'on') + assert response['code'] == Codes.CODE_ADULT_BLOCK + + # Test disabling + response = handler.handle_request({'action': 'off'}) + handler.db_manager.update_setting.assert_called_with('adult_block', 'off') + assert response['code'] == Codes.CODE_ADULT_BLOCK + + def test_handle_check_request(self, handler: AdultContentBlockHandler) -> None: + """Test checking domain with adult content blocking.""" + # Setup: adult blocking enabled + handler.db_manager.get_setting.return_value = 'on' + + response = handler.handle_request({ + 'action': 'check', + 'domain': 'example.com' + }) + assert response['code'] == Codes.CODE_ADULT_BLOCK class TestDomainBlockHandler: @pytest.fixture @@ -89,141 +99,68 @@ def handler(self, mock_db_manager: mock.Mock) -> DomainBlockHandler: """Create handler instance.""" return DomainBlockHandler(mock_db_manager) - def test_block_domain(self, handler: DomainBlockHandler, sample_domains: list[str]) -> None: + def test_block_domain(self, handler: DomainBlockHandler) -> None: """Test blocking a domain.""" - domain = sample_domains[0] response = handler.handle_request({ 'action': 'block', - 'domain': domain + 'domain': 'example.com' }) - handler.db_manager.add_blocked_domain.assert_called_once_with(domain) - assert response['code'] == DOMAIN_BLOCKED + handler.db_manager.add_blocked_domain.assert_called_once_with('example.com') + assert response['code'] == Codes.CODE_ADD_DOMAIN - def test_unblock_domain(self, handler: DomainBlockHandler, sample_domains: list[str]) -> None: + def test_unblock_domain(self, handler: DomainBlockHandler) -> None: """Test unblocking a domain.""" - domain = sample_domains[0] + # Setup: domain exists handler.db_manager.is_domain_blocked.return_value = True response = handler.handle_request({ 'action': 'unblock', - 'domain': domain + 'domain': 'example.com' }) - handler.db_manager.remove_blocked_domain.assert_called_once_with(domain) - assert response['code'] == SUCCESS + handler.db_manager.remove_blocked_domain.assert_called_once_with('example.com') + assert response['code'] == Codes.CODE_REMOVE_DOMAIN def test_unblock_nonexistent_domain(self, handler: DomainBlockHandler) -> None: - """Test unblocking a domain that isn't blocked.""" - handler.db_manager.is_domain_blocked.return_value = False + """Test unblocking a nonexistent domain.""" + handler.db_manager.remove_blocked_domain.return_value = False response = handler.handle_request({ 'action': 'unblock', 'domain': 'nonexistent.com' }) - - assert response['code'] == DOMAIN_NOT_FOUND + + assert 'domain not found' in response['message'].lower() -class TestAdBlockHandler: +class TestRequestFactory: @pytest.fixture - def handler(self, mock_db_manager: mock.Mock, mock_requests: mock.Mock) -> AdBlockHandler: - """Create handler instance with loading disabled.""" - with mock.patch('My_Internet.server.src.handlers.AdBlockHandler.load_easylist'): - # Create handler without loading easylist during initialization - handler = AdBlockHandler(mock_db_manager) - # Reset the mock before the test - mock_requests.reset_mock() - return handler - - def test_load_easylist(self, handler: AdBlockHandler, mock_requests: mock.Mock) -> None: - """Test loading the easylist.""" - # Configure mock response - mock_response = mock.Mock() - mock_response.text = "test.com\n!comment\nexample.com" - mock_requests.get.return_value = mock_response - - # Call method - handler.load_easylist() - - # Verify calls using the imported constant - mock_requests.get.assert_called_once_with(EASYLIST_URL) - mock_response.raise_for_status.assert_called_once() - handler.db_manager.clear_easylist.assert_called_once() - - # Verify that store_easylist_entries was called with correct data - expected_entries = [('test.com',), ('example.com',)] - handler.db_manager.store_easylist_entries.assert_called_once_with(expected_entries) - - def test_handle_check_request(self, handler: AdBlockHandler) -> None: - """Test checking a domain against easylist.""" - handler.db_manager.is_easylist_blocked.return_value = True - response = handler.handle_request({'domain': 'example.com'}) - - assert response['code'] == AD_BLOCK_ENABLED - - def test_handle_check_request_not_blocked(self, handler: AdBlockHandler) -> None: - """Test checking an unblocked domain.""" - handler.db_manager.is_easylist_blocked.return_value = False - response = handler.handle_request({'domain': 'example.com'}) - - assert response['code'] == SUCCESS + def factory(self, mock_db_manager: mock.Mock) -> RequestFactory: + """Create factory instance.""" + return RequestFactory(mock_db_manager) - -class TestRequestFactory: - def test_create_handlers(self, request_factory: RequestFactory) -> None: + def test_create_handlers(self, factory: RequestFactory) -> None: """Test creating different types of handlers.""" - handlers = { - 'ad_block': AdBlockHandler, - 'domain_block': DomainBlockHandler, - 'adult_content_block': AdultContentBlockHandler - } + test_cases = [ + (Codes.CODE_AD_BLOCK, AdBlockHandler), + (Codes.CODE_ADULT_BLOCK, AdultContentBlockHandler), + (Codes.CODE_ADD_DOMAIN, DomainBlockHandler), + (Codes.CODE_REMOVE_DOMAIN, DomainBlockHandler) + ] - for handler_type, handler_class in handlers.items(): - handler = request_factory.create_request_handler(handler_type) + for code, handler_class in test_cases: + handler = factory.create_request_handler(code) assert isinstance(handler, handler_class) - @mock.patch.object(AdultContentBlockHandler, 'handle_request') - def test_request_delegation( - self, - mock_handle: mock.Mock, - request_factory: RequestFactory - ) -> None: - """Test request delegation to appropriate handler.""" - expected_response = {'code': SUCCESS, 'message': 'Test response'} - mock_handle.return_value = expected_response - - request_data = { - 'type': 'adult_content_block', - 'action': 'enable' - } - - response = request_factory.handle_request(request_data) - mock_handle.assert_called_once_with(request_data) - assert response == expected_response + def test_handle_request(self, factory: RequestFactory, sample_requests: dict) -> None: + """Test handling different types of requests.""" + for request in sample_requests.values(): + response = factory.handle_request(request) + assert 'code' in response + assert 'message' in response - def test_handle_invalid_request_type(self, request_factory: RequestFactory) -> None: + def test_invalid_request_type(self, factory: RequestFactory) -> None: """Test handling invalid request type.""" - response = request_factory.handle_request({'type': 'invalid_type'}) - assert response['code'] == INVALID_REQUEST - - def test_factory_handler_integration(self, request_factory: RequestFactory) -> None: - """Test integration between factory and handlers.""" - test_cases = [ - { - 'request': {'type': 'adult_content_block', 'action': 'enable'}, - 'expected_code': SUCCESS - }, - { - 'request': {'type': 'domain_block', 'action': 'block', 'domain': 'example.com'}, - 'expected_code': DOMAIN_BLOCKED - }, - { - 'request': {'type': 'ad_block', 'domain': 'test.com'}, - 'expected_code': SUCCESS - } - ] - - for test_case in test_cases: - response = request_factory.handle_request(test_case['request']) - assert response['code'] == test_case['expected_code'] \ No newline at end of file + response = factory.handle_request({'code': 'invalid'}) + assert 'invalid' in response['message'].lower() \ No newline at end of file diff --git a/server/tests/test_server.py b/server/tests/test_server.py index 3fcadd1..3cdaf5f 100644 --- a/server/tests/test_server.py +++ b/server/tests/test_server.py @@ -2,144 +2,233 @@ import json import asyncio from unittest import mock -from typing import AsyncGenerator, Dict, Any - +from typing import Dict, Any from My_Internet.server.src.server import ( - handle_client, - handle_kernel, - route_kernel_request, + Server, start_server ) from My_Internet.server.src.config import HOST, CLIENT_PORT, KERNEL_PORT -from My_Internet.server.src.handlers import RequestFactory -from My_Internet.server.src.response_codes import SUCCESS, INVALID_REQUEST +from My_Internet.server.src.response_codes import Codes +from My_Internet.server.src.db_manager import DatabaseManager class TestServer: + @pytest.fixture + def db_manager(self) -> mock.Mock: + """Create a mock database manager.""" + mock_db = mock.Mock(spec=DatabaseManager) + mock_db.db_file = "test.db" + mock_db.is_domain_blocked.return_value = True + mock_db.get_setting.return_value = "on" + return mock_db + + @pytest.fixture + def server(self, db_manager: mock.Mock) -> Server: + """Create server instance for testing with a mocked DatabaseManager.""" + return Server(db_manager) + @pytest.fixture def mock_stream_reader(self) -> mock.AsyncMock: + """Mock for asyncio StreamReader.""" reader = mock.AsyncMock() reader.readline = mock.AsyncMock() return reader @pytest.fixture def mock_stream_writer(self) -> mock.Mock: + """Mock for asyncio StreamWriter.""" writer = mock.Mock() writer.write = mock.Mock() writer.drain = mock.AsyncMock() writer.close = mock.Mock() + writer.wait_closed = mock.AsyncMock() return writer - @pytest.mark.asyncio # Only for async functions + @pytest.mark.asyncio async def test_handle_client( self, + server: Server, mock_stream_reader: mock.AsyncMock, - mock_stream_writer: mock.Mock, - request_factory: RequestFactory + mock_stream_writer: mock.Mock ) -> None: """Test client request handling.""" + # Setup mocks + mock_stream_writer.write = mock.Mock() + mock_stream_writer.drain = mock.AsyncMock() + mock_stream_writer.get_extra_info = mock.Mock(return_value="test_client") + server.db_manager.get_blocked_domains.return_value = [] + + # Setup test request test_request = { - 'type': 'domain_block', + 'code': Codes.CODE_ADD_DOMAIN, 'action': 'block', 'domain': 'example.com' } - + + # Configure mock reader responses mock_stream_reader.readline.side_effect = [ json.dumps(test_request).encode() + b'\n', - b'' + b'' # End connection after request ] - - await handle_client(mock_stream_reader, mock_stream_writer, request_factory) - - assert mock_stream_writer.write.called - assert mock_stream_writer.drain.called - assert mock_stream_writer.close.called - @pytest.mark.asyncio # Only for async functions - async def test_handle_kernel( - self, - mock_stream_reader: mock.AsyncMock, - mock_stream_writer: mock.Mock, - mock_db_manager: mock.Mock - ) -> None: - """Test kernel request handling.""" + # Handle client connection + await server.handle_client(mock_stream_reader, mock_stream_writer) + + # Verify response was sent at least twice (domain list + request response) + assert mock_stream_writer.write.call_count >= 2 + + @pytest.mark.asyncio + async def test_handle_kernel(self, server: Server, mock_stream_reader: mock.AsyncMock, mock_stream_writer: mock.Mock) -> None: + """Test kernel module request handling.""" + # Setup test request test_request = { 'domain': 'example.com', 'categories': ['adult'] } + # Configure mocks + server.db_manager.is_domain_blocked.return_value = True + mock_stream_writer.write = mock.Mock() + mock_stream_writer.drain = mock.AsyncMock() + + # Configure mock reader responses mock_stream_reader.readline.side_effect = [ json.dumps(test_request).encode() + b'\n', - b'' + b'' # End connection after request ] - await handle_kernel(mock_stream_reader, mock_stream_writer, mock_db_manager) - + await server.handle_kernel(mock_stream_reader, mock_stream_writer) assert mock_stream_writer.write.called - assert mock_stream_writer.drain.called - assert mock_stream_writer.close.called - # No asyncio marker for synchronous functions - def test_route_kernel_request(self, mock_db_manager: mock.Mock) -> None: - """Test kernel request routing.""" - # Test blocked domain - mock_db_manager.is_domain_blocked.return_value = True - response = route_kernel_request({'domain': 'example.com'}, mock_db_manager) - assert response['block'] is True - - # Test allowed domain - mock_db_manager.is_domain_blocked.return_value = False - mock_db_manager.is_easylist_blocked.return_value = False - response = route_kernel_request({'domain': 'example.com'}, mock_db_manager) - assert response['block'] is False - - @pytest.mark.asyncio # Only for async functions - async def test_start_server(self, mock_db_manager: mock.Mock) -> None: + def test_handle_kernel_request(self, server: Server) -> None: + """Test kernel request processing.""" + test_cases = [ + # Test manually blocked domain + { + 'setup': { + 'is_domain_blocked': True, + 'get_setting': 'off' + }, + 'request': { + 'domain': 'blocked.com' + }, + 'expected': True + }, + # Test ad blocking + { + 'setup': { + 'is_domain_blocked': False, + 'get_setting': 'on', + 'is_easylist_blocked': True + }, + 'request': { + 'domain': 'ads.example.com' + }, + 'expected': True + }, + # Test adult content blocking + { + 'setup': { + 'is_domain_blocked': False, + 'get_setting': 'on' + }, + 'request': { + 'domain': 'example.com', + 'categories': ['adult'] + }, + 'expected': True + }, + # Test allowed domain + { + 'setup': { + 'is_domain_blocked': False, + 'get_setting': 'off', + 'is_easylist_blocked': False + }, + 'request': { + 'domain': 'example.com' + }, + 'expected': False + } + ] + + for case in test_cases: + # Setup mocks + server.db_manager.is_domain_blocked.return_value = case['setup'].get('is_domain_blocked', False) + server.db_manager.get_setting.return_value = case['setup'].get('get_setting', 'off') + server.db_manager.is_easylist_blocked.return_value = case['setup'].get('is_easylist_blocked', False) + + # Test request + response = server.handle_kernel_request(case['request']) + assert response['block'] is case['expected'] + + @pytest.mark.asyncio + async def test_start_server(self, db_manager: mock.Mock) -> None: """Test server startup.""" + # Configure mock + db_manager.db_file = "test.db" mock_client_server = mock.AsyncMock() mock_kernel_server = mock.AsyncMock() - with mock.patch('asyncio.start_server', - side_effect=[mock_client_server, mock_kernel_server]) as mock_start: - task = asyncio.create_task(start_server(mock_db_manager)) + with mock.patch('asyncio.start_server', side_effect=[mock_client_server, mock_kernel_server]) as mock_start: + task = asyncio.create_task(start_server(db_manager)) await asyncio.sleep(0.1) task.cancel() - try: await task except asyncio.CancelledError: pass - - assert mock_start.call_count == 2 - assert mock_client_server.serve_forever.called - assert mock_kernel_server.serve_forever.called - @pytest.mark.asyncio # Only for async functions - async def test_client_connection_error( + @pytest.mark.asyncio + async def test_client_error_handling( self, + server: Server, mock_stream_reader: mock.AsyncMock, - mock_stream_writer: mock.Mock, - request_factory: RequestFactory + mock_stream_writer: mock.Mock ) -> None: - """Test handling of client connection errors.""" - mock_stream_reader.readline.side_effect = ConnectionResetError() - - await handle_client(mock_stream_reader, mock_stream_writer, request_factory) - assert mock_stream_writer.close.called + """Test handling of client errors.""" + # Configure mocks + mock_stream_writer.write = mock.Mock() + mock_stream_writer.drain = mock.AsyncMock() + mock_stream_writer.get_extra_info = mock.Mock(return_value="test_client") + server.db_manager.get_blocked_domains.return_value = [] + + # Set up the mock to return invalid JSON + mock_stream_reader.readline.side_effect = [ + b'invalid json\n', + b'' # End connection after invalid request + ] + + # Handle the client connection + await server.handle_client(mock_stream_reader, mock_stream_writer) + + # Verify responses were sent (domain list + error response) + assert mock_stream_writer.write.call_count >= 2 - # No asyncio marker for synchronous functions - @pytest.mark.parametrize("request_data,expected_block", [ - ({'domain': 'example.com', 'categories': []}, False), - ({'domain': 'blocked.com', 'categories': ['adult']}, True), - ]) - def test_kernel_request_scenarios( + @pytest.mark.asyncio + async def test_kernel_error_handling( self, - request_data: Dict[str, Any], - expected_block: bool, - mock_db_manager: mock.Mock + server: Server, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock ) -> None: - """Test various kernel request scenarios.""" - is_blocked = 'blocked.com' in request_data['domain'] - mock_db_manager.is_domain_blocked.return_value = is_blocked + """Test handling of kernel errors.""" + # Test connection error + mock_stream_reader.readline.side_effect = ConnectionError() + + await server.handle_kernel(mock_stream_reader, mock_stream_writer) - response = route_kernel_request(request_data, mock_db_manager) - assert response['block'] is expected_block \ No newline at end of file + # Verify connection was closed + assert mock_stream_writer.close.called + + def test_multiple_kernel_requests(self, server: Server) -> None: + """Test handling multiple kernel requests.""" + # Setup mocks + server.db_manager.get_setting.side_effect = ['on', 'off'] + server.db_manager.is_domain_blocked.return_value = True + + # Test first request (blocking enabled) + response1 = server.handle_kernel_request({ + 'domain': 'example.com', + 'categories': ['adult'] + }) + assert response1['block'] is True From cc9c70b7456535499619a2e8dbf8468f55297b47 Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Sun, 10 Nov 2024 20:03:49 +0200 Subject: [PATCH 29/38] changed logic , add tests , loger , utils --- server/main.py | 2 +- server/src/config.py | 8 - server/src/db_manager.py | 171 ++------ server/src/easylist_manager.py | 86 ---- server/src/filter_rules.py | 82 ---- server/src/handlers.py | 148 +++---- server/src/logger.py | 45 +++ server/src/server.py | 204 ++++++---- server/src/{response_codes.py => utils.py} | 18 +- server/tests/conftest.py | 95 ++--- server/tests/test_handlers.py | 198 ++++------ server/tests/test_server.py | 440 ++++++++++----------- 12 files changed, 620 insertions(+), 877 deletions(-) delete mode 100644 server/src/config.py delete mode 100644 server/src/easylist_manager.py delete mode 100644 server/src/filter_rules.py create mode 100644 server/src/logger.py rename server/src/{response_codes.py => utils.py} (54%) diff --git a/server/main.py b/server/main.py index 8409cff..53ce51e 100644 --- a/server/main.py +++ b/server/main.py @@ -1,5 +1,5 @@ from src.server import initialize_server -from src.config import DB_FILE +from src.utils import DB_FILE if __name__ == '__main__': initialize_server(DB_FILE) \ No newline at end of file diff --git a/server/src/config.py b/server/src/config.py deleted file mode 100644 index 76737c8..0000000 --- a/server/src/config.py +++ /dev/null @@ -1,8 +0,0 @@ -# Network Configuration -HOST: str = '127.0.0.1' -CLIENT_PORT: int = 65432 -KERNEL_PORT: int = 65433 -DB_FILE: str = 'my_internet.db' - -# EasyList URL -EASYLIST_URL = "https://easylist.to/easylist/easylist.txt" \ No newline at end of file diff --git a/server/src/db_manager.py b/server/src/db_manager.py index e2537bb..3a8f26e 100644 --- a/server/src/db_manager.py +++ b/server/src/db_manager.py @@ -1,79 +1,55 @@ import sqlite3 -import json -from typing import List, Dict, Any, Optional -import requests -from .filter_rules import FilterRule, PatternType +from typing import List +from .logger import setup_logger class DatabaseManager: def __init__(self, db_file: str): + """Initialize database manager.""" self.db_file = db_file - self.create_tables() + self.logger = setup_logger(__name__) + self._create_tables() + self.logger.info(f"Database initialized at {db_file}") - def create_tables(self) -> None: - """Create the necessary tables if they don't exist.""" + def _create_tables(self) -> None: + """Create necessary database tables.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - # Drop existing indices first to avoid conflicts - cursor.execute("DROP INDEX IF EXISTS idx_pattern_type") - cursor.execute("DROP INDEX IF EXISTS idx_processed_pattern") - - # Drop existing tables to ensure clean schema - cursor.execute("DROP TABLE IF EXISTS easylist") - - # Create tables - cursor.execute(""" - CREATE TABLE IF NOT EXISTS blocked_domains ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - domain TEXT UNIQUE - ) - """) - + # Create settings table cursor.execute(""" - CREATE TABLE IF NOT EXISTS easylist ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - raw_pattern TEXT UNIQUE, - pattern_type TEXT NOT NULL, - processed_pattern TEXT NOT NULL, - options TEXT + CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL ) """) + # Create blocked_domains table - simplified without timestamp cursor.execute(""" - CREATE TABLE IF NOT EXISTS settings ( - setting TEXT PRIMARY KEY, - value TEXT + CREATE TABLE IF NOT EXISTS blocked_domains ( + domain TEXT PRIMARY KEY ) """) - # Initialize settings + # Initialize settings if not exists cursor.execute(""" - INSERT OR IGNORE INTO settings (setting, value) + INSERT OR IGNORE INTO settings (key, value) VALUES ('ad_block', 'off'), ('adult_block', 'off') """) - # Create indices after tables are created - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_pattern_type - ON easylist(pattern_type) - """) - - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_processed_pattern - ON easylist(processed_pattern) - """) - conn.commit() + self.logger.info("Database tables created/verified") def get_setting(self, setting: str) -> str: """Get setting value.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - cursor.execute("SELECT value FROM settings WHERE setting = ?", (setting,)) + cursor.execute("SELECT value FROM settings WHERE key = ?", (setting,)) result = cursor.fetchone() - return result[0] if result else 'off' + value = result[0] if result else 'off' + self.logger.debug(f"Retrieved setting {setting}: {value}") + return value def update_setting(self, setting: str, value: str) -> None: """Update setting value.""" @@ -82,9 +58,10 @@ def update_setting(self, setting: str, value: str) -> None: cursor.execute(""" UPDATE settings SET value = ? - WHERE setting = ? + WHERE key = ? """, (value, setting)) conn.commit() + self.logger.info(f"Updated setting {setting} to {value}") def add_blocked_domain(self, domain: str) -> None: """Add a domain to blocked list.""" @@ -93,8 +70,9 @@ def add_blocked_domain(self, domain: str) -> None: try: cursor.execute("INSERT INTO blocked_domains (domain) VALUES (?)", (domain,)) conn.commit() + self.logger.info(f"Domain {domain} added to block list") except sqlite3.IntegrityError: - print(f"Domain {domain} already exists in the database.") + self.logger.warning(f"Domain {domain} already exists in the database") def remove_blocked_domain(self, domain: str) -> bool: """Remove a domain from blocked list.""" @@ -102,100 +80,27 @@ def remove_blocked_domain(self, domain: str) -> bool: cursor = conn.cursor() cursor.execute("DELETE FROM blocked_domains WHERE domain = ?", (domain,)) conn.commit() - return cursor.rowcount > 0 + success = cursor.rowcount > 0 + if success: + self.logger.info(f"Domain {domain} removed from block list") + else: + self.logger.warning(f"Domain {domain} not found in block list") + return success def get_blocked_domains(self) -> List[str]: """Get list of all blocked domains.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() cursor.execute("SELECT domain FROM blocked_domains") - return [row[0] for row in cursor.fetchall()] + domains = [row[0] for row in cursor.fetchall()] + self.logger.debug(f"Retrieved {len(domains)} blocked domains") + return domains def is_domain_blocked(self, domain: str) -> bool: """Check if domain is in blocked list.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() cursor.execute("SELECT 1 FROM blocked_domains WHERE domain = ?", (domain,)) - return cursor.fetchone() is not None - - def store_filter_rule(self, rule: FilterRule) -> None: - """Store a single filter rule in easylist table.""" - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute(""" - INSERT OR REPLACE INTO easylist - (raw_pattern, pattern_type, processed_pattern, options) - VALUES (?, ?, ?, ?) - """, ( - rule.raw_pattern, - rule.pattern_type.value, - rule.processed_pattern, - json.dumps(rule.options) - )) - conn.commit() - - def store_easylist_entries(self, entries: List[str]) -> None: - """Store easylist entries with proper pattern parsing.""" - rules = [] - for entry in entries: - try: - rule = FilterRule(entry) - rules.append(( - rule.raw_pattern, - rule.pattern_type.value, - rule.processed_pattern, - json.dumps(rule.options) - )) - except Exception as e: - print(f"Error parsing rule '{entry}': {e}") - continue - - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.executemany(""" - INSERT OR IGNORE INTO easylist - (raw_pattern, pattern_type, processed_pattern, options) - VALUES (?, ?, ?, ?) - """, rules) - conn.commit() - - def is_easylist_blocked(self, domain: str) -> bool: - """Check if domain matches any easylist pattern.""" - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - - # First check exceptions - cursor.execute(""" - SELECT raw_pattern, pattern_type, processed_pattern, options - FROM easylist - WHERE pattern_type = ? - """, (PatternType.EXCEPTION.value,)) - - for row in cursor.fetchall(): - rule = FilterRule(row[0]) - if rule.matches(domain, domain): # Using domain as both URL and domain - print(f"Domain {domain} matched exception rule: {row[0]}") - return False - - # Then check blocking rules - cursor.execute(""" - SELECT raw_pattern, pattern_type, processed_pattern, options - FROM easylist - WHERE pattern_type != ? - """, (PatternType.EXCEPTION.value,)) - - for row in cursor.fetchall(): - rule = FilterRule(row[0]) - if rule.matches(domain, domain): - print(f"Domain {domain} matched blocking rule: {row[0]}") - return True - - print(f"Domain {domain} did not match any patterns") - return False - - def clear_easylist(self) -> None: - """Clear all easylist entries.""" - with sqlite3.connect(self.db_file) as conn: - cursor = conn.cursor() - cursor.execute("DELETE FROM easylist") - conn.commit() \ No newline at end of file + is_blocked = cursor.fetchone() is not None + self.logger.debug(f"Domain {domain} blocked status: {is_blocked}") + return is_blocked \ No newline at end of file diff --git a/server/src/easylist_manager.py b/server/src/easylist_manager.py deleted file mode 100644 index 933a589..0000000 --- a/server/src/easylist_manager.py +++ /dev/null @@ -1,86 +0,0 @@ -import threading -import requests -import json -from datetime import datetime, timedelta -from typing import Optional -from .db_manager import DatabaseManager -from .config import EASYLIST_URL -from .filter_rules import FilterRule - -class EasyListManager: - def __init__(self, db_manager: DatabaseManager, update_interval: int = 24) -> None: - """ - Initialize EasyList manager. - - Args: - db_manager: Database manager instance - update_interval: Update interval in hours (default: 24) - """ - self.db_manager = db_manager - self.update_interval = update_interval - self.update_timer: Optional[threading.Timer] = None - self.running = True - - def start_update_scheduler(self) -> None: - """Start the update scheduler.""" - self.schedule_next_update() - - def stop_update_scheduler(self) -> None: - """Stop the update scheduler.""" - self.running = False - if self.update_timer: - self.update_timer.cancel() - - def schedule_next_update(self) -> None: - """Schedule the next update.""" - if not self.running: - return - - # Schedule next update - self.update_timer = threading.Timer( - self.update_interval * 3600, # Convert hours to seconds - self._perform_update - ) - self.update_timer.daemon = True - self.update_timer.start() - - def _perform_update(self) -> None: - """Perform the EasyList update.""" - try: - print("Starting EasyList update...") - - # Download new EasyList - response = requests.get(EASYLIST_URL) - response.raise_for_status() - - # Parse rules - rules = [] - for line in response.text.split('\n'): - line = line.strip() - if line and not line.startswith('!') and not line.startswith('['): - try: - rule = FilterRule(line) - rules.append(( - rule.raw_pattern, - rule.pattern_type.value, - rule.processed_pattern, - json.dumps(rule.options) - )) - except Exception as e: - print(f"Error parsing rule '{line}': {e}") - continue - - # Update database - self.db_manager.clear_easylist() - self.db_manager.store_easylist_entries(rules) - - print(f"EasyList updated with {len(rules)} rules") - - except Exception as e: - print(f"Error updating EasyList: {e}") - finally: - self.schedule_next_update() - - def force_update(self) -> None: - """Force an immediate update.""" - self._perform_update() \ No newline at end of file diff --git a/server/src/filter_rules.py b/server/src/filter_rules.py deleted file mode 100644 index 408c7bb..0000000 --- a/server/src/filter_rules.py +++ /dev/null @@ -1,82 +0,0 @@ -from enum import Enum -from typing import Optional, List, Dict, Any -import re - -class PatternType(Enum): - DOMAIN = "domain" # ||example.com^ - EXACT = "exact" # |http://example.com/| - WILDCARD = "wildcard" # /banner/*/img^ - EXCEPTION = "exception" # @@||example.com^ - -class FilterRule: - def __init__(self, raw_pattern: str): - self.raw_pattern = raw_pattern - self.pattern_type = self._determine_pattern_type() - self.processed_pattern = self._process_pattern() - self.options: Dict[str, Any] = self._parse_options() - - def _determine_pattern_type(self) -> PatternType: - pattern = self.raw_pattern.strip() - - if pattern.startswith("@@"): - return PatternType.EXCEPTION - elif pattern.startswith("||"): - return PatternType.DOMAIN - elif pattern.startswith("|") and pattern.endswith("|"): - return PatternType.EXACT - else: - return PatternType.WILDCARD - - def _process_pattern(self) -> str: - """Process the raw pattern into a normalized form.""" - pattern = self.raw_pattern.strip() - - # Remove options part if exists - if "$" in pattern: - pattern = pattern.split("$")[0] - - # Process based on type - if self.pattern_type == PatternType.EXCEPTION: - return pattern[2:] # Remove @@ - elif self.pattern_type == PatternType.DOMAIN: - return pattern[2:-1] # Remove || and ^ - elif self.pattern_type == PatternType.EXACT: - return pattern[1:-1] # Remove leading and trailing | - else: - return pattern - - def _parse_options(self) -> Dict[str, Any]: - """Parse filter options like $script,image,domain=example.com.""" - options = {} - if "$" in self.raw_pattern: - options_part = self.raw_pattern.split("$")[1] - for opt in options_part.split(","): - if "=" in opt: - key, value = opt.split("=") - options[key] = value - else: - options[opt] = True - return options - - def matches(self, url: str, domain: str) -> bool: - """Check if URL matches this filter rule.""" - # Check domain restrictions if any - if "domain" in self.options: - allowed_domains = self.options["domain"].split("|") - if not any(domain.endswith(d) for d in allowed_domains): - return False - - # Match based on pattern type - if self.pattern_type == PatternType.DOMAIN: - return domain.endswith(self.processed_pattern) - elif self.pattern_type == PatternType.EXACT: - return url == self.processed_pattern - else: - # Convert wildcard pattern to regex - regex_pattern = ( - self.processed_pattern - .replace(".", r"\.") - .replace("*", ".*") - .replace("?", ".") - ) - return bool(re.search(regex_pattern, url)) \ No newline at end of file diff --git a/server/src/handlers.py b/server/src/handlers.py index 45f03c9..0213156 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -1,24 +1,15 @@ -from abc import ABC, abstractmethod from typing import Dict, Any from .db_manager import DatabaseManager -from .response_codes import Codes, RESPONSE_MESSAGES -from .easylist_manager import EasyListManager +from .utils import Codes, RESPONSE_MESSAGES +from .logger import setup_logger -class RequestHandler(ABC): - @abstractmethod - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - pass - -class AdBlockHandler(RequestHandler): +class RequestHandler: + """Base class for request handlers.""" def __init__(self, db_manager: DatabaseManager): - """Initialize AdBlockHandler with database manager and EasyList manager.""" self.db_manager = db_manager - self.easylist_manager = EasyListManager(db_manager) - # Start the automatic update scheduler - self.easylist_manager.start_update_scheduler() - # Perform initial load - self.easylist_manager.force_update() + self.logger = setup_logger(__name__) +class AdBlockHandler(RequestHandler): def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle ad block requests.""" try: @@ -26,50 +17,25 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: # Handle toggle request state = request_data['action'] # 'on' or 'off' self.db_manager.update_setting('ad_block', state) - - # If turning on, ensure EasyList is loaded - if state == 'on': - self.easylist_manager.force_update() - + self.logger.info(f"Ad blocking turned {state}") return { 'code': Codes.CODE_AD_BLOCK, - 'message': RESPONSE_MESSAGES['success'] + 'message': f"Ad blocking turned {state}" } - - elif 'domain' in request_data: - # Check if domain should be blocked - if self.is_domain_blocked(request_data['domain']): - return { - 'code': Codes.CODE_AD_BLOCK, - 'message': "Domain contains ads" - } - + return { 'code': Codes.CODE_AD_BLOCK, 'message': RESPONSE_MESSAGES['success'] } except Exception as e: + self.logger.error(f"Error in ad block handler: {e}") return { 'code': Codes.CODE_AD_BLOCK, 'message': str(e) } - def is_domain_blocked(self, domain: str) -> bool: - """Check if domain should be blocked based on easylist.""" - if self.db_manager.get_setting('ad_block') == 'off': - return False - return self.db_manager.is_easylist_blocked(domain) - - def __del__(self) -> None: - """Cleanup when handler is destroyed.""" - if hasattr(self, 'easylist_manager'): - self.easylist_manager.stop_update_scheduler() - class AdultContentBlockHandler(RequestHandler): - def __init__(self, db_manager: DatabaseManager): - self.db_manager = db_manager - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle adult content block requests.""" try: @@ -77,121 +43,119 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: # Handle toggle request state = request_data['action'] # 'on' or 'off' self.db_manager.update_setting('adult_block', state) + self.logger.info(f"Adult content blocking turned {state}") return { 'code': Codes.CODE_ADULT_BLOCK, - 'message': RESPONSE_MESSAGES['success'] + 'message': f"Adult content blocking turned {state}" } - - elif 'domain' in request_data: - # Check if adult blocking is enabled - if self.db_manager.get_setting('adult_block') == 'on': - return { - 'code': Codes.CODE_ADULT_BLOCK, - 'message': "Adult content blocked" - } - + return { 'code': Codes.CODE_ADULT_BLOCK, 'message': RESPONSE_MESSAGES['success'] } except Exception as e: + self.logger.error(f"Error in adult content block handler: {e}") return { 'code': Codes.CODE_ADULT_BLOCK, 'message': str(e) } class DomainBlockHandler(RequestHandler): - def __init__(self, db_manager: DatabaseManager): - self.db_manager = db_manager - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """Handle domain block/unblock requests.""" + """Handle domain blocking requests.""" try: - action = request_data.get('action') - domain = request_data.get('domain') - - if not domain: + if 'action' not in request_data or 'domain' not in request_data: + self.logger.warning("Invalid request format: missing action or domain") return { - 'code': request_data.get('code'), + 'code': Codes.CODE_ADD_DOMAIN, 'message': RESPONSE_MESSAGES['invalid_request'] } + domain = request_data['domain'] + action = request_data['action'] + if action == 'block': self.db_manager.add_blocked_domain(domain) + self.logger.info(f"Domain blocked: {domain}") return { 'code': Codes.CODE_ADD_DOMAIN, 'message': RESPONSE_MESSAGES['domain_blocked'] } - elif action == 'unblock': if self.db_manager.remove_blocked_domain(domain): + self.logger.info(f"Domain unblocked: {domain}") return { 'code': Codes.CODE_REMOVE_DOMAIN, 'message': RESPONSE_MESSAGES['success'] } else: + self.logger.warning(f"Domain not found for unblocking: {domain}") return { 'code': Codes.CODE_REMOVE_DOMAIN, 'message': RESPONSE_MESSAGES['domain_not_found'] } - - return { - 'code': request_data.get('code'), - 'message': RESPONSE_MESSAGES['invalid_request'] - } - + else: + self.logger.warning(f"Invalid action requested: {action}") + return { + 'code': Codes.CODE_ADD_DOMAIN, + 'message': RESPONSE_MESSAGES['invalid_request'] + } + except Exception as e: + self.logger.error(f"Error in domain block handler: {e}") return { - 'code': request_data.get('code'), + 'code': Codes.CODE_ADD_DOMAIN, 'message': str(e) } class DomainListHandler(RequestHandler): - def __init__(self, db_manager: DatabaseManager): - self.db_manager = db_manager - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle domain list requests.""" try: domains = self.db_manager.get_blocked_domains() + self.logger.info(f"Domain list requested, returned {len(domains)} domains") return { 'code': Codes.CODE_DOMAIN_LIST_UPDATE, 'domains': domains, 'message': RESPONSE_MESSAGES['success'] } except Exception as e: + self.logger.error(f"Error in domain list handler: {e}") return { 'code': Codes.CODE_DOMAIN_LIST_UPDATE, 'message': str(e) } class RequestFactory: + """Factory class for creating appropriate request handlers.""" def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager - self._handlers = { - Codes.CODE_AD_BLOCK: lambda: AdBlockHandler(self.db_manager), - Codes.CODE_ADULT_BLOCK: lambda: AdultContentBlockHandler(self.db_manager), - Codes.CODE_ADD_DOMAIN: lambda: DomainBlockHandler(self.db_manager), - Codes.CODE_REMOVE_DOMAIN: lambda: DomainBlockHandler(self.db_manager), - Codes.CODE_DOMAIN_LIST_UPDATE: lambda: DomainListHandler(self.db_manager) + self.logger = setup_logger(__name__) + self.handlers = { + Codes.CODE_AD_BLOCK: AdBlockHandler(db_manager), + Codes.CODE_ADULT_BLOCK: AdultContentBlockHandler(db_manager), + Codes.CODE_ADD_DOMAIN: DomainBlockHandler(db_manager), + Codes.CODE_REMOVE_DOMAIN: DomainBlockHandler(db_manager), + Codes.CODE_DOMAIN_LIST_UPDATE: DomainListHandler(db_manager) } - def create_request_handler(self, request_type: str) -> RequestHandler: - """Create appropriate handler based on request type.""" - handler_creator = self._handlers.get(request_type) - if handler_creator: - return handler_creator() - raise ValueError(f"Invalid request type: {request_type}") - def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """Handle incoming request using appropriate handler.""" + """Route request to appropriate handler.""" try: - request_type = request_data.get('code') - handler = self.create_request_handler(request_type) - return handler.handle_request(request_data) + code = request_data.get('code') + handler = self.handlers.get(code) + + if handler: + self.logger.debug(f"Handling request with code: {code}") + return handler.handle_request(request_data) + else: + self.logger.warning(f"Invalid request code: {code}") + return { + 'message': RESPONSE_MESSAGES['invalid_request'] + } except Exception as e: + self.logger.error(f"Error in request factory: {e}") return { - 'code': request_data.get('code', ''), 'message': str(e) } \ No newline at end of file diff --git a/server/src/logger.py b/server/src/logger.py new file mode 100644 index 0000000..6056c89 --- /dev/null +++ b/server/src/logger.py @@ -0,0 +1,45 @@ +"""Logger module for handling application-wide logging configuration.""" + +import logging +import os +from datetime import datetime +from typing import Optional +from .utils import LOG_DIR, LOG_FORMAT, LOG_DATE_FORMAT + +_logger: Optional[logging.Logger] = None + +def setup_logger(name: str) -> logging.Logger: + """ + Configure and return a logger instance. + + Args: + name: The name of the module requesting the logger. + + Returns: + logging.Logger: Configured logger instance. + """ + global _logger + + if _logger is not None: + return logging.getLogger(name) + + if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR) + + log_file: str = os.path.join( + LOG_DIR, f"server_{datetime.now().strftime(LOG_DATE_FORMAT)}.log" + ) + + logging.basicConfig( + level=logging.INFO, + format=LOG_FORMAT, + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler(), + ], + ) + + _logger = logging.getLogger(name) + _logger.info("Logger setup complete") + + return _logger \ No newline at end of file diff --git a/server/src/server.py b/server/src/server.py index 4a9302d..e7a1dc2 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -1,68 +1,98 @@ +from typing import Dict, Any, Optional import socket import threading import json import asyncio -from typing import Dict, Any -from .config import HOST, CLIENT_PORT, KERNEL_PORT +from .utils import HOST, CLIENT_PORT, KERNEL_PORT from .db_manager import DatabaseManager from .handlers import RequestFactory +from .logger import setup_logger class Server: def __init__(self, db_manager: DatabaseManager) -> None: + """Initialize server with database manager.""" self.db_manager = db_manager self.request_factory = RequestFactory(self.db_manager) self.running = True + self.logger = setup_logger(__name__) + self.logger.info("Server initialized") def handle_client_thread(self) -> None: """Handle client connections using traditional socket.""" client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) client_socket.bind((HOST, CLIENT_PORT)) - client_socket.listen(1) # Only one client needed - print(f"Client server running on {HOST}:{CLIENT_PORT}") - - while self.running: - try: - conn, addr = client_socket.accept() - print(f"Client connected from {addr}") - - # Send initial domain list - domains = self.db_manager.get_blocked_domains() - conn.send(json.dumps({ - 'type': 'domain_list', - 'domains': domains - }).encode() + b'\n') - - while True: - data = conn.recv(1024) - if not data: - break + client_socket.listen(1) + client_socket.settimeout(1.0) + self.logger.info(f"Client server running on {HOST}:{CLIENT_PORT}") + try: + while self.running: + try: + conn, addr = client_socket.accept() + self.logger.info(f"Client connected from {addr}") + + # Set timeout for client connection as well + conn.settimeout(1.0) + try: - request_data = json.loads(data.decode()) - response = self.request_factory.handle_request(request_data) - conn.send(json.dumps(response).encode() + b'\n') - except json.JSONDecodeError: - conn.send(json.dumps({ - 'status': 'error', - 'message': 'Invalid JSON format' - }).encode() + b'\n') - except Exception as e: + # Send initial domain list + domains = self.db_manager.get_blocked_domains() conn.send(json.dumps({ - 'status': 'error', - 'message': str(e) + 'type': 'domain_list', + 'domains': domains }).encode() + b'\n') + self.logger.debug(f"Sent initial domain list: {domains}") + + while True: + try: + data = conn.recv(1024) + if not data: + break + + try: + request_data = json.loads(data.decode()) + self.logger.debug(f"Received request: {request_data}") + response = self.request_factory.handle_request(request_data) + conn.send(json.dumps(response).encode() + b'\n') + self.logger.debug(f"Sent response: {response}") + except json.JSONDecodeError: + self.logger.error("Invalid JSON format received") + conn.send(json.dumps({ + 'status': 'error', + 'message': 'Invalid JSON format' + }).encode() + b'\n') + except Exception as e: + self.logger.error(f"Error handling request: {e}") + conn.send(json.dumps({ + 'status': 'error', + 'message': str(e) + }).encode() + b'\n') + except socket.timeout: + if not self.running: + break + continue + finally: + conn.close() + + except socket.timeout: + if not self.running: + break + continue + except Exception as e: + self.logger.error(f"Client error: {e}") - except Exception as e: - print(f"Client error: {e}") - finally: - if 'conn' in locals(): - conn.close() + finally: + client_socket.close() - async def handle_kernel_requests(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + async def handle_kernel_requests( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter + ) -> None: """Handle kernel requests using asyncio for better performance.""" addr = writer.get_extra_info('peername') - print(f"Kernel module connected from {addr}") - + self.logger.info(f"Kernel module connected from {addr}") + try: while True: data = await reader.readline() @@ -70,54 +100,90 @@ async def handle_kernel_requests(self, reader: asyncio.StreamReader, writer: asy break request_data = json.loads(data.decode()) - domain = request_data.get('domain') - - # Fast domain check - should_block = ( - self.db_manager.is_domain_blocked(domain) or - (self.db_manager.get_setting('ad_block') == 'on' and - self.db_manager.is_easylist_blocked(domain)) or - (self.db_manager.get_setting('adult_block') == 'on' and - 'adult' in request_data.get('categories', [])) - ) - - writer.write(json.dumps({'block': should_block}).encode() + b'\n') + domain = request_data.get('domain', '').strip() + + if not domain: + continue + + # Get current settings state + ad_block_enabled = self.db_manager.get_setting('ad_block') == 'on' + adult_block_enabled = self.db_manager.get_setting('adult_block') == 'on' + + block_reason = None + should_block = False + + # Check custom blocked domains first + if self.db_manager.is_domain_blocked(domain): + should_block = True + block_reason = "custom_blocklist" + self.logger.info(f"Domain {domain} blocked (custom blocklist)") + + # Check if ad blocking is enabled and domain is an ad + elif ad_block_enabled and request_data.get('is_ad', False): + should_block = True + block_reason = "ads" + self.logger.info(f"Domain {domain} blocked (ads)") + + # Check adult content last if enabled + elif adult_block_enabled and 'adult' in request_data.get('categories', []): + should_block = True + block_reason = "adult_content" + self.logger.info(f"Domain {domain} blocked (adult content)") + + response = { + 'block': should_block, + 'reason': block_reason or 'allowed', + 'domain': domain + } + + self.logger.debug(f"Domain check result: {domain} -> {'blocked' if should_block else 'allowed'} ({block_reason or 'no reason'})") + + writer.write(json.dumps(response).encode() + b'\n') await writer.drain() except Exception as e: - print(f"Kernel error: {e}") + self.logger.error(f"Kernel error: {e}") finally: writer.close() await writer.wait_closed() + self.logger.info(f"Kernel connection closed for {addr}") - def start_server(self) -> None: + async def start_server(self) -> None: """Run both client and kernel handlers.""" + client_thread: Optional[threading.Thread] = None + kernel_server: Optional[asyncio.Server] = None + try: # Start client handler in a separate thread client_thread = threading.Thread(target=self.handle_client_thread) client_thread.start() + self.logger.info("Client handler thread started") # Run kernel handler with asyncio - async def start_kernel_server() -> None: - kernel_server = await asyncio.start_server( - self.handle_kernel_requests, - HOST, - KERNEL_PORT - ) - print(f"Kernel server running on {HOST}:{KERNEL_PORT}") + kernel_server = await asyncio.start_server( + self.handle_kernel_requests, + HOST, + KERNEL_PORT + ) + self.logger.info(f"Kernel server running on {HOST}:{KERNEL_PORT}") + + async with kernel_server: await kernel_server.serve_forever() - # Run the asyncio event loop for kernel handler - asyncio.run(start_kernel_server()) - - except KeyboardInterrupt: - self.running = False - print("\nServer stopping...") except Exception as e: - print(f"Server error: {e}") + self.logger.error(f"Server error: {e}") + raise + finally: + self.running = False + # Clean up resources + if kernel_server: + kernel_server.close() + await kernel_server.wait_closed() + if client_thread and client_thread.is_alive(): + client_thread.join(timeout=1.0) def initialize_server(db_file: str) -> None: """Initialize and run the server.""" db_manager = DatabaseManager(db_file) server = Server(db_manager) - server.start_server() \ No newline at end of file + asyncio.run(server.start_server()) \ No newline at end of file diff --git a/server/src/response_codes.py b/server/src/utils.py similarity index 54% rename from server/src/response_codes.py rename to server/src/utils.py index d318126..6592587 100644 --- a/server/src/response_codes.py +++ b/server/src/utils.py @@ -1,6 +1,22 @@ +import os +from pathlib import Path from typing import Dict -# Client Command Codes (exact match with client's codes) +# Base directories +BASE_DIR = Path(__file__).parent.parent +LOG_DIR = os.path.join(BASE_DIR, "logs") + +# Network Configuration +HOST: str = '127.0.0.1' +CLIENT_PORT: int = 65432 +KERNEL_PORT: int = 65433 +DB_FILE: str = 'my_internet.db' + +# Logging configuration +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +LOG_DATE_FORMAT = "%Y%m%d_%H%M%S" + +# Client Command Codes class Codes: CODE_AD_BLOCK = "50" CODE_ADULT_BLOCK = "51" diff --git a/server/tests/conftest.py b/server/tests/conftest.py index a1cfeff..a44aaac 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,86 +1,55 @@ import pytest from unittest import mock from typing import Generator +from My_Internet.server.src.server import Server from My_Internet.server.src.db_manager import DatabaseManager -from My_Internet.server.src.handlers import RequestFactory -@pytest.fixture(scope="function") +@pytest.fixture def mock_db_manager() -> mock.Mock: - """Fixture to provide a mock database manager.""" - mock_db = mock.Mock(spec=DatabaseManager) - - # Setup default returns for common methods - mock_db.is_domain_blocked.return_value = False - mock_db.is_easylist_blocked.return_value = False - mock_db.get_setting.return_value = 'off' # Default setting state - mock_db.get_blocked_domains.return_value = [] # Default empty domain list - - return mock_db + """Create a mock database manager.""" + db_manager = mock.Mock(spec=DatabaseManager) + db_manager.get_blocked_domains.return_value = [] + db_manager.get_setting.return_value = 'off' + db_manager.is_domain_blocked.return_value = False + return db_manager @pytest.fixture -def mock_requests() -> Generator[mock.Mock, None, None]: - """Fixture to mock requests library for easylist downloading.""" - with mock.patch('My_Internet.server.src.handlers.requests') as mock_req: - # Create a mock response - mock_response = mock.Mock() - mock_response.text = "test.com\n!comment\nexample.com" - mock_req.get.return_value = mock_response - mock_response.raise_for_status = mock.Mock() - - # Reset the mock to clear any previous calls - mock_req.reset_mock() - yield mock_req - -@pytest.fixture(scope="function") -def request_factory(mock_db_manager: mock.Mock) -> RequestFactory: - """Fixture to create a RequestFactory instance.""" - return RequestFactory(mock_db_manager) - -@pytest.fixture(scope="session") -def sample_domains() -> list[str]: - """Fixture to provide test domains.""" - return [ - "example.com", - "test.com", - "sample.org" - ] +def server_instance(mock_db_manager: mock.Mock) -> Server: + """Create a server instance for testing.""" + server = Server(mock_db_manager) + server.logger = mock.Mock() # Mock the logger to prevent actual logging + return server -@pytest.fixture(scope="session") -def sample_requests() -> dict: - """Fixture to provide sample request data.""" - return { - "adult_block": { - "code": "51", - "action": "on" - }, - "domain_block": { - "code": "52", - "action": "block", - "domain": "example.com" - }, - "ad_block": { - "code": "50", - "action": "on" - }, - "check_domain": { - "code": "50", - "domain": "ads.example.com" - } - } +@pytest.fixture +def mock_socket() -> mock.Mock: + """Create a mock socket for testing.""" + socket_mock = mock.Mock() + socket_mock.bind = mock.Mock() + socket_mock.listen = mock.Mock() + socket_mock.accept = mock.Mock() + socket_mock.close = mock.Mock() + socket_mock.settimeout = mock.Mock() + return socket_mock @pytest.fixture def mock_stream_reader() -> mock.AsyncMock: - """Mock for asyncio StreamReader.""" + """Create a mock stream reader.""" reader = mock.AsyncMock() reader.readline = mock.AsyncMock() return reader @pytest.fixture def mock_stream_writer() -> mock.Mock: - """Mock for asyncio StreamWriter.""" + """Create a mock stream writer.""" writer = mock.Mock() writer.write = mock.Mock() writer.drain = mock.AsyncMock() writer.close = mock.Mock() writer.wait_closed = mock.AsyncMock() - return writer \ No newline at end of file + writer.get_extra_info = mock.Mock(return_value=('127.0.0.1', 12345)) + return writer + +@pytest.fixture +def mock_asyncio_start_server() -> mock.AsyncMock: + """Create a mock for asyncio.start_server.""" + return mock.AsyncMock() \ No newline at end of file diff --git a/server/tests/test_handlers.py b/server/tests/test_handlers.py index 3d31cb4..af6dc36 100644 --- a/server/tests/test_handlers.py +++ b/server/tests/test_handlers.py @@ -1,166 +1,136 @@ +from typing import Dict, Any import pytest from unittest import mock -from typing import Dict, Any from My_Internet.server.src.handlers import ( - RequestHandler, + RequestHandler, + AdBlockHandler, AdultContentBlockHandler, DomainBlockHandler, - AdBlockHandler, - RequestFactory, - EASYLIST_URL -) -from My_Internet.server.src.response_codes import ( - Codes, - RESPONSE_MESSAGES + DomainListHandler, + RequestFactory ) +from My_Internet.server.src.utils import Codes, RESPONSE_MESSAGES + +@pytest.fixture +def mock_db_manager() -> mock.Mock: + """Create a mock database manager.""" + return mock.Mock() + +@pytest.fixture +def mock_logger() -> mock.Mock: + """Create a mock logger.""" + return mock.Mock() class TestAdBlockHandler: @pytest.fixture - def handler(self, mock_db_manager: mock.Mock, mock_requests: mock.Mock) -> AdBlockHandler: - """Create handler instance.""" - with mock.patch('My_Internet.server.src.handlers.AdBlockHandler.load_easylist'): - handler = AdBlockHandler(mock_db_manager) - mock_requests.reset_mock() - return handler - - def test_load_easylist(self, handler: AdBlockHandler, mock_requests: mock.Mock) -> None: - """Test loading and parsing easylist.""" - # Configure mock response - mock_response = mock.Mock() - mock_response.text = "test.com\n!comment\nexample.com" - mock_requests.get.return_value = mock_response - - # Call method - handler.load_easylist() + def handler(self, mock_db_manager: mock.Mock) -> AdBlockHandler: + """Create AdBlockHandler instance.""" + return AdBlockHandler(mock_db_manager) + + def test_handle_request_toggle_on(self, handler: AdBlockHandler) -> None: + """Test handling ad block toggle on request.""" + request_data: Dict[str, Any] = {'action': 'on'} + response = handler.handle_request(request_data) - # Verify calls - mock_requests.get.assert_called_once_with(EASYLIST_URL) - mock_response.raise_for_status.assert_called_once() - handler.db_manager.clear_easylist.assert_called_once() - - # Verify easylist storage - expected_entries = [('test.com',), ('example.com',)] - handler.db_manager.store_easylist_entries.assert_called_once_with(expected_entries) - - def test_handle_toggle_request(self, handler: AdBlockHandler) -> None: - """Test toggling ad blocking on/off.""" - # Test enabling - response = handler.handle_request({'action': 'on'}) - handler.db_manager.update_setting.assert_called_with('ad_block', 'on') + handler.db_manager.update_setting.assert_called_once_with('ad_block', 'on') assert response['code'] == Codes.CODE_AD_BLOCK + assert response['message'] == "Ad blocking turned on" - # Test disabling - response = handler.handle_request({'action': 'off'}) - handler.db_manager.update_setting.assert_called_with('ad_block', 'off') - assert response['code'] == Codes.CODE_AD_BLOCK - - def test_handle_check_domain(self, handler: AdBlockHandler) -> None: - """Test checking domain with ad blocking.""" - # Setup: ad blocking enabled and domain matched in easylist - handler.db_manager.get_setting.return_value = 'on' - handler.db_manager.is_easylist_blocked.return_value = True + def test_handle_request_error(self, handler: AdBlockHandler) -> None: + """Test handling error in ad block request.""" + handler.db_manager.update_setting.side_effect = Exception("Test error") + response = handler.handle_request({'action': 'on'}) - response = handler.handle_request({'domain': 'ads.example.com'}) assert response['code'] == Codes.CODE_AD_BLOCK - assert handler.db_manager.is_easylist_blocked.called + assert response['message'] == "Test error" class TestAdultContentBlockHandler: @pytest.fixture def handler(self, mock_db_manager: mock.Mock) -> AdultContentBlockHandler: - """Create handler instance.""" + """Create AdultContentBlockHandler instance.""" return AdultContentBlockHandler(mock_db_manager) - def test_handle_toggle_request(self, handler: AdultContentBlockHandler) -> None: - """Test toggling adult content blocking.""" - # Test enabling - response = handler.handle_request({'action': 'on'}) - handler.db_manager.update_setting.assert_called_with('adult_block', 'on') - assert response['code'] == Codes.CODE_ADULT_BLOCK - - # Test disabling - response = handler.handle_request({'action': 'off'}) - handler.db_manager.update_setting.assert_called_with('adult_block', 'off') - assert response['code'] == Codes.CODE_ADULT_BLOCK - - def test_handle_check_request(self, handler: AdultContentBlockHandler) -> None: - """Test checking domain with adult content blocking.""" - # Setup: adult blocking enabled - handler.db_manager.get_setting.return_value = 'on' + def test_handle_request_toggle_on(self, handler: AdultContentBlockHandler) -> None: + """Test handling adult content block toggle on request.""" + request_data: Dict[str, Any] = {'action': 'on'} + response = handler.handle_request(request_data) - response = handler.handle_request({ - 'action': 'check', - 'domain': 'example.com' - }) + handler.db_manager.update_setting.assert_called_once_with('adult_block', 'on') assert response['code'] == Codes.CODE_ADULT_BLOCK + assert response['message'] == "Adult content blocking turned on" class TestDomainBlockHandler: @pytest.fixture def handler(self, mock_db_manager: mock.Mock) -> DomainBlockHandler: - """Create handler instance.""" + """Create DomainBlockHandler instance.""" return DomainBlockHandler(mock_db_manager) def test_block_domain(self, handler: DomainBlockHandler) -> None: """Test blocking a domain.""" - response = handler.handle_request({ + request_data: Dict[str, Any] = { 'action': 'block', 'domain': 'example.com' - }) + } + response = handler.handle_request(request_data) handler.db_manager.add_blocked_domain.assert_called_once_with('example.com') assert response['code'] == Codes.CODE_ADD_DOMAIN + assert response['message'] == RESPONSE_MESSAGES['domain_blocked'] def test_unblock_domain(self, handler: DomainBlockHandler) -> None: """Test unblocking a domain.""" - # Setup: domain exists - handler.db_manager.is_domain_blocked.return_value = True - - response = handler.handle_request({ + handler.db_manager.remove_blocked_domain.return_value = True + request_data: Dict[str, Any] = { 'action': 'unblock', 'domain': 'example.com' - }) + } + response = handler.handle_request(request_data) handler.db_manager.remove_blocked_domain.assert_called_once_with('example.com') assert response['code'] == Codes.CODE_REMOVE_DOMAIN + assert response['message'] == RESPONSE_MESSAGES['success'] - def test_unblock_nonexistent_domain(self, handler: DomainBlockHandler) -> None: - """Test unblocking a nonexistent domain.""" - handler.db_manager.remove_blocked_domain.return_value = False + def test_invalid_request(self, handler: DomainBlockHandler) -> None: + """Test handling invalid request.""" + request_data: Dict[str, Any] = {'action': 'block'} # Missing domain + response = handler.handle_request(request_data) - response = handler.handle_request({ - 'action': 'unblock', - 'domain': 'nonexistent.com' - }) - - assert 'domain not found' in response['message'].lower() + assert response['code'] == Codes.CODE_ADD_DOMAIN + assert response['message'] == RESPONSE_MESSAGES['invalid_request'] +class TestDomainListHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock) -> DomainListHandler: + """Create DomainListHandler instance.""" + return DomainListHandler(mock_db_manager) + + def test_get_domain_list(self, handler: DomainListHandler) -> None: + """Test getting list of blocked domains.""" + domains = ['example.com', 'test.com'] + handler.db_manager.get_blocked_domains.return_value = domains + response = handler.handle_request({}) + + assert response['code'] == Codes.CODE_DOMAIN_LIST_UPDATE + assert response['domains'] == domains + assert response['message'] == RESPONSE_MESSAGES['success'] class TestRequestFactory: @pytest.fixture def factory(self, mock_db_manager: mock.Mock) -> RequestFactory: - """Create factory instance.""" + """Create RequestFactory instance.""" return RequestFactory(mock_db_manager) - def test_create_handlers(self, factory: RequestFactory) -> None: - """Test creating different types of handlers.""" - test_cases = [ - (Codes.CODE_AD_BLOCK, AdBlockHandler), - (Codes.CODE_ADULT_BLOCK, AdultContentBlockHandler), - (Codes.CODE_ADD_DOMAIN, DomainBlockHandler), - (Codes.CODE_REMOVE_DOMAIN, DomainBlockHandler) - ] - - for code, handler_class in test_cases: - handler = factory.create_request_handler(code) - assert isinstance(handler, handler_class) - - def test_handle_request(self, factory: RequestFactory, sample_requests: dict) -> None: - """Test handling different types of requests.""" - for request in sample_requests.values(): - response = factory.handle_request(request) - assert 'code' in response - assert 'message' in response - - def test_invalid_request_type(self, factory: RequestFactory) -> None: - """Test handling invalid request type.""" - response = factory.handle_request({'code': 'invalid'}) - assert 'invalid' in response['message'].lower() \ No newline at end of file + def test_handle_valid_request(self, factory: RequestFactory) -> None: + """Test handling valid request with correct code.""" + request_data: Dict[str, Any] = { + 'code': Codes.CODE_AD_BLOCK, + 'action': 'on' + } + response = factory.handle_request(request_data) + assert response['code'] == Codes.CODE_AD_BLOCK + + def test_handle_invalid_code(self, factory: RequestFactory) -> None: + """Test handling request with invalid code.""" + request_data: Dict[str, Any] = {'code': 'invalid_code'} + response = factory.handle_request(request_data) + assert response['message'] == RESPONSE_MESSAGES['invalid_request'] \ No newline at end of file diff --git a/server/tests/test_server.py b/server/tests/test_server.py index 3cdaf5f..121eaf2 100644 --- a/server/tests/test_server.py +++ b/server/tests/test_server.py @@ -2,233 +2,217 @@ import json import asyncio from unittest import mock -from typing import Dict, Any -from My_Internet.server.src.server import ( - Server, - start_server -) -from My_Internet.server.src.config import HOST, CLIENT_PORT, KERNEL_PORT -from My_Internet.server.src.response_codes import Codes -from My_Internet.server.src.db_manager import DatabaseManager - -class TestServer: - @pytest.fixture - def db_manager(self) -> mock.Mock: - """Create a mock database manager.""" - mock_db = mock.Mock(spec=DatabaseManager) - mock_db.db_file = "test.db" - mock_db.is_domain_blocked.return_value = True - mock_db.get_setting.return_value = "on" - return mock_db - - @pytest.fixture - def server(self, db_manager: mock.Mock) -> Server: - """Create server instance for testing with a mocked DatabaseManager.""" - return Server(db_manager) - - @pytest.fixture - def mock_stream_reader(self) -> mock.AsyncMock: - """Mock for asyncio StreamReader.""" - reader = mock.AsyncMock() - reader.readline = mock.AsyncMock() - return reader - - @pytest.fixture - def mock_stream_writer(self) -> mock.Mock: - """Mock for asyncio StreamWriter.""" - writer = mock.Mock() - writer.write = mock.Mock() - writer.drain = mock.AsyncMock() - writer.close = mock.Mock() - writer.wait_closed = mock.AsyncMock() - return writer - - @pytest.mark.asyncio - async def test_handle_client( - self, - server: Server, - mock_stream_reader: mock.AsyncMock, - mock_stream_writer: mock.Mock - ) -> None: - """Test client request handling.""" - # Setup mocks - mock_stream_writer.write = mock.Mock() - mock_stream_writer.drain = mock.AsyncMock() - mock_stream_writer.get_extra_info = mock.Mock(return_value="test_client") - server.db_manager.get_blocked_domains.return_value = [] - - # Setup test request - test_request = { - 'code': Codes.CODE_ADD_DOMAIN, - 'action': 'block', - 'domain': 'example.com' - } - - # Configure mock reader responses - mock_stream_reader.readline.side_effect = [ - json.dumps(test_request).encode() + b'\n', - b'' # End connection after request - ] - - # Handle client connection - await server.handle_client(mock_stream_reader, mock_stream_writer) - - # Verify response was sent at least twice (domain list + request response) - assert mock_stream_writer.write.call_count >= 2 - - @pytest.mark.asyncio - async def test_handle_kernel(self, server: Server, mock_stream_reader: mock.AsyncMock, mock_stream_writer: mock.Mock) -> None: - """Test kernel module request handling.""" - # Setup test request - test_request = { - 'domain': 'example.com', - 'categories': ['adult'] - } - - # Configure mocks - server.db_manager.is_domain_blocked.return_value = True - mock_stream_writer.write = mock.Mock() - mock_stream_writer.drain = mock.AsyncMock() - - # Configure mock reader responses - mock_stream_reader.readline.side_effect = [ - json.dumps(test_request).encode() + b'\n', - b'' # End connection after request - ] - - await server.handle_kernel(mock_stream_reader, mock_stream_writer) - assert mock_stream_writer.write.called - - def test_handle_kernel_request(self, server: Server) -> None: - """Test kernel request processing.""" - test_cases = [ - # Test manually blocked domain - { - 'setup': { - 'is_domain_blocked': True, - 'get_setting': 'off' - }, - 'request': { - 'domain': 'blocked.com' - }, - 'expected': True - }, - # Test ad blocking - { - 'setup': { - 'is_domain_blocked': False, - 'get_setting': 'on', - 'is_easylist_blocked': True - }, - 'request': { - 'domain': 'ads.example.com' - }, - 'expected': True - }, - # Test adult content blocking - { - 'setup': { - 'is_domain_blocked': False, - 'get_setting': 'on' - }, - 'request': { - 'domain': 'example.com', - 'categories': ['adult'] - }, - 'expected': True - }, - # Test allowed domain - { - 'setup': { - 'is_domain_blocked': False, - 'get_setting': 'off', - 'is_easylist_blocked': False - }, - 'request': { - 'domain': 'example.com' - }, - 'expected': False - } - ] - - for case in test_cases: - # Setup mocks - server.db_manager.is_domain_blocked.return_value = case['setup'].get('is_domain_blocked', False) - server.db_manager.get_setting.return_value = case['setup'].get('get_setting', 'off') - server.db_manager.is_easylist_blocked.return_value = case['setup'].get('is_easylist_blocked', False) - - # Test request - response = server.handle_kernel_request(case['request']) - assert response['block'] is case['expected'] - - @pytest.mark.asyncio - async def test_start_server(self, db_manager: mock.Mock) -> None: - """Test server startup.""" - # Configure mock - db_manager.db_file = "test.db" - mock_client_server = mock.AsyncMock() - mock_kernel_server = mock.AsyncMock() - - with mock.patch('asyncio.start_server', side_effect=[mock_client_server, mock_kernel_server]) as mock_start: - task = asyncio.create_task(start_server(db_manager)) - await asyncio.sleep(0.1) - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - @pytest.mark.asyncio - async def test_client_error_handling( - self, - server: Server, - mock_stream_reader: mock.AsyncMock, - mock_stream_writer: mock.Mock - ) -> None: - """Test handling of client errors.""" - # Configure mocks - mock_stream_writer.write = mock.Mock() - mock_stream_writer.drain = mock.AsyncMock() - mock_stream_writer.get_extra_info = mock.Mock(return_value="test_client") - server.db_manager.get_blocked_domains.return_value = [] - - # Set up the mock to return invalid JSON - mock_stream_reader.readline.side_effect = [ - b'invalid json\n', - b'' # End connection after invalid request - ] - - # Handle the client connection - await server.handle_client(mock_stream_reader, mock_stream_writer) - - # Verify responses were sent (domain list + error response) - assert mock_stream_writer.write.call_count >= 2 - - @pytest.mark.asyncio - async def test_kernel_error_handling( - self, - server: Server, - mock_stream_reader: mock.AsyncMock, - mock_stream_writer: mock.Mock - ) -> None: - """Test handling of kernel errors.""" - # Test connection error - mock_stream_reader.readline.side_effect = ConnectionError() - - await server.handle_kernel(mock_stream_reader, mock_stream_writer) - - # Verify connection was closed - assert mock_stream_writer.close.called - - def test_multiple_kernel_requests(self, server: Server) -> None: - """Test handling multiple kernel requests.""" - # Setup mocks - server.db_manager.get_setting.side_effect = ['on', 'off'] - server.db_manager.is_domain_blocked.return_value = True - - # Test first request (blocking enabled) - response1 = server.handle_kernel_request({ +from typing import Dict, Any, Generator +from My_Internet.server.src.server import Server +from My_Internet.server.src.utils import HOST, CLIENT_PORT, KERNEL_PORT + +@pytest.mark.asyncio +async def test_handle_kernel_requests_block_custom_domain( + server_instance: Server, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock +) -> None: + """Test handling kernel requests for custom blocked domain.""" + server_instance.db_manager.is_domain_blocked.return_value = True + mock_stream_reader.readline.side_effect = [ + json.dumps({ 'domain': 'example.com', + 'is_ad': False, + 'categories': [] + }).encode() + b'\n', + b'' + ] + + await server_instance.handle_kernel_requests(mock_stream_reader, mock_stream_writer) + + response_data = json.loads(mock_stream_writer.write.call_args[0][0].decode().strip()) + assert response_data['block'] is True + assert response_data['reason'] == 'custom_blocklist' + assert response_data['domain'] == 'example.com' + +@pytest.mark.asyncio +async def test_handle_kernel_requests_block_ad( + server_instance: Server, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock +) -> None: + """Test handling kernel requests for ad blocking.""" + server_instance.db_manager.is_domain_blocked.return_value = False + server_instance.db_manager.get_setting.side_effect = lambda x: 'on' if x == 'ad_block' else 'off' + mock_stream_reader.readline.side_effect = [ + json.dumps({ + 'domain': 'ad.example.com', + 'is_ad': True, + 'categories': [] + }).encode() + b'\n', + b'' + ] + + await server_instance.handle_kernel_requests(mock_stream_reader, mock_stream_writer) + + response_data = json.loads(mock_stream_writer.write.call_args[0][0].decode().strip()) + assert response_data['block'] is True + assert response_data['reason'] == 'ads' + assert response_data['domain'] == 'ad.example.com' + +@pytest.mark.asyncio +async def test_handle_kernel_requests_block_adult_content( + server_instance: Server, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock +) -> None: + """Test handling kernel requests for adult content blocking.""" + server_instance.db_manager.is_domain_blocked.return_value = False + server_instance.db_manager.get_setting.side_effect = lambda x: 'on' if x == 'adult_block' else 'off' + mock_stream_reader.readline.side_effect = [ + json.dumps({ + 'domain': 'adult.example.com', + 'is_ad': False, 'categories': ['adult'] - }) - assert response1['block'] is True + }).encode() + b'\n', + b'' + ] + + await server_instance.handle_kernel_requests(mock_stream_reader, mock_stream_writer) + + response_data = json.loads(mock_stream_writer.write.call_args[0][0].decode().strip()) + assert response_data['block'] is True + assert response_data['reason'] == 'adult_content' + assert response_data['domain'] == 'adult.example.com' + +def test_handle_client_thread_initial_domain_list( + server_instance: Server, + mock_socket: mock.Mock, + monkeypatch: pytest.MonkeyPatch +) -> None: + """Test sending initial domain list to client.""" + mock_conn = mock.Mock() + mock_conn.send = mock.Mock() + mock_conn.recv.return_value = b'' + + server_instance.db_manager.get_blocked_domains.return_value = ['example.com'] + + mock_socket_instance = mock.Mock() + mock_socket_instance.accept = mock.Mock(return_value=(mock_conn, ('127.0.0.1', 12345))) + + mock_socket_class = mock.Mock(return_value=mock_socket_instance) + monkeypatch.setattr('socket.socket', mock_socket_class) + + def mock_accept(*args: Any, **kwargs: Any) -> tuple[mock.Mock, tuple[str, int]]: + server_instance.running = False + return mock_conn, ('127.0.0.1', 12345) + + mock_socket_instance.accept = mock.Mock(side_effect=mock_accept) + + server_instance.handle_client_thread() + + mock_conn.send.assert_called() + sent_data = json.loads(mock_conn.send.call_args_list[0][0][0].decode().strip()) + assert sent_data['type'] == 'domain_list' + assert isinstance(sent_data['domains'], list) + assert 'example.com' in sent_data['domains'] + +def test_handle_client_thread_process_request( + server_instance: Server, + mock_socket: mock.Mock, + monkeypatch: pytest.MonkeyPatch +) -> None: + """Test processing client request.""" + mock_conn = mock.Mock() + mock_conn.send = mock.Mock() + + mock_conn.recv.side_effect = [ + json.dumps({'code': '50', 'action': 'on'}).encode(), + b'' + ] + + mock_socket_instance = mock.Mock() + mock_socket_instance.accept = mock.Mock(return_value=(mock_conn, ('127.0.0.1', 12345))) + + mock_socket_class = mock.Mock(return_value=mock_socket_instance) + monkeypatch.setattr('socket.socket', mock_socket_class) + + def mock_accept(*args: Any, **kwargs: Any) -> tuple[mock.Mock, tuple[str, int]]: + server_instance.running = False + return mock_conn, ('127.0.0.1', 12345) + + mock_socket_instance.accept = mock.Mock(side_effect=mock_accept) + + server_instance.db_manager.get_blocked_domains.return_value = ['example.com'] + + mock_request_factory = mock.Mock() + mock_request_factory.handle_request.return_value = { + 'status': 'success', + 'message': 'Request processed' + } + server_instance.request_factory = mock_request_factory + + server_instance.handle_client_thread() + + assert mock_conn.send.call_count >= 2 + assert server_instance.request_factory.handle_request.called + +def test_handle_client_thread_invalid_json( + server_instance: Server, + mock_socket: mock.Mock, + monkeypatch: pytest.MonkeyPatch +) -> None: + """Test handling invalid JSON request.""" + mock_conn = mock.Mock() + mock_conn.send = mock.Mock() + mock_conn.recv.side_effect = [b'invalid json', b''] + + mock_socket_instance = mock.Mock() + mock_socket_instance.accept = mock.Mock(return_value=(mock_conn, ('127.0.0.1', 12345))) + + mock_socket_class = mock.Mock(return_value=mock_socket_instance) + monkeypatch.setattr('socket.socket', mock_socket_class) + + def mock_accept(*args: Any, **kwargs: Any) -> tuple[mock.Mock, tuple[str, int]]: + server_instance.running = False + return mock_conn, ('127.0.0.1', 12345) + + mock_socket_instance.accept = mock.Mock(side_effect=mock_accept) + + server_instance.db_manager.get_blocked_domains.return_value = ['example.com'] + + server_instance.handle_client_thread() + + assert mock_conn.send.call_count >= 2 + sent_data = json.loads(mock_conn.send.call_args_list[1][0][0].decode().strip()) + assert sent_data['status'] == 'error' + assert 'Invalid JSON format' in sent_data['message'] + +@pytest.mark.asyncio +async def test_start_server( + server_instance: Server, + mock_asyncio_start_server: mock.AsyncMock, + monkeypatch: pytest.MonkeyPatch +) -> None: + """Test server startup process.""" + monkeypatch.setattr('asyncio.start_server', mock_asyncio_start_server) + + mock_thread = mock.Mock() + mock_thread_class = mock.Mock(return_value=mock_thread) + monkeypatch.setattr('threading.Thread', mock_thread_class) + + mock_kernel_server = mock.AsyncMock() + mock_kernel_server.__aenter__ = mock.AsyncMock(return_value=mock_kernel_server) + mock_kernel_server.__aexit__ = mock.AsyncMock() + mock_kernel_server.close = mock.AsyncMock() + + async def mock_serve_forever(): + server_instance.running = False + + mock_kernel_server.serve_forever = mock.AsyncMock(side_effect=mock_serve_forever) + mock_asyncio_start_server.return_value = mock_kernel_server + + await server_instance.start_server() + + mock_thread_class.assert_called_once() + mock_thread.start.assert_called_once() + assert mock_asyncio_start_server.called + assert mock_kernel_server.__aenter__.called + assert mock_kernel_server.serve_forever.called + assert not server_instance.running + await mock_kernel_server.close() \ No newline at end of file From 396ebe731007c4abc765f28971f48e7ef13ec6c3 Mon Sep 17 00:00:00 2001 From: elipaz Date: Mon, 11 Nov 2024 12:49:38 +0200 Subject: [PATCH 30/38] Merged with updated server files --- server/src/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/src/server.py b/server/src/server.py index 4a9302d..22f142e 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -41,11 +41,13 @@ def handle_client_thread(self) -> None: request_data = json.loads(data.decode()) response = self.request_factory.handle_request(request_data) conn.send(json.dumps(response).encode() + b'\n') + except json.JSONDecodeError: conn.send(json.dumps({ 'status': 'error', 'message': 'Invalid JSON format' }).encode() + b'\n') + except Exception as e: conn.send(json.dumps({ 'status': 'error', @@ -120,4 +122,4 @@ def initialize_server(db_file: str) -> None: """Initialize and run the server.""" db_manager = DatabaseManager(db_file) server = Server(db_manager) - server.start_server() \ No newline at end of file + server.start_server() From 1ba71b074dec69dbeb82f2c56463bd636fbb9328 Mon Sep 17 00:00:00 2001 From: elipaz Date: Mon, 11 Nov 2024 14:08:46 +0200 Subject: [PATCH 31/38] Modify code --- server/src/db_manager.py | 25 +++++++++++++---------- server/src/handlers.py | 43 ++++++++++++++++++++-------------------- server/src/server.py | 3 +++ 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/server/src/db_manager.py b/server/src/db_manager.py index 3a8f26e..f13b18e 100644 --- a/server/src/db_manager.py +++ b/server/src/db_manager.py @@ -15,7 +15,6 @@ def _create_tables(self) -> None: with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - # Create settings table cursor.execute(""" CREATE TABLE IF NOT EXISTS settings ( key TEXT PRIMARY KEY, @@ -23,14 +22,12 @@ def _create_tables(self) -> None: ) """) - # Create blocked_domains table - simplified without timestamp cursor.execute(""" CREATE TABLE IF NOT EXISTS blocked_domains ( domain TEXT PRIMARY KEY ) """) - # Initialize settings if not exists cursor.execute(""" INSERT OR IGNORE INTO settings (key, value) VALUES @@ -45,7 +42,9 @@ def get_setting(self, setting: str) -> str: """Get setting value.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - cursor.execute("SELECT value FROM settings WHERE key = ?", (setting,)) + cursor.execute("""SELECT value + FROM settings + WHERE key = ?""", (setting,)) result = cursor.fetchone() value = result[0] if result else 'off' self.logger.debug(f"Retrieved setting {setting}: {value}") @@ -68,7 +67,8 @@ def add_blocked_domain(self, domain: str) -> None: with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() try: - cursor.execute("INSERT INTO blocked_domains (domain) VALUES (?)", (domain,)) + cursor.execute("""INSERT INTO blocked_domains (domain) + VALUES (?)""", (domain,)) conn.commit() self.logger.info(f"Domain {domain} added to block list") except sqlite3.IntegrityError: @@ -78,20 +78,21 @@ def remove_blocked_domain(self, domain: str) -> bool: """Remove a domain from blocked list.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - cursor.execute("DELETE FROM blocked_domains WHERE domain = ?", (domain,)) + cursor.execute("""DELETE FROM blocked_domains + WHERE domain = ?""", (domain,)) conn.commit() - success = cursor.rowcount > 0 - if success: + if cursor.rowcount: self.logger.info(f"Domain {domain} removed from block list") else: self.logger.warning(f"Domain {domain} not found in block list") - return success + return bool(cursor.rowcount) def get_blocked_domains(self) -> List[str]: """Get list of all blocked domains.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - cursor.execute("SELECT domain FROM blocked_domains") + cursor.execute("""SELECT domain + FROM blocked_domains""") domains = [row[0] for row in cursor.fetchall()] self.logger.debug(f"Retrieved {len(domains)} blocked domains") return domains @@ -100,7 +101,9 @@ def is_domain_blocked(self, domain: str) -> bool: """Check if domain is in blocked list.""" with sqlite3.connect(self.db_file) as conn: cursor = conn.cursor() - cursor.execute("SELECT 1 FROM blocked_domains WHERE domain = ?", (domain,)) + cursor.execute("""SELECT 1 + FROM blocked_domains + WHERE domain = ?""", (domain,)) is_blocked = cursor.fetchone() is not None self.logger.debug(f"Domain {domain} blocked status: {is_blocked}") return is_blocked \ No newline at end of file diff --git a/server/src/handlers.py b/server/src/handlers.py index 0213156..4aa8918 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -14,8 +14,7 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle ad block requests.""" try: if 'action' in request_data: - # Handle toggle request - state = request_data['action'] # 'on' or 'off' + state = request_data['action'] self.db_manager.update_setting('ad_block', state) self.logger.info(f"Ad blocking turned {state}") return { @@ -40,8 +39,7 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle adult content block requests.""" try: if 'action' in request_data: - # Handle toggle request - state = request_data['action'] # 'on' or 'off' + state = request_data['action'] self.db_manager.update_setting('adult_block', state) self.logger.info(f"Adult content blocking turned {state}") return { @@ -77,31 +75,34 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: if action == 'block': self.db_manager.add_blocked_domain(domain) + self.logger.info(f"Domain blocked: {domain}") + return { 'code': Codes.CODE_ADD_DOMAIN, 'message': RESPONSE_MESSAGES['domain_blocked'] } - elif action == 'unblock': + + if action == 'unblock': if self.db_manager.remove_blocked_domain(domain): self.logger.info(f"Domain unblocked: {domain}") return { 'code': Codes.CODE_REMOVE_DOMAIN, 'message': RESPONSE_MESSAGES['success'] } - else: - self.logger.warning(f"Domain not found for unblocking: {domain}") - return { - 'code': Codes.CODE_REMOVE_DOMAIN, - 'message': RESPONSE_MESSAGES['domain_not_found'] - } - else: - self.logger.warning(f"Invalid action requested: {action}") + + self.logger.warning(f"Domain not found for unblocking: {domain}") return { - 'code': Codes.CODE_ADD_DOMAIN, - 'message': RESPONSE_MESSAGES['invalid_request'] + 'code': Codes.CODE_REMOVE_DOMAIN, + 'message': RESPONSE_MESSAGES['domain_not_found'] } - + + self.logger.warning(f"Invalid action requested: {action}") + return { + 'code': Codes.CODE_ADD_DOMAIN, + 'message': RESPONSE_MESSAGES['invalid_request'] + } + except Exception as e: self.logger.error(f"Error in domain block handler: {e}") return { @@ -133,11 +134,11 @@ def __init__(self, db_manager: DatabaseManager): self.db_manager = db_manager self.logger = setup_logger(__name__) self.handlers = { - Codes.CODE_AD_BLOCK: AdBlockHandler(db_manager), - Codes.CODE_ADULT_BLOCK: AdultContentBlockHandler(db_manager), - Codes.CODE_ADD_DOMAIN: DomainBlockHandler(db_manager), - Codes.CODE_REMOVE_DOMAIN: DomainBlockHandler(db_manager), - Codes.CODE_DOMAIN_LIST_UPDATE: DomainListHandler(db_manager) + Codes.CODE_AD_BLOCK : AdBlockHandler(db_manager), + Codes.CODE_ADULT_BLOCK : AdultContentBlockHandler(db_manager), + Codes.CODE_ADD_DOMAIN : DomainBlockHandler(db_manager), + Codes.CODE_REMOVE_DOMAIN : DomainBlockHandler(db_manager), + Codes.CODE_DOMAIN_LIST_UPDATE : DomainListHandler(db_manager) } def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: diff --git a/server/src/server.py b/server/src/server.py index 60640df..7a2498a 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -55,18 +55,21 @@ def handle_client_thread(self) -> None: response = self.request_factory.handle_request(request_data) conn.send(json.dumps(response).encode() + b'\n') self.logger.debug(f"Sent response: {response}") + except json.JSONDecodeError: self.logger.error("Invalid JSON format received") conn.send(json.dumps({ 'status': 'error', 'message': 'Invalid JSON format' }).encode() + b'\n') + except Exception as e: self.logger.error(f"Error handling request: {e}") conn.send(json.dumps({ 'status': 'error', 'message': str(e) }).encode() + b'\n') + except socket.timeout: if not self.running: break From 036d447482a91aac6d11b8ffde9a685b20ff9619 Mon Sep 17 00:00:00 2001 From: elipaz Date: Tue, 12 Nov 2024 11:50:12 +0200 Subject: [PATCH 32/38] Finish integration --- client/src/Application.py | 12 +++- client/src/Communicator.py | 21 +++--- client/src/View.py | 14 +++- client/src/utils.py | 62 ++++++++++------- server/src/handlers.py | 135 ++++++++++++++++++------------------- server/src/server.py | 43 ++++++------ server/src/utils.py | 104 ++++++++++++++++++++++------ 7 files changed, 233 insertions(+), 158 deletions(-) diff --git a/client/src/Application.py b/client/src/Application.py index d4ec6bf..bab6116 100644 --- a/client/src/Application.py +++ b/client/src/Application.py @@ -6,7 +6,7 @@ from .ConfigManager import ConfigManager from .utils import ( - STR_CODE, STR_CONTENT, + STR_CODE, STR_DOMAINS, STR_CONTENT, Codes ) @@ -87,13 +87,19 @@ def _handle_request(self, request: str) -> None: with self._request_lock: match request_dict[STR_CODE]: + ## Codes sent to server case Codes.CODE_AD_BLOCK | \ Codes.CODE_ADULT_BLOCK | \ Codes.CODE_ADD_DOMAIN | \ Codes.CODE_REMOVE_DOMAIN: - self._communicator.send_message(json.dumps(request)) + self._communicator.send_message(request_dict) + ## Codes received from server case Codes.CODE_DOMAIN_LIST_UPDATE: - self._view.update_domain_list(request_dict[STR_CONTENT]) + self._view.update_domain_list(request_dict[STR_DOMAINS]) + case Codes.CODE_ERROR: + self._view._show_error(request_dict[STR_CONTENT]) + case Codes.CODE_SUCCESS: + self._view._show_success(request_dict[STR_CONTENT]) except json.JSONDecodeError as e: self._logger.error(f"Invalid JSON format: {str(e)}") diff --git a/client/src/Communicator.py b/client/src/Communicator.py index e6275b7..9907a44 100644 --- a/client/src/Communicator.py +++ b/client/src/Communicator.py @@ -3,9 +3,8 @@ import json from .Logger import setup_logger from .utils import ( - DEFAULT_HOST, DEFAULT_PORT, DEFAULT_BUFFER_SIZE, ERR_SOCKET_NOT_SETUP, - STR_NETWORK + STR_NETWORK, STR_HOST, STR_PORT, STR_RECEIVE_BUFFER_SIZE ) class Communicator: @@ -22,9 +21,9 @@ def __init__(self, config_manager, message_callback: Callable[[str], None]) -> N self.config = config_manager.get_config() self._message_callback = message_callback - self._host = self.config[STR_NETWORK][DEFAULT_HOST] - self._port = self.config[STR_NETWORK][DEFAULT_PORT] - self._receive_buffer_size = self.config[STR_NETWORK][DEFAULT_BUFFER_SIZE] + self._host = self.config[STR_NETWORK][STR_HOST] + self._port = int(self.config[STR_NETWORK][STR_PORT]) + self._receive_buffer_size = int(self.config[STR_NETWORK][STR_RECEIVE_BUFFER_SIZE]) self._socket: Optional[socket.socket] = None def connect(self) -> None: @@ -42,12 +41,12 @@ def connect(self) -> None: self.logger.error(f"Failed to connect to server: {str(e)}") raise - def send_message(self, message: str) -> None: + def send_message(self, request: dict) -> None: """ - Send a json message to the server. + Send a json request to the server. Args: - message_json: The message to send to the server. + request: The request to send to the server. Raises: RuntimeError: If socket connection is not established. @@ -55,10 +54,10 @@ def send_message(self, message: str) -> None: self._validate_connection() try: - self._socket.send(message.encode('utf-8')) - self.logger.info(f"Message sent: {message}") + self._socket.send(json.dumps(request).encode('utf-8')) + self.logger.info(f"Request sent: {request}") except Exception as e: - self.logger.error(f"Failed to send message: {str(e)}") + self.logger.error(f"Failed to send request: {str(e)}") raise def receive_message(self) -> None: diff --git a/client/src/View.py b/client/src/View.py index 2943421..f973c41 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -10,8 +10,8 @@ Codes, WINDOW_SIZE, WINDOW_TITLE, ERR_DUPLICATE_DOMAIN, ERR_NO_DOMAIN_SELECTED, ERR_DOMAIN_LIST_UPDATE_FAILED, - STR_AD_BLOCK, STR_ADULT_BLOCK, STR_CODE, - STR_BLOCKED_DOMAINS, STR_CONTENT, STR_SETTINGS, STR_ERROR + STR_AD_BLOCK, STR_ADULT_BLOCK, STR_CODE, STR_BLOCKED_DOMAINS, + STR_CONTENT, STR_SETTINGS, STR_ERROR, STR_SUCCESS ) @@ -183,6 +183,16 @@ def _show_error(self, message: str) -> None: self.logger.error(f"Error message displayed: {message}") tk.messagebox.showerror(STR_ERROR, message) + def _show_success(self, message: str) -> None: + """ + Display a success message in a popup window. + + Args: + message: The success message to display. + """ + self.logger.info(f"Success message displayed: {message}") + tk.messagebox.showinfo(STR_SUCCESS, message) + def _setup_ui(self) -> None: """Set up the UI components including block controls and domain list.""" # Main container with increased padding diff --git a/client/src/utils.py b/client/src/utils.py index 31b59f2..e3c97b5 100644 --- a/client/src/utils.py +++ b/client/src/utils.py @@ -1,9 +1,9 @@ """Utility module containing constants and common functions for the application.""" # Network related constants -DEFAULT_HOST = "host" -DEFAULT_PORT = "port" -DEFAULT_BUFFER_SIZE = "receive_buffer_size" +DEFAULT_HOST: str = "127.0.0.1" +DEFAULT_PORT: str = "65432" +DEFAULT_BUFFER_SIZE: str = "1024" # GUI constants WINDOW_TITLE = "Site Blocker" @@ -19,24 +19,8 @@ class Codes: CODE_ADD_DOMAIN = "52" CODE_REMOVE_DOMAIN = "53" CODE_DOMAIN_LIST_UPDATE = "54" - -# Default settings -DEFAULT_CONFIG = { - "network": { - "host": DEFAULT_HOST, - "port": DEFAULT_PORT, - "receive_buffer_size": DEFAULT_BUFFER_SIZE - }, - "blocked_domains": {}, - "settings": { - "ad_block": "off", - "adult_block": "off" - }, - "logging": { - "level": "INFO", - "log_dir": "client_logs" - } -} + CODE_SUCCESS = "100" + CODE_ERROR = "101" # Logging constants LOG_DIR = "client_logs" @@ -56,9 +40,37 @@ class Codes: STR_CODE = "code" STR_CONTENT = "content" STR_ERROR = "Error" +STR_DOMAINS = "domains" +STR_SUCCESS = "Success" # Config Constants -STR_BLOCKED_DOMAINS = "blocked_domains" -STR_NETWORK = "network" -STR_SETTINGS = "settings" -STR_LOGGING = "logging" +STR_BLOCKED_DOMAINS = "blocked_domains" +STR_NETWORK = "network" +STR_SETTINGS = "settings" +STR_LOGGING = "logging" +STR_HOST = "host" +STR_PORT = "port" +STR_RECEIVE_BUFFER_SIZE = "receive_buffer_size" +STR_LEVEL = "level" +STR_LOG_DIR = "log_dir" + +# Default settings +DEFAULT_CONFIG = { + STR_NETWORK: { + STR_HOST: DEFAULT_HOST, + STR_PORT: DEFAULT_PORT, + STR_RECEIVE_BUFFER_SIZE: DEFAULT_BUFFER_SIZE + }, + + STR_BLOCKED_DOMAINS: {}, + + STR_SETTINGS: { + STR_AD_BLOCK: "off", + STR_ADULT_BLOCK: "off" + }, + + STR_LOGGING: { + STR_LEVEL: "INFO", + STR_LOG_DIR: LOG_DIR + } +} diff --git a/server/src/handlers.py b/server/src/handlers.py index 4aa8918..0299dce 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -1,6 +1,13 @@ from typing import Dict, Any from .db_manager import DatabaseManager -from .utils import Codes, RESPONSE_MESSAGES +from .utils import ( + Codes, + STR_AD_BLOCK, STR_ADULT_BLOCK, + STR_CODE, STR_CONTENT, STR_DOMAINS, STR_DOMAIN, + STR_DOMAIN_BLOCKED_MSG, STR_DOMAIN_NOT_FOUND_MSG, + STR_BLOCK, STR_UNBLOCK, STR_DOMAIN_UNBLOCKED_MSG, + invalid_json_response +) from .logger import setup_logger class RequestHandler: @@ -13,101 +20,87 @@ class AdBlockHandler(RequestHandler): def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle ad block requests.""" try: - if 'action' in request_data: - state = request_data['action'] - self.db_manager.update_setting('ad_block', state) + if STR_CONTENT in request_data: + state = request_data[STR_CONTENT] + self.db_manager.update_setting(STR_AD_BLOCK, state) self.logger.info(f"Ad blocking turned {state}") return { - 'code': Codes.CODE_AD_BLOCK, - 'message': f"Ad blocking turned {state}" + STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: f"Ad blocking turned {state}" } - return { - 'code': Codes.CODE_AD_BLOCK, - 'message': RESPONSE_MESSAGES['success'] - } + return invalid_json_response() except Exception as e: self.logger.error(f"Error in ad block handler: {e}") return { - 'code': Codes.CODE_AD_BLOCK, - 'message': str(e) + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e) } class AdultContentBlockHandler(RequestHandler): def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle adult content block requests.""" try: - if 'action' in request_data: - state = request_data['action'] - self.db_manager.update_setting('adult_block', state) + if STR_CONTENT in request_data: + state = request_data[STR_CONTENT] + self.db_manager.update_setting(STR_ADULT_BLOCK, state) self.logger.info(f"Adult content blocking turned {state}") return { - 'code': Codes.CODE_ADULT_BLOCK, - 'message': f"Adult content blocking turned {state}" + STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: f"Adult content blocking turned {state}" } - return { - 'code': Codes.CODE_ADULT_BLOCK, - 'message': RESPONSE_MESSAGES['success'] - } - + return invalid_json_response() + except Exception as e: self.logger.error(f"Error in adult content block handler: {e}") return { - 'code': Codes.CODE_ADULT_BLOCK, - 'message': str(e) + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e) } class DomainBlockHandler(RequestHandler): def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Handle domain blocking requests.""" try: - if 'action' not in request_data or 'domain' not in request_data: - self.logger.warning("Invalid request format: missing action or domain") - return { - 'code': Codes.CODE_ADD_DOMAIN, - 'message': RESPONSE_MESSAGES['invalid_request'] - } + if STR_CONTENT not in request_data: + self.logger.warning("Invalid request format: missing content") + return invalid_json_response() - domain = request_data['domain'] - action = request_data['action'] + match request_data[STR_CODE]: + case Codes.CODE_ADD_DOMAIN: + self.db_manager.add_blocked_domain(request_data[STR_CONTENT]) + self.logger.info(f"Domain blocked: {request_data[STR_CONTENT]}") + + return { + STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: STR_DOMAIN_BLOCKED_MSG + } - if action == 'block': - self.db_manager.add_blocked_domain(domain) + case Codes.CODE_REMOVE_DOMAIN: + if self.db_manager.remove_blocked_domain(request_data[STR_CONTENT]): + self.logger.info(f"Domain unblocked: {request_data[STR_CONTENT]}") - self.logger.info(f"Domain blocked: {domain}") + return { + STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: STR_DOMAIN_UNBLOCKED_MSG + } - return { - 'code': Codes.CODE_ADD_DOMAIN, - 'message': RESPONSE_MESSAGES['domain_blocked'] - } - - if action == 'unblock': - if self.db_manager.remove_blocked_domain(domain): - self.logger.info(f"Domain unblocked: {domain}") + self.logger.warning(f"Domain not found for unblocking: {request_data[STR_CONTENT]}") return { - 'code': Codes.CODE_REMOVE_DOMAIN, - 'message': RESPONSE_MESSAGES['success'] + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: STR_DOMAIN_NOT_FOUND_MSG } - - self.logger.warning(f"Domain not found for unblocking: {domain}") - return { - 'code': Codes.CODE_REMOVE_DOMAIN, - 'message': RESPONSE_MESSAGES['domain_not_found'] - } - self.logger.warning(f"Invalid action requested: {action}") - return { - 'code': Codes.CODE_ADD_DOMAIN, - 'message': RESPONSE_MESSAGES['invalid_request'] - } + self.logger.warning(f"Invalid action requested: {request_data[STR_CODE]}") + return invalid_json_response() except Exception as e: self.logger.error(f"Error in domain block handler: {e}") return { - 'code': Codes.CODE_ADD_DOMAIN, - 'message': str(e) + STR_CODE: Codes.CODE_ADD_DOMAIN, + STR_CONTENT: str(e) } class DomainListHandler(RequestHandler): @@ -117,15 +110,15 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: domains = self.db_manager.get_blocked_domains() self.logger.info(f"Domain list requested, returned {len(domains)} domains") return { - 'code': Codes.CODE_DOMAIN_LIST_UPDATE, - 'domains': domains, - 'message': RESPONSE_MESSAGES['success'] + STR_CODE: Codes.CODE_SUCCESS, + STR_DOMAINS: domains } + except Exception as e: self.logger.error(f"Error in domain list handler: {e}") return { - 'code': Codes.CODE_DOMAIN_LIST_UPDATE, - 'message': str(e) + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e) } class RequestFactory: @@ -144,19 +137,19 @@ def __init__(self, db_manager: DatabaseManager): def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """Route request to appropriate handler.""" try: - code = request_data.get('code') + code = request_data.get(STR_CODE) handler = self.handlers.get(code) if handler: self.logger.debug(f"Handling request with code: {code}") return handler.handle_request(request_data) - else: - self.logger.warning(f"Invalid request code: {code}") - return { - 'message': RESPONSE_MESSAGES['invalid_request'] - } + + self.logger.warning(f"Invalid request code: {code}") + return invalid_json_response() + except Exception as e: self.logger.error(f"Error in request factory: {e}") return { - 'message': str(e) - } \ No newline at end of file + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e) + } diff --git a/server/src/server.py b/server/src/server.py index 7a2498a..4d20125 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -3,7 +3,12 @@ import threading import json import asyncio -from .utils import HOST, CLIENT_PORT, KERNEL_PORT +from .utils import ( + CLIENT_PORT, DEFAULT_ADDRESS, KERNEL_PORT, + STR_AD_BLOCK, STR_ADULT_BLOCK, STR_CODE, STR_DOMAINS, STR_CONTENT, + STR_TOGGLE_ON, STR_TOGGLE_OFF, STR_DOMAIN, + Codes, invalid_json_response +) from .db_manager import DatabaseManager from .handlers import RequestFactory from .logger import setup_logger @@ -20,10 +25,10 @@ def __init__(self, db_manager: DatabaseManager) -> None: def handle_client_thread(self) -> None: """Handle client connections using traditional socket.""" client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - client_socket.bind((HOST, CLIENT_PORT)) + client_socket.bind((DEFAULT_ADDRESS, CLIENT_PORT)) client_socket.listen(1) client_socket.settimeout(1.0) - self.logger.info(f"Client server running on {HOST}:{CLIENT_PORT}") + self.logger.info(f"Client server running on {DEFAULT_ADDRESS}:{CLIENT_PORT}") try: while self.running: @@ -31,15 +36,13 @@ def handle_client_thread(self) -> None: conn, addr = client_socket.accept() self.logger.info(f"Client connected from {addr}") - # Set timeout for client connection as well conn.settimeout(1.0) try: - # Send initial domain list domains = self.db_manager.get_blocked_domains() conn.send(json.dumps({ - 'type': 'domain_list', - 'domains': domains + STR_CODE: Codes.CODE_DOMAIN_LIST_UPDATE, + STR_DOMAINS: domains }).encode() + b'\n') self.logger.debug(f"Sent initial domain list: {domains}") @@ -53,21 +56,19 @@ def handle_client_thread(self) -> None: request_data = json.loads(data.decode()) self.logger.debug(f"Received request: {request_data}") response = self.request_factory.handle_request(request_data) + conn.send(json.dumps(response).encode() + b'\n') self.logger.debug(f"Sent response: {response}") except json.JSONDecodeError: self.logger.error("Invalid JSON format received") - conn.send(json.dumps({ - 'status': 'error', - 'message': 'Invalid JSON format' - }).encode() + b'\n') + conn.send(json.dumps(invalid_json_response()).encode() + b'\n') except Exception as e: self.logger.error(f"Error handling request: {e}") conn.send(json.dumps({ - 'status': 'error', - 'message': str(e) + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e) }).encode() + b'\n') except socket.timeout: @@ -103,31 +104,27 @@ async def handle_kernel_requests( break request_data = json.loads(data.decode()) - domain = request_data.get('domain', '').strip() + domain = request_data.get(STR_DOMAIN, '').strip() if not domain: continue - # Get current settings state - ad_block_enabled = self.db_manager.get_setting('ad_block') == 'on' - adult_block_enabled = self.db_manager.get_setting('adult_block') == 'on' + ad_block_enabled = self.db_manager.get_setting(STR_AD_BLOCK) == STR_TOGGLE_ON + adult_block_enabled = self.db_manager.get_setting(STR_ADULT_BLOCK) == STR_TOGGLE_ON block_reason = None should_block = False - # Check custom blocked domains first if self.db_manager.is_domain_blocked(domain): should_block = True block_reason = "custom_blocklist" self.logger.info(f"Domain {domain} blocked (custom blocklist)") - # Check if ad blocking is enabled and domain is an ad elif ad_block_enabled and request_data.get('is_ad', False): should_block = True block_reason = "ads" self.logger.info(f"Domain {domain} blocked (ads)") - # Check adult content last if enabled elif adult_block_enabled and 'adult' in request_data.get('categories', []): should_block = True block_reason = "adult_content" @@ -157,18 +154,16 @@ async def start_server(self) -> None: kernel_server: Optional[asyncio.Server] = None try: - # Start client handler in a separate thread client_thread = threading.Thread(target=self.handle_client_thread) client_thread.start() self.logger.info("Client handler thread started") - # Run kernel handler with asyncio kernel_server = await asyncio.start_server( self.handle_kernel_requests, - HOST, + DEFAULT_ADDRESS, KERNEL_PORT ) - self.logger.info(f"Kernel server running on {HOST}:{KERNEL_PORT}") + self.logger.info(f"Kernel server running on {DEFAULT_ADDRESS}:{KERNEL_PORT}") async with kernel_server: await kernel_server.serve_forever() diff --git a/server/src/utils.py b/server/src/utils.py index 6592587..ce45733 100644 --- a/server/src/utils.py +++ b/server/src/utils.py @@ -1,34 +1,94 @@ +"""Utility module containing constants and common functions for the application.""" + import os from pathlib import Path -from typing import Dict + +# Network related constants +DEFAULT_ADDRESS: str = "127.0.0.1" +CLIENT_PORT: int = 65432 +KERNEL_PORT: int = 65433 +BUFFER_SIZE: int = 1024 # Base directories BASE_DIR = Path(__file__).parent.parent LOG_DIR = os.path.join(BASE_DIR, "logs") - -# Network Configuration -HOST: str = '127.0.0.1' -CLIENT_PORT: int = 65432 -KERNEL_PORT: int = 65433 DB_FILE: str = 'my_internet.db' -# Logging configuration +# Message codes +class Codes: + """Constants for message codes used in communication.""" + CODE_AD_BLOCK = "50" + CODE_ADULT_BLOCK = "51" + CODE_ADD_DOMAIN = "52" + CODE_REMOVE_DOMAIN = "53" + CODE_DOMAIN_LIST_UPDATE = "54" + CODE_SUCCESS = "100" + CODE_ERROR = "101" + CODE_ACK = "99" + +# Logging constants LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" LOG_DATE_FORMAT = "%Y%m%d_%H%M%S" -# Client Command Codes -class Codes: - CODE_AD_BLOCK = "50" - CODE_ADULT_BLOCK = "51" - CODE_ADD_DOMAIN = "52" - CODE_REMOVE_DOMAIN = "53" - CODE_DOMAIN_LIST_UPDATE = "54" +# Message Types and Codes +STR_CODE = "code" +STR_CONTENT = "content" +# STR_TYPE = "type" +# STR_ACTION = "action" +# STR_MESSAGE_ID = "message_id" +# STR_ACK = "ack" + +# Domain Related +STR_DOMAIN = "domain" +STR_DOMAINS = "domains" +STR_BLOCK = "block" +STR_UNBLOCK = "unblock" +# STR_IS_AD = "is_ad" +# STR_CATEGORIES = "categories" +# STR_REASON = "reason" + +# Features and Settings +STR_AD_BLOCK = "ad_block" +STR_ADULT_BLOCK = "adult_block" +STR_TOGGLE_ON = "on" +STR_TOGGLE_OFF = "off" + +# Status and Response Keys +STR_ERROR = "Error" +STR_SUCCESS = "success" +# STR_INVALID_REQUEST = "invalid_request" +# STR_DOMAIN_BLOCKED = "domain_blocked" +# STR_DOMAIN_NOT_FOUND = "domain_not_found" +# STR_INVALID_JSON = "invalid_json" + +# Block Reasons +# STR_CUSTOM_BLOCKLIST = "custom_blocklist" +# STR_ADS = "ads" +# STR_ADULT_CONTENT = "adult_content" +# STR_ALLOWED = "allowed" +# STR_DOMAIN_LIST = "domain_list" + +# Response Messages +STR_DOMAIN_BLOCKED_MSG = "Domain has been successfully blocked." +STR_DOMAIN_UNBLOCKED_MSG = "Domain has been successfully unblocked." +STR_DOMAIN_NOT_FOUND_MSG = "Domain not found in block list." +STR_INVALID_JSON_MSG = "Invalid JSON format." +# STR_REQUEST_PROCESSED = "Request processed successfully." +# STR_INVALID_REQUEST_FORMAT = "Invalid request format." +# STR_DOMAIN_EXISTS_MSG = "Domain already exists in block list." +# STR_ACK_TIMEOUT_MSG = "Acknowledgment timeout occurred." + +# Config Constants +# STR_BLOCKED_DOMAINS = "blocked_domains" +# STR_NETWORK = "network" +# STR_SETTINGS = "settings" +# STR_LOGGING = "logging" + +# Timeouts +# ACK_TIMEOUT = 5.0 # seconds -# Response messages -RESPONSE_MESSAGES = { - 'success': "Request processed successfully.", - 'invalid_request': "Invalid request format.", - 'domain_blocked': "Domain has been successfully blocked.", - 'domain_not_found': "Domain not found in block list.", - 'domain_exists': "Domain already exists in block list." -} \ No newline at end of file +def invalid_json_response(): + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: STR_INVALID_JSON_MSG + } From fce11b838c026cae21590854e4a45a637c95e18d Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Tue, 12 Nov 2024 15:08:52 +0200 Subject: [PATCH 33/38] change handlers response --- client/config.json | 9 ++++++--- server/src/handlers.py | 41 ++++++++++++++++++++++++++--------------- server/src/utils.py | 1 + 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/client/config.json b/client/config.json index a0dfe4e..8a456b5 100644 --- a/client/config.json +++ b/client/config.json @@ -7,11 +7,14 @@ "blocked_domains": { "example.com": true, "ads.example.com": true, - "fxp.co.il": true + "fxp.co.il": true, + "test.com": true, + "e.com": true, + "fxp.com": true }, "settings": { - "ad_block": "off", - "adult_block": "off" + "ad_block": "on", + "adult_block": "on" }, "logging": { "level": "INFO", diff --git a/server/src/handlers.py b/server/src/handlers.py index 0299dce..a20fcd6 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -3,9 +3,9 @@ from .utils import ( Codes, STR_AD_BLOCK, STR_ADULT_BLOCK, - STR_CODE, STR_CONTENT, STR_DOMAINS, STR_DOMAIN, + STR_CODE, STR_CONTENT, STR_DOMAINS, STR_DOMAIN_BLOCKED_MSG, STR_DOMAIN_NOT_FOUND_MSG, - STR_BLOCK, STR_UNBLOCK, STR_DOMAIN_UNBLOCKED_MSG, + STR_DOMAIN_UNBLOCKED_MSG, STR_OPERATION, invalid_json_response ) from .logger import setup_logger @@ -26,7 +26,8 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: self.logger.info(f"Ad blocking turned {state}") return { STR_CODE: Codes.CODE_SUCCESS, - STR_CONTENT: f"Ad blocking turned {state}" + STR_CONTENT: f"Ad blocking turned {state}", + STR_OPERATION: Codes.CODE_AD_BLOCK } return invalid_json_response() @@ -35,7 +36,8 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: self.logger.error(f"Error in ad block handler: {e}") return { STR_CODE: Codes.CODE_ERROR, - STR_CONTENT: str(e) + STR_CONTENT: str(e), + STR_OPERATION: Codes.CODE_AD_BLOCK } class AdultContentBlockHandler(RequestHandler): @@ -48,7 +50,8 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: self.logger.info(f"Adult content blocking turned {state}") return { STR_CODE: Codes.CODE_SUCCESS, - STR_CONTENT: f"Adult content blocking turned {state}" + STR_CONTENT: f"Adult content blocking turned {state}", + STR_OPERATION: Codes.CODE_ADULT_BLOCK } return invalid_json_response() @@ -57,7 +60,8 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: self.logger.error(f"Error in adult content block handler: {e}") return { STR_CODE: Codes.CODE_ERROR, - STR_CONTENT: str(e) + STR_CONTENT: str(e), + STR_OPERATION: Codes.CODE_ADULT_BLOCK } class DomainBlockHandler(RequestHandler): @@ -68,14 +72,16 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: self.logger.warning("Invalid request format: missing content") return invalid_json_response() - match request_data[STR_CODE]: + operation_code = request_data[STR_CODE] + match operation_code: case Codes.CODE_ADD_DOMAIN: self.db_manager.add_blocked_domain(request_data[STR_CONTENT]) self.logger.info(f"Domain blocked: {request_data[STR_CONTENT]}") return { STR_CODE: Codes.CODE_SUCCESS, - STR_CONTENT: STR_DOMAIN_BLOCKED_MSG + STR_CONTENT: STR_DOMAIN_BLOCKED_MSG, + STR_OPERATION: Codes.CODE_ADD_DOMAIN } case Codes.CODE_REMOVE_DOMAIN: @@ -84,23 +90,26 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: return { STR_CODE: Codes.CODE_SUCCESS, - STR_CONTENT: STR_DOMAIN_UNBLOCKED_MSG + STR_CONTENT: STR_DOMAIN_UNBLOCKED_MSG, + STR_OPERATION: Codes.CODE_REMOVE_DOMAIN } self.logger.warning(f"Domain not found for unblocking: {request_data[STR_CONTENT]}") return { STR_CODE: Codes.CODE_ERROR, - STR_CONTENT: STR_DOMAIN_NOT_FOUND_MSG + STR_CONTENT: STR_DOMAIN_NOT_FOUND_MSG, + STR_OPERATION: Codes.CODE_REMOVE_DOMAIN } self.logger.warning(f"Invalid action requested: {request_data[STR_CODE]}") return invalid_json_response() - + except Exception as e: self.logger.error(f"Error in domain block handler: {e}") return { - STR_CODE: Codes.CODE_ADD_DOMAIN, - STR_CONTENT: str(e) + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e), + STR_OPERATION: operation_code } class DomainListHandler(RequestHandler): @@ -111,14 +120,16 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: self.logger.info(f"Domain list requested, returned {len(domains)} domains") return { STR_CODE: Codes.CODE_SUCCESS, - STR_DOMAINS: domains + STR_DOMAINS: domains, + STR_OPERATION: Codes.CODE_DOMAIN_LIST_UPDATE } except Exception as e: self.logger.error(f"Error in domain list handler: {e}") return { STR_CODE: Codes.CODE_ERROR, - STR_CONTENT: str(e) + STR_CONTENT: str(e), + STR_OPERATION: Codes.CODE_DOMAIN_LIST_UPDATE } class RequestFactory: diff --git a/server/src/utils.py b/server/src/utils.py index ce45733..aa10f12 100644 --- a/server/src/utils.py +++ b/server/src/utils.py @@ -33,6 +33,7 @@ class Codes: # Message Types and Codes STR_CODE = "code" STR_CONTENT = "content" +STR_OPERATION = "operation" # STR_TYPE = "type" # STR_ACTION = "action" # STR_MESSAGE_ID = "message_id" From 886ccd4894098ddbf97190617d4a1c76557ddebd Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Tue, 12 Nov 2024 16:18:14 +0200 Subject: [PATCH 34/38] add method to send init settings --- server/src/server.py | 37 +++++++++++++++++++++++-------------- server/src/utils.py | 2 +- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/server/src/server.py b/server/src/server.py index 4d20125..640fdac 100644 --- a/server/src/server.py +++ b/server/src/server.py @@ -6,7 +6,7 @@ from .utils import ( CLIENT_PORT, DEFAULT_ADDRESS, KERNEL_PORT, STR_AD_BLOCK, STR_ADULT_BLOCK, STR_CODE, STR_DOMAINS, STR_CONTENT, - STR_TOGGLE_ON, STR_TOGGLE_OFF, STR_DOMAIN, + STR_TOGGLE_ON, STR_TOGGLE_OFF, STR_DOMAIN, Codes, invalid_json_response ) from .db_manager import DatabaseManager @@ -39,12 +39,10 @@ def handle_client_thread(self) -> None: conn.settimeout(1.0) try: - domains = self.db_manager.get_blocked_domains() - conn.send(json.dumps({ - STR_CODE: Codes.CODE_DOMAIN_LIST_UPDATE, - STR_DOMAINS: domains - }).encode() + b'\n') - self.logger.debug(f"Sent initial domain list: {domains}") + # Send initial settings + initial_settings = self._get_initial_settings() + conn.send(json.dumps(initial_settings).encode() + b'\n') + self.logger.debug(f"Sent initial settings: {initial_settings}") while True: try: @@ -64,13 +62,6 @@ def handle_client_thread(self) -> None: self.logger.error("Invalid JSON format received") conn.send(json.dumps(invalid_json_response()).encode() + b'\n') - except Exception as e: - self.logger.error(f"Error handling request: {e}") - conn.send(json.dumps({ - STR_CODE: Codes.CODE_ERROR, - STR_CONTENT: str(e) - }).encode() + b'\n') - except socket.timeout: if not self.running: break @@ -180,6 +171,24 @@ async def start_server(self) -> None: if client_thread and client_thread.is_alive(): client_thread.join(timeout=1.0) + def _get_initial_settings(self) -> Dict[str, Any]: + """Get initial settings and domain list for client initialization.""" + try: + domains = self.db_manager.get_blocked_domains() + settings = { + STR_AD_BLOCK: self.db_manager.get_setting(STR_AD_BLOCK), + STR_ADULT_BLOCK: self.db_manager.get_setting(STR_ADULT_BLOCK) + } + + return { + STR_CODE: Codes.CODE_INIT_SETTINGS, + STR_DOMAINS: domains, + Codes.CODE_INIT_SETTINGS: settings + } + except Exception as e: + self.logger.error(f"Error getting initial settings: {e}") + return invalid_json_response() + def initialize_server(db_file: str) -> None: """Initialize and run the server.""" db_manager = DatabaseManager(db_file) diff --git a/server/src/utils.py b/server/src/utils.py index aa10f12..3db8c35 100644 --- a/server/src/utils.py +++ b/server/src/utils.py @@ -25,7 +25,7 @@ class Codes: CODE_SUCCESS = "100" CODE_ERROR = "101" CODE_ACK = "99" - + CODE_INIT_SETTINGS = "55" # Logging constants LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" LOG_DATE_FORMAT = "%Y%m%d_%H%M%S" From 8e3b20bd2a2d46355d4c31708337fee46779eb8b Mon Sep 17 00:00:00 2001 From: elipaz Date: Tue, 12 Nov 2024 16:18:37 +0200 Subject: [PATCH 35/38] Modify client code. modify the handle request in Application.py some changes in view.py that will suit the changes above. --- client/config.json | 9 - client/src/Application.py | 40 +++-- client/src/Communicator.py | 2 +- client/src/View.py | 290 ++++++++++++++++++++++--------- client/src/utils.py | 18 +- client/tests/test_application.py | 2 +- client/tests/test_view.py | 10 +- server/src/handlers.py | 2 +- 8 files changed, 250 insertions(+), 123 deletions(-) diff --git a/client/config.json b/client/config.json index a0dfe4e..1036558 100644 --- a/client/config.json +++ b/client/config.json @@ -4,15 +4,6 @@ "port": 65432, "receive_buffer_size": 1024 }, - "blocked_domains": { - "example.com": true, - "ads.example.com": true, - "fxp.co.il": true - }, - "settings": { - "ad_block": "off", - "adult_block": "off" - }, "logging": { "level": "INFO", "log_dir": "client_logs" diff --git a/client/src/Application.py b/client/src/Application.py index bab6116..aeabcb0 100644 --- a/client/src/Application.py +++ b/client/src/Application.py @@ -6,7 +6,7 @@ from .ConfigManager import ConfigManager from .utils import ( - STR_CODE, STR_DOMAINS, STR_CONTENT, + STR_DOMAINS, STR_OPERATION, Codes ) @@ -26,7 +26,6 @@ def __init__(self) -> None: """Initialize application components.""" self._logger = setup_logger(__name__) self._config_manager = ConfigManager() - self._request_lock = threading.Lock() self._view = Viewer(config_manager=self._config_manager, message_callback=self._handle_request) self._communicator = Communicator(config_manager=self._config_manager, message_callback=self._handle_request) @@ -74,7 +73,7 @@ def _start_gui(self) -> None: self._logger.error(f"Failed to start GUI: {str(e)}") raise - def _handle_request(self, request: str) -> None: + def _handle_request(self, request: str, server: bool = True) -> None: """ Handle outgoing messages from the UI and Server. @@ -85,22 +84,25 @@ def _handle_request(self, request: str) -> None: self._logger.info(f"Processing request: {request}") request_dict = json.loads(request) - with self._request_lock: - match request_dict[STR_CODE]: - ## Codes sent to server - case Codes.CODE_AD_BLOCK | \ - Codes.CODE_ADULT_BLOCK | \ - Codes.CODE_ADD_DOMAIN | \ - Codes.CODE_REMOVE_DOMAIN: - self._communicator.send_message(request_dict) - ## Codes received from server - case Codes.CODE_DOMAIN_LIST_UPDATE: - self._view.update_domain_list(request_dict[STR_DOMAINS]) - case Codes.CODE_ERROR: - self._view._show_error(request_dict[STR_CONTENT]) - case Codes.CODE_SUCCESS: - self._view._show_success(request_dict[STR_CONTENT]) - + if server: + message = request if isinstance(request, dict) else json.loads(request) + self._communicator.send_message(message) + return + + match request_dict[STR_OPERATION]: + case Codes.CODE_INIT_SETTINGS: + self._view.update_initial_settings(request_dict) + case Codes.CODE_AD_BLOCK: + self._view.ad_block_response(request_dict) + case Codes.CODE_ADULT_BLOCK: + self._view.adult_block_response(request_dict) + case Codes.CODE_ADD_DOMAIN: + self._view.add_domain_response(request_dict) + case Codes.CODE_REMOVE_DOMAIN: + self._view.remove_domain_response(request_dict) + case Codes.CODE_DOMAIN_LIST_UPDATE: + self._view.update_domain_list_response(request_dict[STR_DOMAINS]) + except json.JSONDecodeError as e: self._logger.error(f"Invalid JSON format: {str(e)}") raise diff --git a/client/src/Communicator.py b/client/src/Communicator.py index 9907a44..1713641 100644 --- a/client/src/Communicator.py +++ b/client/src/Communicator.py @@ -81,7 +81,7 @@ def receive_message(self) -> None: break message = message_bytes.decode('utf-8') self.logger.info(f"Received message: {message}") - self._message_callback(message) + self._message_callback(message, False) except Exception as e: self.logger.error(f"Error receiving message: {str(e)}") raise diff --git a/client/src/View.py b/client/src/View.py index f973c41..e303a5f 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -2,19 +2,20 @@ from tkinter import ttk, messagebox from typing import Callable, List import json -import threading from .Logger import setup_logger from .ConfigManager import ConfigManager from .utils import ( Codes, WINDOW_SIZE, WINDOW_TITLE, - ERR_DUPLICATE_DOMAIN, ERR_NO_DOMAIN_SELECTED, ERR_DOMAIN_LIST_UPDATE_FAILED, + ERR_NO_DOMAIN_SELECTED, ERR_DOMAIN_LIST_UPDATE_FAILED, STR_AD_BLOCK, STR_ADULT_BLOCK, STR_CODE, STR_BLOCKED_DOMAINS, - STR_CONTENT, STR_SETTINGS, STR_ERROR, STR_SUCCESS + STR_CONTENT, STR_SETTINGS, STR_ERROR, STR_SUCCESS, + STR_ADD_DOMAIN_RESPONSE, STR_REMOVE_DOMAIN_REQUEST, STR_ADD_DOMAIN_REQUEST, + STR_AD_BLOCK_RESPONSE, STR_ADULT_BLOCK_RESPONSE, STR_REMOVE_DOMAIN_RESPONSE, + STR_DOMAINS, ) - class Viewer: """ Graphical user interface for the application. @@ -33,7 +34,6 @@ def __init__(self, config_manager: ConfigManager, message_callback: Callable[[st self.config_manager = config_manager self.config = config_manager.get_config() self._message_callback = message_callback - self._update_list_lock = threading.Lock() # Initialize root window first self.root: tk.Tk = tk.Tk() @@ -82,115 +82,251 @@ def get_block_settings(self) -> dict[str, str]: STR_ADULT_BLOCK: self.adult_var.get() } - def update_domain_list(self, domains: List[str]) -> None: + def update_domain_list_response(self, domains: List[str]) -> None: """ Update the domains listbox with a new list of domains from the server. Args: domains: List of domain strings to be displayed in the listbox. """ - with self._update_list_lock: - self.logger.info("Updating domain list from server") + self.logger.info("Updating domain list from server") + + try: + self.domains_listbox.delete(0, tk.END) + + for domain in domains: + self.domains_listbox.insert(tk.END, domain) + except Exception as e: + self.logger.error(f"Error updating domain list: {str(e)}") + self._show_error(ERR_DOMAIN_LIST_UPDATE_FAILED) + return + + self.logger.info(f"Updated domain list with {len(domains)} domains") + + def add_domain_response(self, response: dict) -> None: + """ + Handle the response from the server after attempting to add a domain. + + Args: + response: Dictionary containing the server's response with code and content. + """ try: - self.domains_listbox.delete(0, tk.END) - - for domain in domains: - self.domains_listbox.insert(tk.END, domain) - - self.logger.info(f"Updated domain list with {len(domains)} domains") + match response[STR_CODE]: + case Codes.CODE_SUCCESS: + domain = response[STR_CONTENT] + with self._update_list_lock: + self.domains_listbox.insert(tk.END, domain) + self.domain_entry.delete(0, tk.END) + + self._show_success( + message=f"Domain '{domain}' added successfully", + operation=STR_ADD_DOMAIN_RESPONSE + ) + case Codes.CODE_ERROR: + self._show_error( + message=response[STR_CONTENT], + operation=STR_ADD_DOMAIN_RESPONSE + ) + except Exception as e: - self.logger.error(f"Error updating domain list: {str(e)}") - self._show_error(ERR_DOMAIN_LIST_UPDATE_FAILED) + self._show_error( + message="An unexpected error occurred", + operation=f"Processing add domain response: {str(e)}" + ) + + def ad_block_response(self, response: dict) -> None: + """ + Handle the response from the server after changing ad block setting. + + Args: + response: Dictionary containing the server's response with code and content. + """ + prev_state = "off" if self.ad_var.get() == "on" else "on" + + try: + match response[STR_CODE]: + case Codes.CODE_SUCCESS: + self._show_success( + message=f"Ad blocking turned {self.ad_var.get()}", + operation=STR_AD_BLOCK_RESPONSE + ) + case Codes.CODE_ERROR: + self.ad_var.set(prev_state) + self._show_error( + message=response[STR_CONTENT], + operation=STR_AD_BLOCK_RESPONSE + ) + except Exception as e: + self.ad_var.set(prev_state) + self._show_error( + message="An unexpected error occurred", + operation=f"Processing ad block response: {str(e)}" + ) - def _add_domain(self) -> None: + def adult_block_response(self, response: dict) -> None: + """ + Handle the response from the server after changing adult block setting. + + Args: + response: Dictionary containing the server's response with code and content. + """ + prev_state = "off" if self.adult_var.get() == "on" else "on" + + try: + match response[STR_CODE]: + case Codes.CODE_SUCCESS: + self._show_success( + message=f"Adult content blocking turned {self.adult_var.get()}", + operation=STR_ADULT_BLOCK_RESPONSE + ) + case Codes.CODE_ERROR: + self.adult_var.set(prev_state) + self._show_error( + message=response[STR_CONTENT], + operation=STR_ADULT_BLOCK_RESPONSE + ) + except Exception as e: + self.adult_var.set(prev_state) + self._show_error( + message="An unexpected error occurred", + operation=f"Processing adult block response: {str(e)}" + ) + + def remove_domain_response(self, response: dict) -> None: + """ + Handle the response from the server after removing a domain. + + Args: + response: Dictionary containing the server's response with code and content. + """ + try: + match response[STR_CODE]: + case Codes.CODE_SUCCESS: + domain = response[STR_CONTENT] + self.domains_listbox.delete(self.domains_listbox.curselection()) + self._show_success( + message=f"Domain '{domain}' removed successfully", + operation=STR_REMOVE_DOMAIN_RESPONSE + ) + case Codes.CODE_ERROR: + self._show_error( + message=response[STR_CONTENT], + operation=STR_REMOVE_DOMAIN_RESPONSE + ) + except Exception as e: + self._show_error( + message="An unexpected error occurred", + operation=f"Processing remove domain response: {str(e)}" + ) + + def update_initial_settings(self, response: dict) -> None: + """ + Update all initial settings from server response. + + Args: + response: Dictionary containing initial settings: + - domains: List of blocked domains + - settings: Dictionary with ad_block and adult_block states + """ + try: + self.root.after(0, lambda: self.update_domain_list_response(response[STR_DOMAINS])) + self.root.after(0, lambda: self._update_block_settings(response[STR_SETTINGS])) + + self.logger.info("Successfully initialized settings from server") + + except Exception as e: + self._show_error( + message="Failed to initialize settings", + operation=f"Initial settings update: {str(e)}" + ) + + def _add_domain_request(self) -> None: """Add a domain to the blocked sites list.""" domain = self.domain_entry.get().strip() if domain: - if domain not in self.config[STR_BLOCKED_DOMAINS]: - self.domains_listbox.insert(tk.END, domain) - self.domain_entry.delete(0, tk.END) - - self.config[STR_BLOCKED_DOMAINS][domain] = True - self.config_manager.save_config(self.config) - - self._message_callback(json.dumps({ - STR_CODE: Codes.CODE_ADD_DOMAIN, - STR_CONTENT: domain - })) + self.logger.debug(f"Sending add domain request for: {domain}") + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_ADD_DOMAIN, + STR_CONTENT: domain + })) + else: + self._show_error( + message="Please enter a domain name", + operation=STR_ADD_DOMAIN_REQUEST + ) - self.logger.info(f"Domain added: {domain}") - else: - self.logger.warning(f"Attempted to add duplicate domain: {domain}") - self._show_error(ERR_DUPLICATE_DOMAIN) - - def _remove_domain(self) -> None: + def _remove_domain_request(self) -> None: """Remove the selected domain from the blocked sites list.""" selection = self.domains_listbox.curselection() if selection: domain = self.domains_listbox.get(selection) - self.domains_listbox.delete(selection) - - del self.config[STR_BLOCKED_DOMAINS][domain] - self.config_manager.save_config(self.config) - + self.logger.debug(f"Sending remove domain request for: {domain}") self._message_callback(json.dumps({ STR_CODE: Codes.CODE_REMOVE_DOMAIN, STR_CONTENT: domain - })) - - self.logger.info(f"Domain removed: {domain}") + })) else: - self.logger.warning("Attempted to remove domain without selection") - self._show_error(ERR_NO_DOMAIN_SELECTED) + self._show_error( + message=ERR_NO_DOMAIN_SELECTED, + operation=STR_REMOVE_DOMAIN_REQUEST + ) - def _handle_ad_block(self) -> None: + def _handle_ad_block_request(self) -> None: """Handle changes to the ad block setting.""" state = self.ad_var.get() - self.config[STR_SETTINGS][STR_AD_BLOCK] = state - self.config_manager.save_config(self.config) - + self.logger.debug(f"Sending ad block request: {state}") + self._message_callback(json.dumps({ STR_CODE: Codes.CODE_AD_BLOCK, STR_CONTENT: state - })) - - self.logger.info(f"Ad blocking state changed to: {state}") + })) - def _handle_adult_block(self) -> None: + def _handle_adult_block_request(self) -> None: """Handle changes to the adult sites block setting.""" state = self.adult_var.get() - self.config[STR_SETTINGS][STR_ADULT_BLOCK] = state - self.config_manager.save_config(self.config) - + self.logger.debug(f"Sending adult block request: {state}") + self._message_callback(json.dumps({ STR_CODE: Codes.CODE_ADULT_BLOCK, STR_CONTENT: state - })) - - self.logger.info(f"Adult site blocking state changed to: {state}") + })) - def _show_error(self, message: str) -> None: + def _update_block_settings(self, settings: dict) -> None: + """Update the block settings radio buttons.""" + self.ad_var.set(settings[STR_AD_BLOCK]) + self.adult_var.set(settings[STR_ADULT_BLOCK]) + + def _show_error(self, message: str, operation: str = "") -> None: """ - Display an error message in a popup window. + Display and log an error message for an operation. Args: - message: The error message to display. + message: The error message to display to the user. + operation: Optional description of the operation that failed. + If provided, will be included in the log message. """ - self.logger.error(f"Error message displayed: {message}") + if operation: + self.logger.error(f"Operation failed: {operation} - Error: {message}") + else: + self.logger.error(f"Error: {message}") + tk.messagebox.showerror(STR_ERROR, message) - def _show_success(self, message: str) -> None: + def _show_success(self, message: str, operation: str = "") -> None: """ - Display a success message in a popup window. + Display and log a success message for an operation. Args: - message: The success message to display. + message: The success message to display to the user. + operation: Optional description of the operation that succeeded. + If provided, will be included in the log message. """ - self.logger.info(f"Success message displayed: {message}") + log_message = f"Operation successful: {operation}" if operation else message + self.logger.info(log_message) tk.messagebox.showinfo(STR_SUCCESS, message) def _setup_ui(self) -> None: @@ -291,14 +427,14 @@ def _setup_ui(self) -> None: button_frame, text="Add Domain", style='Action.TButton', - command=self._add_domain + command=self._add_domain_request ).grid(row=0, column=0, padx=5) ttk.Button( button_frame, text="Remove Domain", style='Action.TButton', - command=self._remove_domain + command=self._remove_domain_request ).grid(row=0, column=1, padx=5) # Right side controls with improved spacing @@ -324,20 +460,20 @@ def _setup_ui(self) -> None: ) # Initialize with config value - self.ad_var = tk.StringVar(value=self.config[STR_SETTINGS][STR_AD_BLOCK]) + self.ad_var = tk.StringVar() ttk.Radiobutton( ad_frame, text="Enable", value="on", variable=self.ad_var, - command=self._handle_ad_block + command=self._handle_ad_block_request ).grid(row=0, column=0, padx=10) ttk.Radiobutton( ad_frame, text="Disable", value="off", variable=self.ad_var, - command=self._handle_ad_block + command=self._handle_ad_block_request ).grid(row=0, column=1, padx=10) # Adult sites Block controls @@ -354,20 +490,20 @@ def _setup_ui(self) -> None: ) # Initialize with config value - self.adult_var = tk.StringVar(value=self.config[STR_SETTINGS][STR_ADULT_BLOCK]) + self.adult_var = tk.StringVar() ttk.Radiobutton( adult_frame, text="Enable", value="on", variable=self.adult_var, - command=self._handle_adult_block + command=self._handle_adult_block_request ).grid(row=0, column=0, padx=10) ttk.Radiobutton( adult_frame, text="Disable", value="off", variable=self.adult_var, - command=self._handle_adult_block + command=self._handle_adult_block_request ).grid(row=0, column=1, padx=10) # Configure grid weights for better resizing @@ -381,8 +517,4 @@ def _setup_ui(self) -> None: button_frame.columnconfigure(1, weight=1) # Bind events - self.domains_listbox.bind('', lambda e: self._remove_domain()) - - # Load saved domains - for domain in self.config[STR_BLOCKED_DOMAINS].keys(): - self.domains_listbox.insert(tk.END, domain) + self.domains_listbox.bind('', lambda e: self._remove_domain_request()) diff --git a/client/src/utils.py b/client/src/utils.py index e3c97b5..d9c9eca 100644 --- a/client/src/utils.py +++ b/client/src/utils.py @@ -19,6 +19,7 @@ class Codes: CODE_ADD_DOMAIN = "52" CODE_REMOVE_DOMAIN = "53" CODE_DOMAIN_LIST_UPDATE = "54" + CODE_INIT_SETTINGS = "55" CODE_SUCCESS = "100" CODE_ERROR = "101" @@ -42,6 +43,15 @@ class Codes: STR_ERROR = "Error" STR_DOMAINS = "domains" STR_SUCCESS = "Success" +STR_OPERATION = "operation" + +# Operation constants +STR_REMOVE_DOMAIN_REQUEST = "Remove domain request" +STR_ADD_DOMAIN_REQUEST = "Add domain request" +STR_AD_BLOCK_RESPONSE = "Ad block response" +STR_ADD_DOMAIN_RESPONSE = "Add domain response" +STR_ADULT_BLOCK_RESPONSE = "Adult block response" +STR_REMOVE_DOMAIN_RESPONSE = "Remove domain response" # Config Constants STR_BLOCKED_DOMAINS = "blocked_domains" @@ -61,14 +71,6 @@ class Codes: STR_PORT: DEFAULT_PORT, STR_RECEIVE_BUFFER_SIZE: DEFAULT_BUFFER_SIZE }, - - STR_BLOCKED_DOMAINS: {}, - - STR_SETTINGS: { - STR_AD_BLOCK: "off", - STR_ADULT_BLOCK: "off" - }, - STR_LOGGING: { STR_LEVEL: "INFO", STR_LOG_DIR: LOG_DIR diff --git a/client/tests/test_application.py b/client/tests/test_application.py index f55c206..eaabd55 100644 --- a/client/tests/test_application.py +++ b/client/tests/test_application.py @@ -90,7 +90,7 @@ def test_handle_request_domain_list_update(application: Application) -> None: }) application._handle_request(test_request) - application._view.update_domain_list.assert_called_once_with(test_content) + application._view.update_domain_list_response.assert_called_once_with(test_content) def test_cleanup(application: Application) -> None: diff --git a/client/tests/test_view.py b/client/tests/test_view.py index f83f0b9..c611e5c 100644 --- a/client/tests/test_view.py +++ b/client/tests/test_view.py @@ -86,7 +86,7 @@ def test_handle_ad_block(viewer: Viewer) -> None: """Test handling ad block setting changes.""" # Configure the mock StringVar to return "on" viewer.ad_var.get.return_value = "on" - viewer._handle_ad_block() + viewer._handle_ad_block_request() expected_json = json.dumps({ STR_CODE: Codes.CODE_AD_BLOCK, @@ -101,7 +101,7 @@ def test_handle_adult_block(viewer: Viewer) -> None: """Test handling adult block setting changes.""" # Configure the mock StringVar to return "on" viewer.adult_var.get.return_value = "on" - viewer._handle_adult_block() + viewer._handle_adult_block_request() expected_json = json.dumps({ STR_CODE: Codes.CODE_ADULT_BLOCK, @@ -116,7 +116,7 @@ def test_add_domain(viewer: Viewer) -> None: """Test adding a domain.""" domain = "test.com" viewer.domain_entry.get.return_value = domain - viewer._add_domain() + viewer._add_domain_request() expected_json = json.dumps({ STR_CODE: Codes.CODE_ADD_DOMAIN, @@ -133,7 +133,7 @@ def test_add_duplicate_domain(viewer: Viewer) -> None: viewer.config[STR_BLOCKED_DOMAINS][domain] = True viewer.domain_entry.get.return_value = domain - viewer._add_domain() + viewer._add_domain_request() viewer._message_callback.assert_not_called() viewer._show_error.assert_called_once_with(ERR_DUPLICATE_DOMAIN) @@ -146,7 +146,7 @@ def test_remove_domain(viewer: Viewer) -> None: viewer.domains_listbox.curselection.return_value = (0,) viewer.domains_listbox.get.return_value = domain - viewer._remove_domain() + viewer._remove_domain_request() expected_json = json.dumps({ STR_CODE: Codes.CODE_REMOVE_DOMAIN, diff --git a/server/src/handlers.py b/server/src/handlers.py index 0299dce..e99f8d9 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -75,7 +75,7 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: return { STR_CODE: Codes.CODE_SUCCESS, - STR_CONTENT: STR_DOMAIN_BLOCKED_MSG + STR_CONTENT: STR_DOMAIN_BLOCKED_MSG, } case Codes.CODE_REMOVE_DOMAIN: From e956888cab6c480618618348e40e895e0b7393d9 Mon Sep 17 00:00:00 2001 From: elipaz Date: Tue, 12 Nov 2024 16:56:58 +0200 Subject: [PATCH 36/38] Modify code, content in response msg --- client/src/View.py | 5 ++--- server/src/handlers.py | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/client/src/View.py b/client/src/View.py index e303a5f..6dc6505 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -115,9 +115,8 @@ def add_domain_response(self, response: dict) -> None: match response[STR_CODE]: case Codes.CODE_SUCCESS: domain = response[STR_CONTENT] - with self._update_list_lock: - self.domains_listbox.insert(tk.END, domain) - self.domain_entry.delete(0, tk.END) + self.domains_listbox.insert(tk.END, domain) + self.domain_entry.delete(0, tk.END) self._show_success( message=f"Domain '{domain}' added successfully", diff --git a/server/src/handlers.py b/server/src/handlers.py index a20fcd6..cd3668a 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -73,31 +73,33 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: return invalid_json_response() operation_code = request_data[STR_CODE] + domain = request_data[STR_CONTENT] + match operation_code: case Codes.CODE_ADD_DOMAIN: - self.db_manager.add_blocked_domain(request_data[STR_CONTENT]) - self.logger.info(f"Domain blocked: {request_data[STR_CONTENT]}") + self.db_manager.add_blocked_domain(domain) + self.logger.info(f"Domain blocked: {domain}") return { STR_CODE: Codes.CODE_SUCCESS, - STR_CONTENT: STR_DOMAIN_BLOCKED_MSG, + STR_CONTENT: domain, STR_OPERATION: Codes.CODE_ADD_DOMAIN } case Codes.CODE_REMOVE_DOMAIN: - if self.db_manager.remove_blocked_domain(request_data[STR_CONTENT]): - self.logger.info(f"Domain unblocked: {request_data[STR_CONTENT]}") + if self.db_manager.remove_blocked_domain(domain): + self.logger.info(f"Domain unblocked: {domain}") return { STR_CODE: Codes.CODE_SUCCESS, - STR_CONTENT: STR_DOMAIN_UNBLOCKED_MSG, + STR_CONTENT: domain, STR_OPERATION: Codes.CODE_REMOVE_DOMAIN } - self.logger.warning(f"Domain not found for unblocking: {request_data[STR_CONTENT]}") + self.logger.warning(f"Domain not found for unblocking: {domain}") return { STR_CODE: Codes.CODE_ERROR, - STR_CONTENT: STR_DOMAIN_NOT_FOUND_MSG, + STR_CONTENT: domain, STR_OPERATION: Codes.CODE_REMOVE_DOMAIN } From 460915543cf353aff939610bb23ec34f4b507947 Mon Sep 17 00:00:00 2001 From: yoaz11 <129755179+yoaz11@users.noreply.github.com> Date: Tue, 12 Nov 2024 17:16:42 +0200 Subject: [PATCH 37/38] Bug fix - duplicates domains --- server/src/handlers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/server/src/handlers.py b/server/src/handlers.py index cd3668a..d273392 100644 --- a/server/src/handlers.py +++ b/server/src/handlers.py @@ -77,9 +77,16 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: match operation_code: case Codes.CODE_ADD_DOMAIN: + if self.db_manager.is_domain_blocked(domain): + self.logger.warning(f"Domain already blocked: {domain}") + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: f"Domain {domain} is already blocked", + STR_OPERATION: Codes.CODE_ADD_DOMAIN + } + self.db_manager.add_blocked_domain(domain) self.logger.info(f"Domain blocked: {domain}") - return { STR_CODE: Codes.CODE_SUCCESS, STR_CONTENT: domain, @@ -89,7 +96,6 @@ def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: case Codes.CODE_REMOVE_DOMAIN: if self.db_manager.remove_blocked_domain(domain): self.logger.info(f"Domain unblocked: {domain}") - return { STR_CODE: Codes.CODE_SUCCESS, STR_CONTENT: domain, From 466fe6c98f48ea39d7ee990d319de0a47adae2b9 Mon Sep 17 00:00:00 2001 From: elipaz Date: Thu, 14 Nov 2024 12:38:59 +0200 Subject: [PATCH 38/38] Modify tests --- client/src/Application.py | 4 +- client/src/View.py | 2 +- client/tests/test_application.py | 110 +++----------- client/tests/test_communicator.py | 124 +++------------ client/tests/test_view.py | 243 +++++++++++++++--------------- 5 files changed, 170 insertions(+), 313 deletions(-) diff --git a/client/src/Application.py b/client/src/Application.py index aeabcb0..dd16448 100644 --- a/client/src/Application.py +++ b/client/src/Application.py @@ -73,7 +73,7 @@ def _start_gui(self) -> None: self._logger.error(f"Failed to start GUI: {str(e)}") raise - def _handle_request(self, request: str, server: bool = True) -> None: + def _handle_request(self, request: str, to_server: bool = True) -> None: """ Handle outgoing messages from the UI and Server. @@ -84,7 +84,7 @@ def _handle_request(self, request: str, server: bool = True) -> None: self._logger.info(f"Processing request: {request}") request_dict = json.loads(request) - if server: + if to_server: message = request if isinstance(request, dict) else json.loads(request) self._communicator.send_message(message) return diff --git a/client/src/View.py b/client/src/View.py index 6dc6505..8e546e7 100644 --- a/client/src/View.py +++ b/client/src/View.py @@ -82,7 +82,7 @@ def get_block_settings(self) -> dict[str, str]: STR_ADULT_BLOCK: self.adult_var.get() } - def update_domain_list_response(self, domains: List[str]) -> None: + def update_domain_list_response(self, domains: list[str]) -> None: """ Update the domains listbox with a new list of domains from the server. diff --git a/client/tests/test_application.py b/client/tests/test_application.py index eaabd55..c139ffe 100644 --- a/client/tests/test_application.py +++ b/client/tests/test_application.py @@ -1,141 +1,71 @@ -import logging -from unittest import mock -from typing import Optional, Callable import json - import pytest +from unittest import mock +from typing import Dict, Any from src.Application import Application -from src.View import Viewer -from src.Communicator import Communicator from src.utils import ( - STR_CODE, STR_CONTENT, + STR_CODE, STR_CONTENT, STR_OPERATION, STR_DOMAINS, Codes, DEFAULT_CONFIG ) - @pytest.fixture def mock_config_manager() -> mock.Mock: """Fixture to provide a mock configuration manager.""" config_manager = mock.Mock() - config_manager.get_config.return_value = DEFAULT_CONFIG + config_manager.get_config.return_value = DEFAULT_CONFIG.copy() return config_manager - @pytest.fixture def application(mock_config_manager: mock.Mock) -> Application: - """Fixture to create an Application instance.""" + """Fixture to create an Application instance with mocked components.""" with mock.patch('src.Application.Viewer') as mock_viewer, \ mock.patch('src.Application.Communicator') as mock_comm, \ - mock.patch('src.Application.setup_logger') as mock_logger: + mock.patch('src.Application.setup_logger'): app = Application() app._logger = mock.Mock() app._config_manager = mock_config_manager return app - def test_init(application: Application) -> None: """Test the initialization of Application.""" assert hasattr(application, '_logger') - assert hasattr(application, '_view') assert hasattr(application, '_communicator') - assert hasattr(application, '_request_lock') assert hasattr(application, '_config_manager') - -@mock.patch('src.Application.threading.Thread') -def test_start_communication( - mock_thread: mock.Mock, - application: Application -) -> None: - """Test the communication startup.""" - application._start_communication() - - application._communicator.connect.assert_called_once() - mock_thread.assert_called_once_with( - target=application._communicator.receive_message, - daemon=True - ) - mock_thread.return_value.start.assert_called_once() - - -def test_start_gui(application: Application) -> None: - """Test the GUI startup.""" - application._start_gui() - application._view.run.assert_called_once() - - def test_handle_request_ad_block(application: Application) -> None: """Test handling ad block request.""" test_request = { STR_CODE: Codes.CODE_AD_BLOCK, - STR_CONTENT: "test" + STR_CONTENT: "on", } application._communicator.send_message = mock.Mock() - application._handle_request(json.dumps(test_request)) - actual_arg = application._communicator.send_message.call_args[0][0] - - assert json.loads(json.loads(actual_arg)) == test_request - + application._communicator.send_message.assert_called_once() + sent_data = application._communicator.send_message.call_args[0][0] + assert sent_data == test_request def test_handle_request_domain_list_update(application: Application) -> None: """Test handling domain list update request.""" - test_content = ["domain1.com", "domain2.com"] - test_request = json.dumps({ - STR_CODE: Codes.CODE_DOMAIN_LIST_UPDATE, - STR_CONTENT: test_content - }) + test_domains = ["domain1.com", "domain2.com"] + test_request = { + STR_CODE: Codes.CODE_SUCCESS, + STR_DOMAINS: test_domains, + STR_OPERATION: Codes.CODE_DOMAIN_LIST_UPDATE + } - application._handle_request(test_request) - application._view.update_domain_list_response.assert_called_once_with(test_content) - + application._handle_request(json.dumps(test_request), to_server=False) + application._view.update_domain_list_response.assert_called_once_with(test_domains) def test_cleanup(application: Application) -> None: """Test cleanup process.""" application._cleanup() - application._communicator.close.assert_called_once() - application._view.root.destroy.assert_called_once() - - -def test_run_success(application: Application) -> None: - """Test successful application run.""" - with mock.patch.object(application, '_start_communication'), \ - mock.patch.object(application, '_start_gui'), \ - mock.patch.object(application, '_cleanup'): - - application.run() - - application._start_communication.assert_called_once() - application._start_gui.assert_called_once() - application._cleanup.assert_called_once() - - -def test_run_exception(application: Application) -> None: - """Test application run with exception.""" - error_msg = "Test error" - - with mock.patch.object(application, '_start_communication') as mock_start_comm, \ - mock.patch.object(application, '_cleanup') as mock_cleanup: - - mock_start_comm.side_effect = Exception(error_msg) - - with pytest.raises(Exception) as exc_info: - application.run() - - assert str(exc_info.value) == error_msg - application._logger.error.assert_called_with( - f"Error during execution: {error_msg}", - exc_info=True - ) - mock_cleanup.assert_called_once() - -def test_handle_request_json_error(application: Application) -> None: - """Test handling of invalid JSON in request.""" +def test_handle_request_invalid_json(application: Application) -> None: + """Test handling invalid JSON in request.""" invalid_json = "{" with pytest.raises(json.JSONDecodeError): diff --git a/client/tests/test_communicator.py b/client/tests/test_communicator.py index 2a761e4..acfbfdd 100644 --- a/client/tests/test_communicator.py +++ b/client/tests/test_communicator.py @@ -1,17 +1,15 @@ import socket -from unittest import mock -from typing import Optional, Callable - import pytest +from unittest import mock +from typing import Callable +import json from src.Communicator import Communicator from src.utils import ( - DEFAULT_HOST, DEFAULT_PORT, DEFAULT_BUFFER_SIZE, - ERR_SOCKET_NOT_SETUP, STR_NETWORK, - DEFAULT_CONFIG + DEFAULT_CONFIG, ERR_SOCKET_NOT_SETUP, + STR_NETWORK, STR_HOST, STR_PORT, STR_RECEIVE_BUFFER_SIZE ) - @pytest.fixture def mock_config_manager() -> mock.Mock: """Fixture to provide a mock configuration manager.""" @@ -19,13 +17,11 @@ def mock_config_manager() -> mock.Mock: config_manager.get_config.return_value = DEFAULT_CONFIG return config_manager - @pytest.fixture def mock_callback() -> Callable[[str], None]: """Fixture to provide a mock callback function.""" return mock.Mock() - @pytest.fixture def communicator( mock_config_manager: mock.Mock, @@ -37,125 +33,55 @@ def communicator( message_callback=mock_callback ) - -def test_init( - communicator: Communicator, - mock_callback: Callable[[str], None] -) -> None: - """Test the initialization of Communicator.""" - assert communicator._host == DEFAULT_CONFIG[STR_NETWORK][DEFAULT_HOST] - assert communicator._port == DEFAULT_CONFIG[STR_NETWORK][DEFAULT_PORT] - assert communicator._receive_buffer_size == DEFAULT_CONFIG[STR_NETWORK][DEFAULT_BUFFER_SIZE] +def test_init(communicator: Communicator, mock_callback: Callable[[str], None]) -> None: + """Test initialization of Communicator.""" + config = DEFAULT_CONFIG[STR_NETWORK] + assert communicator._host == config[STR_HOST] + assert communicator._port == int(config[STR_PORT]) + assert communicator._receive_buffer_size == int(config[STR_RECEIVE_BUFFER_SIZE]) assert communicator._socket is None assert communicator._message_callback == mock_callback - @mock.patch('socket.socket') -def test_connect( - mock_socket_class: mock.Mock, - communicator: Communicator -) -> None: - """Test the connect method initializes and connects the socket.""" +def test_connect(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test socket connection.""" mock_socket_instance = mock_socket_class.return_value communicator.connect() - mock_socket_class.assert_called_once_with( - socket.AF_INET, - socket.SOCK_STREAM - ) + mock_socket_class.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) mock_socket_instance.connect.assert_called_once_with( (communicator._host, communicator._port) ) assert communicator._socket is mock_socket_instance - @mock.patch('socket.socket') def test_send_message_without_setup( mock_socket_class: mock.Mock, communicator: Communicator ) -> None: - """Test sending a message without setting up the socket raises RuntimeError.""" + """Test sending message without socket setup.""" with pytest.raises(RuntimeError) as exc_info: - communicator.send_message("Hello") + communicator.send_message("test message") assert str(exc_info.value) == ERR_SOCKET_NOT_SETUP - @mock.patch('socket.socket') -def test_send_message( - mock_socket_class: mock.Mock, - communicator: Communicator -) -> None: - """Test sending a message successfully.""" +def test_send_message(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test sending message successfully.""" mock_socket_instance = mock_socket_class.return_value communicator._socket = mock_socket_instance - - message: str = "Hello, World!" + + message = {"test": "message"} communicator.send_message(message) - - mock_socket_instance.send.assert_called_once_with( - message.encode('utf-8') - ) - - -@mock.patch('socket.socket') -def test_receive_message_without_setup( - mock_socket_class: mock.Mock, - communicator: Communicator -) -> None: - """Test receiving a message without setting up the socket raises RuntimeError.""" - with pytest.raises(RuntimeError) as exc_info: - communicator.receive_message() - assert str(exc_info.value) == ERR_SOCKET_NOT_SETUP - - -@mock.patch('socket.socket') -def test_receive_message( - mock_socket_class: mock.Mock, - communicator: Communicator, - mock_callback: Callable[[str], None] -) -> None: - """Test receiving a message successfully.""" - mock_socket_instance = mock_socket_class.return_value - communicator._socket = mock_socket_instance - - mock_socket_instance.recv.side_effect = [b'Hello, Client!', b''] - communicator.receive_message() - - mock_socket_instance.recv.assert_called_with( - DEFAULT_CONFIG[STR_NETWORK][DEFAULT_BUFFER_SIZE] - ) - mock_callback.assert_called_once_with('Hello, Client!') - + mock_socket_instance.send.assert_called_once_with(json.dumps(message).encode('utf-8')) @mock.patch('socket.socket') -def test_close_socket( - mock_socket_class: mock.Mock, - communicator: Communicator -) -> None: - """Test closing the socket.""" +def test_close_socket(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test closing socket connection.""" mock_socket_instance = mock_socket_class.return_value communicator._socket = mock_socket_instance - + communicator.close() - + mock_socket_instance.close.assert_called_once() assert communicator._socket is None - - -@mock.patch('socket.socket') -def test_receive_message_decode_error( - mock_socket_class: mock.Mock, - communicator: Communicator, - mock_callback: Callable[[str], None] -) -> None: - """Test handling of decode errors in receive_message.""" - mock_socket_instance = mock_socket_class.return_value - communicator._socket = mock_socket_instance - - mock_socket_instance.recv.side_effect = [bytes([0xFF, 0xFE, 0xFD]), b''] - - with pytest.raises(UnicodeDecodeError): - communicator.receive_message() - - mock_callback.assert_not_called() diff --git a/client/tests/test_view.py b/client/tests/test_view.py index c611e5c..fa994eb 100644 --- a/client/tests/test_view.py +++ b/client/tests/test_view.py @@ -1,158 +1,159 @@ +"""Unit tests for the Viewer class.""" + import pytest from unittest import mock import json -from typing import Callable +from typing import Dict, Any from src.View import Viewer from src.utils import ( - Codes, STR_CODE, STR_CONTENT, - STR_SETTINGS, STR_AD_BLOCK, STR_ADULT_BLOCK, - STR_BLOCKED_DOMAINS, DEFAULT_CONFIG, - ERR_DUPLICATE_DOMAIN + Codes, STR_CODE, STR_CONTENT, STR_DOMAINS, STR_SETTINGS, + STR_AD_BLOCK, STR_ADULT_BLOCK ) @pytest.fixture def mock_config_manager() -> mock.Mock: - """Fixture to provide a mock configuration manager.""" + """Create a mock configuration manager fixture.""" config_manager = mock.Mock() - config_manager.get_config.return_value = DEFAULT_CONFIG.copy() + config_manager.get_config.return_value = { + "network": { + "host": "127.0.0.1", + "port": 65432 + } + } return config_manager @pytest.fixture -def mock_callback() -> Callable[[str], None]: - """Fixture to provide a mock callback function.""" +def mock_callback() -> mock.Mock: + """Create a mock callback function fixture.""" return mock.Mock() @pytest.fixture -def viewer(mock_config_manager: mock.Mock, mock_callback: mock.Mock) -> Viewer: - """Fixture to create a Viewer instance with mocked components.""" - with mock.patch('tkinter.Tk') as mock_tk, \ - mock.patch('tkinter.ttk.Style'): - # Create a mock Tk instance - root = mock_tk.return_value +def mock_tk() -> mock.Mock: + """Create a mock for tkinter components.""" + with mock.patch('src.View.tk') as mock_tk: + # Mock Tk instance + mock_root = mock.Mock() + mock_tk.Tk.return_value = mock_root + + # Mock StringVar + mock_string_var = mock.Mock() + mock_string_var.get.return_value = "on" + mock_tk.StringVar.return_value = mock_string_var - # Set up the mock root properly - mock_tk._default_root = root - root._default_root = root + # Mock Listbox + mock_listbox = mock.Mock() + mock_listbox.get.side_effect = lambda start, end: ["domain1.com", "domain2.com"] + mock_tk.Listbox.return_value = mock_listbox - # Create StringVar mock that returns string values - with mock.patch('tkinter.StringVar') as mock_string_var: - string_var_instance = mock.Mock() - string_var_instance.get.return_value = "off" - mock_string_var.return_value = string_var_instance - - # Create Entry and Listbox mocks - with mock.patch('tkinter.Entry') as mock_entry, \ - mock.patch('tkinter.Listbox') as mock_listbox: - - # Setup Entry mock - entry_instance = mock.Mock() - entry_instance.get.return_value = "" - mock_entry.return_value = entry_instance - - # Setup Listbox mock - listbox_instance = mock.Mock() - listbox_instance.curselection.return_value = () - listbox_instance.get.return_value = "" - mock_listbox.return_value = listbox_instance - - viewer = Viewer( - config_manager=mock_config_manager, - message_callback=mock_callback - ) - - # Store mock instances for easy access in tests - viewer.domain_entry = entry_instance - viewer.domains_listbox = listbox_instance - - # Mock the _show_error method - viewer._show_error = mock.Mock() - - return viewer + yield mock_tk -def test_get_block_settings(viewer: Viewer) -> None: - """Test getting block settings.""" - # Configure the mock StringVar to return specific values - viewer.ad_var.get.return_value = "off" - viewer.adult_var.get.return_value = "off" - - settings = viewer.get_block_settings() - assert STR_AD_BLOCK in settings - assert STR_ADULT_BLOCK in settings - assert isinstance(settings[STR_AD_BLOCK], str) - assert isinstance(settings[STR_ADULT_BLOCK], str) +@pytest.fixture +def viewer( + mock_config_manager: mock.Mock, + mock_callback: mock.Mock, + mock_tk: mock.Mock +) -> Viewer: + """Create a Viewer instance with mocked dependencies.""" + with mock.patch('src.View.ttk'), \ + mock.patch('src.View.messagebox'), \ + mock.patch('src.View.setup_logger') as mock_logger: + logger_instance = mock.Mock() + mock_logger.return_value = logger_instance + + viewer = Viewer( + config_manager=mock_config_manager, + message_callback=mock_callback + ) + + # Set up instance variables that would normally be created in _setup_ui + viewer.domains_listbox = mock_tk.Listbox.return_value + viewer.ad_var = mock_tk.StringVar.return_value + viewer.adult_var = mock_tk.StringVar.return_value + viewer.domain_entry = mock.Mock() + + return viewer -def test_handle_ad_block(viewer: Viewer) -> None: - """Test handling ad block setting changes.""" - # Configure the mock StringVar to return "on" - viewer.ad_var.get.return_value = "on" - viewer._handle_ad_block_request() - - expected_json = json.dumps({ +def test_handle_ad_block_request(viewer: Viewer) -> None: + """Test handling ad block request message formation.""" + expected_message = json.dumps({ STR_CODE: Codes.CODE_AD_BLOCK, STR_CONTENT: "on" }) - viewer._message_callback.assert_called_once_with(expected_json) - viewer.config_manager.save_config.assert_called_once_with(viewer.config) - assert viewer.config[STR_SETTINGS][STR_AD_BLOCK] == "on" + viewer._handle_ad_block_request() + viewer._message_callback.assert_called_once_with(expected_message) -def test_handle_adult_block(viewer: Viewer) -> None: - """Test handling adult block setting changes.""" - # Configure the mock StringVar to return "on" - viewer.adult_var.get.return_value = "on" - viewer._handle_adult_block_request() - - expected_json = json.dumps({ +def test_handle_adult_block_request(viewer: Viewer) -> None: + """Test handling adult block request message formation.""" + expected_message = json.dumps({ STR_CODE: Codes.CODE_ADULT_BLOCK, STR_CONTENT: "on" }) - viewer._message_callback.assert_called_once_with(expected_json) - viewer.config_manager.save_config.assert_called_once_with(viewer.config) - assert viewer.config[STR_SETTINGS][STR_ADULT_BLOCK] == "on" + viewer._handle_adult_block_request() + viewer._message_callback.assert_called_once_with(expected_message) -def test_add_domain(viewer: Viewer) -> None: - """Test adding a domain.""" - domain = "test.com" - viewer.domain_entry.get.return_value = domain - viewer._add_domain_request() +def test_update_initial_settings(viewer: Viewer) -> None: + """Test updating initial settings from server response.""" + test_settings = { + STR_DOMAINS: ["example.com", "test.com"], + STR_SETTINGS: { + STR_AD_BLOCK: "on", + STR_ADULT_BLOCK: "off" + } + } - expected_json = json.dumps({ - STR_CODE: Codes.CODE_ADD_DOMAIN, - STR_CONTENT: domain - }) - - viewer._message_callback.assert_called_once_with(expected_json) - viewer.config_manager.save_config.assert_called_once_with(viewer.config) - assert viewer.config[STR_BLOCKED_DOMAINS][domain] is True + viewer.update_initial_settings(test_settings) + viewer.logger.info.assert_called_with("Successfully initialized settings from server") -def test_add_duplicate_domain(viewer: Viewer) -> None: - """Test adding a duplicate domain.""" - domain = "test.com" - viewer.config[STR_BLOCKED_DOMAINS][domain] = True - viewer.domain_entry.get.return_value = domain - - viewer._add_domain_request() +def test_update_domain_list_response(viewer: Viewer) -> None: + """Test updating domain list from server response.""" + test_domains = ["domain1.com", "domain2.com"] - viewer._message_callback.assert_not_called() - viewer._show_error.assert_called_once_with(ERR_DUPLICATE_DOMAIN) - assert len(viewer.config[STR_BLOCKED_DOMAINS]) == 1 + viewer.update_domain_list_response(test_domains) + viewer.logger.info.assert_called_with(f"Updated domain list with {len(test_domains)} domains") -def test_remove_domain(viewer: Viewer) -> None: - """Test removing a domain.""" - domain = "test.com" - viewer.config[STR_BLOCKED_DOMAINS][domain] = True - viewer.domains_listbox.curselection.return_value = (0,) - viewer.domains_listbox.get.return_value = domain +@pytest.mark.parametrize("response,expected_log", [ + ( + {STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: "test.com"}, + "info" + ), + ( + {STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: "Failed to add domain"}, + "error" + ) +]) +def test_add_domain_response( + viewer: Viewer, + response: Dict[str, Any], + expected_log: str +) -> None: + """Test handling add domain response from server.""" + # Reset the mock call counts before our test + viewer.logger.info.reset_mock() + viewer.logger.error.reset_mock() - viewer._remove_domain_request() + viewer.add_domain_response(response) - expected_json = json.dumps({ - STR_CODE: Codes.CODE_REMOVE_DOMAIN, - STR_CONTENT: domain - }) - - viewer._message_callback.assert_called_once_with(expected_json) - viewer.config_manager.save_config.assert_called_once_with(viewer.config) - assert domain not in viewer.config[STR_BLOCKED_DOMAINS] + if expected_log == "info": + viewer.logger.info.assert_called_once() + viewer.logger.error.assert_not_called() + else: + viewer.logger.error.assert_called_once() + +def test_get_blocked_domains(viewer: Viewer) -> None: + """Test getting list of blocked domains.""" + expected_domains = ["domain1.com", "domain2.com"] + domains = list(viewer.get_blocked_domains()) + assert domains == expected_domains + +def test_get_block_settings(viewer: Viewer) -> None: + """Test getting block settings.""" + settings = viewer.get_block_settings() + assert settings == { + STR_AD_BLOCK: "on", + STR_ADULT_BLOCK: "on" + }