diff --git a/.gitignore b/.gitignore index 594cee2..5319942 100644 --- a/.gitignore +++ b/.gitignore @@ -845,3 +845,5 @@ FodyWeavers.xsd # Additional files built by Visual Studio # End of https://www.toptal.com/developers/gitignore/api/python,c,windows,linux,visualstudio,pycharm,clion +server/client.py +server/my_internet.db diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/__init__.py b/client/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/config.json b/client/config.json new file mode 100644 index 0000000..1036558 --- /dev/null +++ b/client/config.json @@ -0,0 +1,11 @@ +{ + "network": { + "host": "127.0.0.1", + "port": 65432, + "receive_buffer_size": 1024 + }, + "logging": { + "level": "INFO", + "log_dir": "client_logs" + } +} \ No newline at end of file diff --git a/client/main.py b/client/main.py new file mode 100644 index 0000000..7ef3d01 --- /dev/null +++ b/client/main.py @@ -0,0 +1,8 @@ +from src.Application import Application + +def main() -> None: + application: Application = Application() + application.run() + +if __name__ == "__main__": + main() diff --git a/client/pytest.ini b/client/pytest.ini new file mode 100644 index 0000000..a635c5c --- /dev/null +++ b/client/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = . diff --git a/client/requirments.txt b/client/requirments.txt new file mode 100644 index 0000000..0a08159 --- /dev/null +++ b/client/requirments.txt @@ -0,0 +1,8 @@ +-e git+https://github.com/pazMenachem/My_Internet.git@b0c04b626f09baa0dace19fb70902cc8189f7ce0#egg=client&subdirectory=client +colorama==0.4.6 +exceptiongroup==1.2.2 +iniconfig==2.0.0 +packaging==24.1 +pluggy==1.5.0 +pytest==8.3.3 +tomli==2.0.2 diff --git a/client/setup.py b/client/setup.py new file mode 100644 index 0000000..37ab367 --- /dev/null +++ b/client/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="client", + packages=find_packages(), + version="0.1", +) \ No newline at end of file diff --git a/client/src/Application.py b/client/src/Application.py new file mode 100644 index 0000000..dd16448 --- /dev/null +++ b/client/src/Application.py @@ -0,0 +1,124 @@ +import json +import threading +from .Communicator import Communicator +from .View import Viewer +from .Logger import setup_logger +from .ConfigManager import ConfigManager + +from .utils import ( + STR_DOMAINS, STR_OPERATION, + Codes +) + +class Application: + """ + Main application class that coordinates communication between UI and server. + + Uses threading to handle simultaneous GUI and network operations. + + Attributes: + _logger: Logger instance for application logging + _view: Viewer instance for GUI operations + _communicator: Communicator instance for network operations + """ + + def __init__(self) -> None: + """Initialize application components.""" + self._logger = setup_logger(__name__) + self._config_manager = ConfigManager() + + self._view = Viewer(config_manager=self._config_manager, message_callback=self._handle_request) + self._communicator = Communicator(config_manager=self._config_manager, message_callback=self._handle_request) + + def run(self) -> None: + """ + Start the application with threaded communication handling. + + Raises: + Exception: If there's an error during startup of either component. + """ + self._logger.info("Starting application") + + try: + self._start_communication() + self._start_gui() + + except Exception as e: + self._logger.error(f"Error during execution: {str(e)}", exc_info=True) + raise + finally: + self._cleanup() + + def _start_communication(self) -> None: + """Initialize and start the communication thread.""" + try: + self._communicator.connect() + threading.Thread( + target=self._communicator.receive_message, + daemon=True + ).start() + + self._logger.info("Communication server started successfully") + except Exception as e: + self._logger.error(f"Failed to start communication: {str(e)}") + raise + + def _start_gui(self) -> None: + """Start the GUI main loop.""" + try: + self._logger.info("Starting GUI") + self._view.run() + + except Exception as e: + self._logger.error(f"Failed to start GUI: {str(e)}") + raise + + def _handle_request(self, request: str, to_server: bool = True) -> None: + """ + Handle outgoing messages from the UI and Server. + + Args: + request: received request from server or user input from UI. + """ + try: + self._logger.info(f"Processing request: {request}") + request_dict = json.loads(request) + + if to_server: + message = request if isinstance(request, dict) else json.loads(request) + self._communicator.send_message(message) + return + + match request_dict[STR_OPERATION]: + case Codes.CODE_INIT_SETTINGS: + self._view.update_initial_settings(request_dict) + case Codes.CODE_AD_BLOCK: + self._view.ad_block_response(request_dict) + case Codes.CODE_ADULT_BLOCK: + self._view.adult_block_response(request_dict) + case Codes.CODE_ADD_DOMAIN: + self._view.add_domain_response(request_dict) + case Codes.CODE_REMOVE_DOMAIN: + self._view.remove_domain_response(request_dict) + case Codes.CODE_DOMAIN_LIST_UPDATE: + self._view.update_domain_list_response(request_dict[STR_DOMAINS]) + + except json.JSONDecodeError as e: + self._logger.error(f"Invalid JSON format: {str(e)}") + raise + except Exception as e: + self._logger.error(f"Error handling request: {str(e)}") + raise + + def _cleanup(self) -> None: + """Clean up resources and stop threads.""" + self._logger.info("Cleaning up application resources") + try: + if self._communicator: + self._communicator.close() + + if self._view and self._view.root.winfo_exists(): + self._view.root.destroy() + + except Exception as e: + self._logger.warning(f"Cleanup encountered an error: {str(e)}") diff --git a/client/src/Communicator.py b/client/src/Communicator.py new file mode 100644 index 0000000..1713641 --- /dev/null +++ b/client/src/Communicator.py @@ -0,0 +1,104 @@ +import socket +from typing import Optional, Callable +import json +from .Logger import setup_logger +from .utils import ( + ERR_SOCKET_NOT_SETUP, + STR_NETWORK, STR_HOST, STR_PORT, STR_RECEIVE_BUFFER_SIZE +) + +class Communicator: + def __init__(self, config_manager, message_callback: Callable[[str], None]) -> None: + """ + Initialize the communicator. + + Args: + config_manager: Configuration manager instance + message_callback: Callback function to handle received messages. + """ + self.logger = setup_logger(__name__) + self.logger.info("Initializing Communicator") + self.config = config_manager.get_config() + self._message_callback = message_callback + + self._host = self.config[STR_NETWORK][STR_HOST] + self._port = int(self.config[STR_NETWORK][STR_PORT]) + self._receive_buffer_size = int(self.config[STR_NETWORK][STR_RECEIVE_BUFFER_SIZE]) + self._socket: Optional[socket.socket] = None + + def connect(self) -> None: + """ + Establish connection to the server. + + Raises: + socket.error: If connection cannot be established. + """ + try: + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.connect((self._host, self._port)) + self.logger.info(f"Connected to server at {self._host}:{self._port}") + except socket.error as e: + self.logger.error(f"Failed to connect to server: {str(e)}") + raise + + def send_message(self, request: dict) -> None: + """ + Send a json request to the server. + + Args: + request: The request to send to the server. + + Raises: + RuntimeError: If socket connection is not established. + """ + self._validate_connection() + + try: + self._socket.send(json.dumps(request).encode('utf-8')) + self.logger.info(f"Request sent: {request}") + except Exception as e: + self.logger.error(f"Failed to send request: {str(e)}") + raise + + def receive_message(self) -> None: + """Continuously receive and process messages from the socket connection. + + This method runs in a loop to receive messages from the socket. Each received + message is decoded from UTF-8 and passed to the message callback function. + + Raises: + RuntimeError: If socket connection is not established. + socket.error: If there's an error receiving data from the socket. + UnicodeDecodeError: If received data cannot be decoded as UTF-8. + """ + self._validate_connection() + + self.logger.info("Starting message receive loop") + try: + while message_bytes := self._socket.recv(self._receive_buffer_size): + if not message_bytes: + self.logger.warning("Received empty message, breaking receive loop") + break + message = message_bytes.decode('utf-8') + self.logger.info(f"Received message: {message}") + self._message_callback(message, False) + except Exception as e: + self.logger.error(f"Error receiving message: {str(e)}") + raise + + def close(self) -> None: + """Close the socket connection and clean up resources.""" + if self._socket: + try: + self._socket.close() + self.logger.info("Socket connection closed") + except Exception as e: + self.logger.error(f"Error closing socket: {str(e)}") + finally: + self._socket = None + + def _validate_connection(self) -> None: + """Validate the socket connection.""" + if not self._socket: + self.logger.error(ERR_SOCKET_NOT_SETUP) + raise RuntimeError(ERR_SOCKET_NOT_SETUP) diff --git a/client/src/ConfigManager.py b/client/src/ConfigManager.py new file mode 100644 index 0000000..62e67df --- /dev/null +++ b/client/src/ConfigManager.py @@ -0,0 +1,87 @@ +"""Configuration management module for the application.""" + +import json +import os +from typing import Dict, Any +from .Logger import setup_logger +from .utils import DEFAULT_CONFIG + + +class ConfigManager: + """Manages application configuration loading and saving.""" + + def __init__(self, config_file: str = "config.json") -> None: + """ + Initialize the configuration manager. + + Args: + config_file: Path to the configuration file. + """ + self.logger = setup_logger(__name__) + self.config_file = config_file + self.config = self._load_config() + + def _load_config(self) -> Dict[str, Any]: + """ + Load configuration from JSON file. + + Returns: + Dict containing configuration settings. + """ + try: + if os.path.exists(self.config_file): + self.logger.info(f"Loading configuration from {self.config_file}") + with open(self.config_file, 'r') as f: + user_config = json.load(f) + return self._merge_configs(DEFAULT_CONFIG, user_config) + + self.logger.warning(f"Configuration file not found, using default configuration") + + except json.JSONDecodeError: + self.logger.error(f"Error decoding {self.config_file}, using default configuration") + + return DEFAULT_CONFIG.copy() + + def _merge_configs(self, default: Dict[str, Any], user: Dict[str, Any]) -> Dict[str, Any]: + """ + Recursively merge user configuration with default configuration. + + Args: + default: Default configuration dictionary + user: User configuration dictionary + + Returns: + Merged configuration dictionary + """ + result = default.copy() + + for key, value in user.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = self._merge_configs(result[key], value) + else: + result[key] = value + + return result + + def save_config(self, config: Dict[str, Any]) -> None: + """ + Save configuration to JSON file. + + Args: + config: Configuration dictionary to save + """ + try: + with open(self.config_file, 'w') as f: + json.dump(config, f, indent=4) + self.logger.info("Configuration saved successfully") + except Exception as e: + self.logger.error(f"Error saving configuration: {str(e)}") + + def get_config(self) -> Dict[str, Any]: + """ + Get the current configuration. + + Returns: + Current configuration dictionary + """ + return self.config diff --git a/client/src/Logger.py b/client/src/Logger.py new file mode 100644 index 0000000..a6c5709 --- /dev/null +++ b/client/src/Logger.py @@ -0,0 +1,45 @@ +"""Logger module for handling application-wide logging configuration.""" + +import logging +import os +from datetime import datetime +from typing import Optional +from .utils import LOG_DIR, LOG_FORMAT, LOG_DATE_FORMAT + +_logger: Optional[logging.Logger] = None + +def setup_logger(name: str) -> logging.Logger: + """ + Configure and return a logger instance. + + Args: + name: The name of the module requesting the logger. + + Returns: + logging.Logger: Configured logger instance. + """ + global _logger + + if _logger is not None: + return logging.getLogger(name) + + if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR) + + log_file: str = os.path.join( + LOG_DIR, f"Client_{datetime.now().strftime(LOG_DATE_FORMAT)}.log" + ) + + logging.basicConfig( + level=logging.INFO, + format=LOG_FORMAT, + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler(), + ], + ) + + _logger = logging.getLogger(name) + _logger.info("Logger setup complete") + + return _logger \ No newline at end of file diff --git a/client/src/View.py b/client/src/View.py new file mode 100644 index 0000000..8e546e7 --- /dev/null +++ b/client/src/View.py @@ -0,0 +1,519 @@ +import tkinter as tk +from tkinter import ttk, messagebox +from typing import Callable, List +import json +from .Logger import setup_logger +from .ConfigManager import ConfigManager + +from .utils import ( + Codes, + WINDOW_SIZE, WINDOW_TITLE, + ERR_NO_DOMAIN_SELECTED, ERR_DOMAIN_LIST_UPDATE_FAILED, + STR_AD_BLOCK, STR_ADULT_BLOCK, STR_CODE, STR_BLOCKED_DOMAINS, + STR_CONTENT, STR_SETTINGS, STR_ERROR, STR_SUCCESS, + STR_ADD_DOMAIN_RESPONSE, STR_REMOVE_DOMAIN_REQUEST, STR_ADD_DOMAIN_REQUEST, + STR_AD_BLOCK_RESPONSE, STR_ADULT_BLOCK_RESPONSE, STR_REMOVE_DOMAIN_RESPONSE, + STR_DOMAINS, +) + +class Viewer: + """ + Graphical user interface for the application. + """ + + def __init__(self, config_manager: ConfigManager, message_callback: Callable[[str], None]) -> None: + """ + Initialize the viewer window and its components. + + Args: + config_manager: Configuration manager instance + message_callback: Callback function to handle message sending. + """ + self.logger = setup_logger(__name__) + self.logger.info("Initializing Viewer") + self.config_manager = config_manager + self.config = config_manager.get_config() + self._message_callback = message_callback + + # Initialize root window first + self.root: tk.Tk = tk.Tk() + self.root.title(WINDOW_TITLE) + self.root.geometry(WINDOW_SIZE) + + self.root.withdraw() # Hide the window temporarily + + # Configure styles + style = ttk.Style() + style.configure('TLabelframe', padding=10) + style.configure('TLabelframe.Label', font=('Arial', 10, 'bold')) + style.configure('TButton', padding=5) + style.configure('TRadiobutton', font=('Arial', 10)) + style.configure('TLabel', font=('Arial', 10)) + + self._setup_ui() + + # Show the window after setup is complete + self.root.deiconify() + self.logger.info("Viewer initialization complete") + + def run(self) -> None: + """Start the main event loop of the viewer.""" + self.logger.info("Starting main event loop") + self.root.mainloop() + + def get_blocked_domains(self) -> tuple[str, ...]: + """ + Get the list of currently blocked domains. + + Returns: + A tuple containing all blocked domains. + """ + return self.domains_listbox.get(0, tk.END) + + def get_block_settings(self) -> dict[str, str]: + """ + Get the current state of blocking settings. + + Returns: + A dictionary containing the current state of ad and adult content blocking. + """ + return { + STR_AD_BLOCK: self.ad_var.get(), + STR_ADULT_BLOCK: self.adult_var.get() + } + + def update_domain_list_response(self, domains: list[str]) -> None: + """ + Update the domains listbox with a new list of domains from the server. + + Args: + domains: List of domain strings to be displayed in the listbox. + """ + self.logger.info("Updating domain list from server") + + try: + self.domains_listbox.delete(0, tk.END) + + for domain in domains: + self.domains_listbox.insert(tk.END, domain) + + except Exception as e: + self.logger.error(f"Error updating domain list: {str(e)}") + self._show_error(ERR_DOMAIN_LIST_UPDATE_FAILED) + return + + self.logger.info(f"Updated domain list with {len(domains)} domains") + + def add_domain_response(self, response: dict) -> None: + """ + Handle the response from the server after attempting to add a domain. + + Args: + response: Dictionary containing the server's response with code and content. + """ + try: + match response[STR_CODE]: + case Codes.CODE_SUCCESS: + domain = response[STR_CONTENT] + self.domains_listbox.insert(tk.END, domain) + self.domain_entry.delete(0, tk.END) + + self._show_success( + message=f"Domain '{domain}' added successfully", + operation=STR_ADD_DOMAIN_RESPONSE + ) + + case Codes.CODE_ERROR: + self._show_error( + message=response[STR_CONTENT], + operation=STR_ADD_DOMAIN_RESPONSE + ) + + except Exception as e: + self._show_error( + message="An unexpected error occurred", + operation=f"Processing add domain response: {str(e)}" + ) + + def ad_block_response(self, response: dict) -> None: + """ + Handle the response from the server after changing ad block setting. + + Args: + response: Dictionary containing the server's response with code and content. + """ + prev_state = "off" if self.ad_var.get() == "on" else "on" + + try: + match response[STR_CODE]: + case Codes.CODE_SUCCESS: + self._show_success( + message=f"Ad blocking turned {self.ad_var.get()}", + operation=STR_AD_BLOCK_RESPONSE + ) + case Codes.CODE_ERROR: + self.ad_var.set(prev_state) + self._show_error( + message=response[STR_CONTENT], + operation=STR_AD_BLOCK_RESPONSE + ) + except Exception as e: + self.ad_var.set(prev_state) + self._show_error( + message="An unexpected error occurred", + operation=f"Processing ad block response: {str(e)}" + ) + + def adult_block_response(self, response: dict) -> None: + """ + Handle the response from the server after changing adult block setting. + + Args: + response: Dictionary containing the server's response with code and content. + """ + prev_state = "off" if self.adult_var.get() == "on" else "on" + + try: + match response[STR_CODE]: + case Codes.CODE_SUCCESS: + self._show_success( + message=f"Adult content blocking turned {self.adult_var.get()}", + operation=STR_ADULT_BLOCK_RESPONSE + ) + case Codes.CODE_ERROR: + self.adult_var.set(prev_state) + self._show_error( + message=response[STR_CONTENT], + operation=STR_ADULT_BLOCK_RESPONSE + ) + except Exception as e: + self.adult_var.set(prev_state) + self._show_error( + message="An unexpected error occurred", + operation=f"Processing adult block response: {str(e)}" + ) + + def remove_domain_response(self, response: dict) -> None: + """ + Handle the response from the server after removing a domain. + + Args: + response: Dictionary containing the server's response with code and content. + """ + try: + match response[STR_CODE]: + case Codes.CODE_SUCCESS: + domain = response[STR_CONTENT] + self.domains_listbox.delete(self.domains_listbox.curselection()) + self._show_success( + message=f"Domain '{domain}' removed successfully", + operation=STR_REMOVE_DOMAIN_RESPONSE + ) + case Codes.CODE_ERROR: + self._show_error( + message=response[STR_CONTENT], + operation=STR_REMOVE_DOMAIN_RESPONSE + ) + except Exception as e: + self._show_error( + message="An unexpected error occurred", + operation=f"Processing remove domain response: {str(e)}" + ) + + def update_initial_settings(self, response: dict) -> None: + """ + Update all initial settings from server response. + + Args: + response: Dictionary containing initial settings: + - domains: List of blocked domains + - settings: Dictionary with ad_block and adult_block states + """ + try: + self.root.after(0, lambda: self.update_domain_list_response(response[STR_DOMAINS])) + self.root.after(0, lambda: self._update_block_settings(response[STR_SETTINGS])) + + self.logger.info("Successfully initialized settings from server") + + except Exception as e: + self._show_error( + message="Failed to initialize settings", + operation=f"Initial settings update: {str(e)}" + ) + + def _add_domain_request(self) -> None: + """Add a domain to the blocked sites list.""" + domain = self.domain_entry.get().strip() + + if domain: + self.logger.debug(f"Sending add domain request for: {domain}") + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_ADD_DOMAIN, + STR_CONTENT: domain + })) + else: + self._show_error( + message="Please enter a domain name", + operation=STR_ADD_DOMAIN_REQUEST + ) + + def _remove_domain_request(self) -> None: + """Remove the selected domain from the blocked sites list.""" + selection = self.domains_listbox.curselection() + + if selection: + domain = self.domains_listbox.get(selection) + self.logger.debug(f"Sending remove domain request for: {domain}") + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_REMOVE_DOMAIN, + STR_CONTENT: domain + })) + else: + self._show_error( + message=ERR_NO_DOMAIN_SELECTED, + operation=STR_REMOVE_DOMAIN_REQUEST + ) + + def _handle_ad_block_request(self) -> None: + """Handle changes to the ad block setting.""" + state = self.ad_var.get() + self.logger.debug(f"Sending ad block request: {state}") + + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_AD_BLOCK, + STR_CONTENT: state + })) + + def _handle_adult_block_request(self) -> None: + """Handle changes to the adult sites block setting.""" + state = self.adult_var.get() + self.logger.debug(f"Sending adult block request: {state}") + + self._message_callback(json.dumps({ + STR_CODE: Codes.CODE_ADULT_BLOCK, + STR_CONTENT: state + })) + + def _update_block_settings(self, settings: dict) -> None: + """Update the block settings radio buttons.""" + self.ad_var.set(settings[STR_AD_BLOCK]) + self.adult_var.set(settings[STR_ADULT_BLOCK]) + + def _show_error(self, message: str, operation: str = "") -> None: + """ + Display and log an error message for an operation. + + Args: + message: The error message to display to the user. + operation: Optional description of the operation that failed. + If provided, will be included in the log message. + """ + if operation: + self.logger.error(f"Operation failed: {operation} - Error: {message}") + else: + self.logger.error(f"Error: {message}") + + tk.messagebox.showerror(STR_ERROR, message) + + def _show_success(self, message: str, operation: str = "") -> None: + """ + Display and log a success message for an operation. + + Args: + message: The success message to display to the user. + operation: Optional description of the operation that succeeded. + If provided, will be included in the log message. + """ + log_message = f"Operation successful: {operation}" if operation else message + self.logger.info(log_message) + tk.messagebox.showinfo(STR_SUCCESS, message) + + def _setup_ui(self) -> None: + """Set up the UI components including block controls and domain list.""" + # Main container with increased padding + main_container = ttk.Frame(self.root, padding="20") + main_container.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + self.root.columnconfigure(0, weight=1) + self.root.rowconfigure(0, weight=1) + + # Left side - Specific sites block (now with better proportions) + sites_frame = ttk.LabelFrame( + main_container, + text="Specific Sites Block", + padding="15" + ) + sites_frame.grid( + row=0, + column=0, + rowspan=3, + padx=10, + sticky=(tk.W, tk.E, tk.N, tk.S) + ) + + # Create a frame for listbox and scrollbar + listbox_frame = ttk.Frame(sites_frame) + listbox_frame.grid(row=0, column=0, pady=5, sticky=(tk.W, tk.E, tk.N, tk.S)) + + # Domains listbox with scrollbars + self.domains_listbox = tk.Listbox( + listbox_frame, + width=40, + height=15, + selectmode=tk.SINGLE, + activestyle='dotbox', + font=('Arial', 10) + ) + scrollbar_y = ttk.Scrollbar( + listbox_frame, + orient=tk.VERTICAL, + command=self.domains_listbox.yview + ) + scrollbar_x = ttk.Scrollbar( + listbox_frame, + orient=tk.HORIZONTAL, + command=self.domains_listbox.xview + ) + + self.domains_listbox.configure( + yscrollcommand=scrollbar_y.set, + xscrollcommand=scrollbar_x.set + ) + + # Grid layout for listbox and scrollbars + self.domains_listbox.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S)) + scrollbar_y.grid(row=0, column=1, sticky=(tk.N, tk.S)) + scrollbar_x.grid(row=1, column=0, sticky=(tk.W, tk.E)) + + # Add domain entry with improved layout + domain_entry_frame = ttk.Frame(sites_frame) + domain_entry_frame.grid( + row=1, + column=0, + pady=15, + sticky=(tk.W, tk.E) + ) + + ttk.Label( + domain_entry_frame, + text="Add Domain:", + font=('Arial', 10) + ).grid(row=0, column=0, padx=5) + + self.domain_entry = ttk.Entry( + domain_entry_frame, + font=('Arial', 10) + ) + self.domain_entry.grid( + row=0, + column=1, + padx=5, + sticky=(tk.W, tk.E) + ) + + # Buttons with improved styling + button_frame = ttk.Frame(sites_frame) + button_frame.grid( + row=2, + column=0, + pady=10, + sticky=(tk.W, tk.E) + ) + + style = ttk.Style() + style.configure('Action.TButton', padding=5) + + ttk.Button( + button_frame, + text="Add Domain", + style='Action.TButton', + command=self._add_domain_request + ).grid(row=0, column=0, padx=5) + + ttk.Button( + button_frame, + text="Remove Domain", + style='Action.TButton', + command=self._remove_domain_request + ).grid(row=0, column=1, padx=5) + + # Right side controls with improved spacing + controls_frame = ttk.Frame(main_container) + controls_frame.grid( + row=0, + column=1, + padx=20, + sticky=(tk.N, tk.S) + ) + + # Ad Block controls with better styling + ad_frame = ttk.LabelFrame( + controls_frame, + text="Ad Blocking", + padding="15" + ) + ad_frame.grid( + row=0, + column=0, + pady=10, + sticky=(tk.W, tk.E) + ) + + # Initialize with config value + self.ad_var = tk.StringVar() + ttk.Radiobutton( + ad_frame, + text="Enable", + value="on", + variable=self.ad_var, + command=self._handle_ad_block_request + ).grid(row=0, column=0, padx=10) + ttk.Radiobutton( + ad_frame, + text="Disable", + value="off", + variable=self.ad_var, + command=self._handle_ad_block_request + ).grid(row=0, column=1, padx=10) + + # Adult sites Block controls + adult_frame = ttk.LabelFrame( + controls_frame, + text="Adult Content Blocking", + padding="15" + ) + adult_frame.grid( + row=1, + column=0, + pady=10, + sticky=(tk.W, tk.E) + ) + + # Initialize with config value + self.adult_var = tk.StringVar() + ttk.Radiobutton( + adult_frame, + text="Enable", + value="on", + variable=self.adult_var, + command=self._handle_adult_block_request + ).grid(row=0, column=0, padx=10) + ttk.Radiobutton( + adult_frame, + text="Disable", + value="off", + variable=self.adult_var, + command=self._handle_adult_block_request + ).grid(row=0, column=1, padx=10) + + # Configure grid weights for better resizing + main_container.columnconfigure(0, weight=3) + main_container.columnconfigure(1, weight=1) + sites_frame.columnconfigure(0, weight=1) + listbox_frame.columnconfigure(0, weight=1) + listbox_frame.rowconfigure(0, weight=1) + domain_entry_frame.columnconfigure(1, weight=1) + button_frame.columnconfigure(0, weight=1) + button_frame.columnconfigure(1, weight=1) + + # Bind events + self.domains_listbox.bind('', lambda e: self._remove_domain_request()) diff --git a/client/src/__init__.py b/client/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/src/utils.py b/client/src/utils.py new file mode 100644 index 0000000..d9c9eca --- /dev/null +++ b/client/src/utils.py @@ -0,0 +1,78 @@ +"""Utility module containing constants and common functions for the application.""" + +# Network related constants +DEFAULT_HOST: str = "127.0.0.1" +DEFAULT_PORT: str = "65432" +DEFAULT_BUFFER_SIZE: str = "1024" + +# GUI constants +WINDOW_TITLE = "Site Blocker" +WINDOW_SIZE = "800x600" +PADDING_SMALL = "5" +PADDING_MEDIUM = "10" + +# Message codes +class Codes: + """Constants for message codes used in communication.""" + CODE_AD_BLOCK = "50" + CODE_ADULT_BLOCK = "51" + CODE_ADD_DOMAIN = "52" + CODE_REMOVE_DOMAIN = "53" + CODE_DOMAIN_LIST_UPDATE = "54" + CODE_INIT_SETTINGS = "55" + CODE_SUCCESS = "100" + CODE_ERROR = "101" + +# Logging constants +LOG_DIR = "client_logs" +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +LOG_DATE_FORMAT = "%Y%m%d_%H%M%S" + +# Error messages +ERR_SOCKET_NOT_SETUP = "Socket not set up. Call connect method first." +ERR_NO_CONNECTION = "Attempted to send message without connection" +ERR_DUPLICATE_DOMAIN = "Domain already exists in the list" +ERR_NO_DOMAIN_SELECTED = "Please select a domain to remove" +ERR_DOMAIN_LIST_UPDATE_FAILED = "Failed to update domain list" + +# String Constants +STR_AD_BLOCK = "ad_block" +STR_ADULT_BLOCK = "adult_block" +STR_CODE = "code" +STR_CONTENT = "content" +STR_ERROR = "Error" +STR_DOMAINS = "domains" +STR_SUCCESS = "Success" +STR_OPERATION = "operation" + +# Operation constants +STR_REMOVE_DOMAIN_REQUEST = "Remove domain request" +STR_ADD_DOMAIN_REQUEST = "Add domain request" +STR_AD_BLOCK_RESPONSE = "Ad block response" +STR_ADD_DOMAIN_RESPONSE = "Add domain response" +STR_ADULT_BLOCK_RESPONSE = "Adult block response" +STR_REMOVE_DOMAIN_RESPONSE = "Remove domain response" + +# Config Constants +STR_BLOCKED_DOMAINS = "blocked_domains" +STR_NETWORK = "network" +STR_SETTINGS = "settings" +STR_LOGGING = "logging" +STR_HOST = "host" +STR_PORT = "port" +STR_RECEIVE_BUFFER_SIZE = "receive_buffer_size" +STR_LEVEL = "level" +STR_LOG_DIR = "log_dir" + +# Default settings +DEFAULT_CONFIG = { + STR_NETWORK: { + STR_HOST: DEFAULT_HOST, + STR_PORT: DEFAULT_PORT, + STR_RECEIVE_BUFFER_SIZE: DEFAULT_BUFFER_SIZE + }, + STR_LOGGING: { + STR_LEVEL: "INFO", + STR_LOG_DIR: LOG_DIR + } +} diff --git a/client/tests/__init__.py b/client/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/test_application.py b/client/tests/test_application.py new file mode 100644 index 0000000..c139ffe --- /dev/null +++ b/client/tests/test_application.py @@ -0,0 +1,75 @@ +import json +import pytest +from unittest import mock +from typing import Dict, Any + +from src.Application import Application +from src.utils import ( + STR_CODE, STR_CONTENT, STR_OPERATION, STR_DOMAINS, + Codes, DEFAULT_CONFIG +) + +@pytest.fixture +def mock_config_manager() -> mock.Mock: + """Fixture to provide a mock configuration manager.""" + config_manager = mock.Mock() + config_manager.get_config.return_value = DEFAULT_CONFIG.copy() + return config_manager + +@pytest.fixture +def application(mock_config_manager: mock.Mock) -> Application: + """Fixture to create an Application instance with mocked components.""" + with mock.patch('src.Application.Viewer') as mock_viewer, \ + mock.patch('src.Application.Communicator') as mock_comm, \ + mock.patch('src.Application.setup_logger'): + app = Application() + app._logger = mock.Mock() + app._config_manager = mock_config_manager + return app + +def test_init(application: Application) -> None: + """Test the initialization of Application.""" + assert hasattr(application, '_logger') + assert hasattr(application, '_communicator') + assert hasattr(application, '_config_manager') + +def test_handle_request_ad_block(application: Application) -> None: + """Test handling ad block request.""" + test_request = { + STR_CODE: Codes.CODE_AD_BLOCK, + STR_CONTENT: "on", + } + + application._communicator.send_message = mock.Mock() + application._handle_request(json.dumps(test_request)) + + application._communicator.send_message.assert_called_once() + sent_data = application._communicator.send_message.call_args[0][0] + assert sent_data == test_request + +def test_handle_request_domain_list_update(application: Application) -> None: + """Test handling domain list update request.""" + test_domains = ["domain1.com", "domain2.com"] + test_request = { + STR_CODE: Codes.CODE_SUCCESS, + STR_DOMAINS: test_domains, + STR_OPERATION: Codes.CODE_DOMAIN_LIST_UPDATE + } + + application._handle_request(json.dumps(test_request), to_server=False) + application._view.update_domain_list_response.assert_called_once_with(test_domains) + +def test_cleanup(application: Application) -> None: + """Test cleanup process.""" + application._cleanup() + application._communicator.close.assert_called_once() + +def test_handle_request_invalid_json(application: Application) -> None: + """Test handling invalid JSON in request.""" + invalid_json = "{" + + with pytest.raises(json.JSONDecodeError): + application._handle_request(invalid_json) + + application._logger.error.assert_called() + \ No newline at end of file diff --git a/client/tests/test_communicator.py b/client/tests/test_communicator.py new file mode 100644 index 0000000..acfbfdd --- /dev/null +++ b/client/tests/test_communicator.py @@ -0,0 +1,87 @@ +import socket +import pytest +from unittest import mock +from typing import Callable +import json + +from src.Communicator import Communicator +from src.utils import ( + DEFAULT_CONFIG, ERR_SOCKET_NOT_SETUP, + STR_NETWORK, STR_HOST, STR_PORT, STR_RECEIVE_BUFFER_SIZE +) + +@pytest.fixture +def mock_config_manager() -> mock.Mock: + """Fixture to provide a mock configuration manager.""" + config_manager = mock.Mock() + config_manager.get_config.return_value = DEFAULT_CONFIG + return config_manager + +@pytest.fixture +def mock_callback() -> Callable[[str], None]: + """Fixture to provide a mock callback function.""" + return mock.Mock() + +@pytest.fixture +def communicator( + mock_config_manager: mock.Mock, + mock_callback: Callable[[str], None] +) -> Communicator: + """Fixture to create a Communicator instance.""" + return Communicator( + config_manager=mock_config_manager, + message_callback=mock_callback + ) + +def test_init(communicator: Communicator, mock_callback: Callable[[str], None]) -> None: + """Test initialization of Communicator.""" + config = DEFAULT_CONFIG[STR_NETWORK] + assert communicator._host == config[STR_HOST] + assert communicator._port == int(config[STR_PORT]) + assert communicator._receive_buffer_size == int(config[STR_RECEIVE_BUFFER_SIZE]) + assert communicator._socket is None + assert communicator._message_callback == mock_callback + +@mock.patch('socket.socket') +def test_connect(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test socket connection.""" + mock_socket_instance = mock_socket_class.return_value + communicator.connect() + + mock_socket_class.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM) + mock_socket_instance.connect.assert_called_once_with( + (communicator._host, communicator._port) + ) + assert communicator._socket is mock_socket_instance + +@mock.patch('socket.socket') +def test_send_message_without_setup( + mock_socket_class: mock.Mock, + communicator: Communicator +) -> None: + """Test sending message without socket setup.""" + with pytest.raises(RuntimeError) as exc_info: + communicator.send_message("test message") + assert str(exc_info.value) == ERR_SOCKET_NOT_SETUP + +@mock.patch('socket.socket') +def test_send_message(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test sending message successfully.""" + mock_socket_instance = mock_socket_class.return_value + communicator._socket = mock_socket_instance + + message = {"test": "message"} + communicator.send_message(message) + + mock_socket_instance.send.assert_called_once_with(json.dumps(message).encode('utf-8')) + +@mock.patch('socket.socket') +def test_close_socket(mock_socket_class: mock.Mock, communicator: Communicator) -> None: + """Test closing socket connection.""" + mock_socket_instance = mock_socket_class.return_value + communicator._socket = mock_socket_instance + + communicator.close() + + mock_socket_instance.close.assert_called_once() + assert communicator._socket is None diff --git a/client/tests/test_view.py b/client/tests/test_view.py new file mode 100644 index 0000000..fa994eb --- /dev/null +++ b/client/tests/test_view.py @@ -0,0 +1,159 @@ +"""Unit tests for the Viewer class.""" + +import pytest +from unittest import mock +import json +from typing import Dict, Any + +from src.View import Viewer +from src.utils import ( + Codes, STR_CODE, STR_CONTENT, STR_DOMAINS, STR_SETTINGS, + STR_AD_BLOCK, STR_ADULT_BLOCK +) + +@pytest.fixture +def mock_config_manager() -> mock.Mock: + """Create a mock configuration manager fixture.""" + config_manager = mock.Mock() + config_manager.get_config.return_value = { + "network": { + "host": "127.0.0.1", + "port": 65432 + } + } + return config_manager + +@pytest.fixture +def mock_callback() -> mock.Mock: + """Create a mock callback function fixture.""" + return mock.Mock() + +@pytest.fixture +def mock_tk() -> mock.Mock: + """Create a mock for tkinter components.""" + with mock.patch('src.View.tk') as mock_tk: + # Mock Tk instance + mock_root = mock.Mock() + mock_tk.Tk.return_value = mock_root + + # Mock StringVar + mock_string_var = mock.Mock() + mock_string_var.get.return_value = "on" + mock_tk.StringVar.return_value = mock_string_var + + # Mock Listbox + mock_listbox = mock.Mock() + mock_listbox.get.side_effect = lambda start, end: ["domain1.com", "domain2.com"] + mock_tk.Listbox.return_value = mock_listbox + + yield mock_tk + +@pytest.fixture +def viewer( + mock_config_manager: mock.Mock, + mock_callback: mock.Mock, + mock_tk: mock.Mock +) -> Viewer: + """Create a Viewer instance with mocked dependencies.""" + with mock.patch('src.View.ttk'), \ + mock.patch('src.View.messagebox'), \ + mock.patch('src.View.setup_logger') as mock_logger: + logger_instance = mock.Mock() + mock_logger.return_value = logger_instance + + viewer = Viewer( + config_manager=mock_config_manager, + message_callback=mock_callback + ) + + # Set up instance variables that would normally be created in _setup_ui + viewer.domains_listbox = mock_tk.Listbox.return_value + viewer.ad_var = mock_tk.StringVar.return_value + viewer.adult_var = mock_tk.StringVar.return_value + viewer.domain_entry = mock.Mock() + + return viewer + +def test_handle_ad_block_request(viewer: Viewer) -> None: + """Test handling ad block request message formation.""" + expected_message = json.dumps({ + STR_CODE: Codes.CODE_AD_BLOCK, + STR_CONTENT: "on" + }) + + viewer._handle_ad_block_request() + viewer._message_callback.assert_called_once_with(expected_message) + +def test_handle_adult_block_request(viewer: Viewer) -> None: + """Test handling adult block request message formation.""" + expected_message = json.dumps({ + STR_CODE: Codes.CODE_ADULT_BLOCK, + STR_CONTENT: "on" + }) + + viewer._handle_adult_block_request() + viewer._message_callback.assert_called_once_with(expected_message) + +def test_update_initial_settings(viewer: Viewer) -> None: + """Test updating initial settings from server response.""" + test_settings = { + STR_DOMAINS: ["example.com", "test.com"], + STR_SETTINGS: { + STR_AD_BLOCK: "on", + STR_ADULT_BLOCK: "off" + } + } + + viewer.update_initial_settings(test_settings) + viewer.logger.info.assert_called_with("Successfully initialized settings from server") + +def test_update_domain_list_response(viewer: Viewer) -> None: + """Test updating domain list from server response.""" + test_domains = ["domain1.com", "domain2.com"] + + viewer.update_domain_list_response(test_domains) + viewer.logger.info.assert_called_with(f"Updated domain list with {len(test_domains)} domains") + +@pytest.mark.parametrize("response,expected_log", [ + ( + {STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: "test.com"}, + "info" + ), + ( + {STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: "Failed to add domain"}, + "error" + ) +]) +def test_add_domain_response( + viewer: Viewer, + response: Dict[str, Any], + expected_log: str +) -> None: + """Test handling add domain response from server.""" + # Reset the mock call counts before our test + viewer.logger.info.reset_mock() + viewer.logger.error.reset_mock() + + viewer.add_domain_response(response) + + if expected_log == "info": + viewer.logger.info.assert_called_once() + viewer.logger.error.assert_not_called() + else: + viewer.logger.error.assert_called_once() + +def test_get_blocked_domains(viewer: Viewer) -> None: + """Test getting list of blocked domains.""" + expected_domains = ["domain1.com", "domain2.com"] + domains = list(viewer.get_blocked_domains()) + assert domains == expected_domains + +def test_get_block_settings(viewer: Viewer) -> None: + """Test getting block settings.""" + settings = viewer.get_block_settings() + assert settings == { + STR_AD_BLOCK: "on", + STR_ADULT_BLOCK: "on" + } diff --git a/server/__init__.py b/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/main.py b/server/main.py new file mode 100644 index 0000000..53ce51e --- /dev/null +++ b/server/main.py @@ -0,0 +1,5 @@ +from src.server import initialize_server +from src.utils import DB_FILE + +if __name__ == '__main__': + initialize_server(DB_FILE) \ No newline at end of file diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 0000000..b3211cb --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,4 @@ +pytest==7.4.0 +pytest-mock==3.11.1 +pytest-asyncio==0.21.1 +pytest-cov==4.1.0 \ No newline at end of file diff --git a/server/src/__init__.py b/server/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/src/db_manager.py b/server/src/db_manager.py new file mode 100644 index 0000000..f13b18e --- /dev/null +++ b/server/src/db_manager.py @@ -0,0 +1,109 @@ +import sqlite3 +from typing import List +from .logger import setup_logger + +class DatabaseManager: + def __init__(self, db_file: str): + """Initialize database manager.""" + self.db_file = db_file + self.logger = setup_logger(__name__) + self._create_tables() + self.logger.info(f"Database initialized at {db_file}") + + def _create_tables(self) -> None: + """Create necessary database tables.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS blocked_domains ( + domain TEXT PRIMARY KEY + ) + """) + + cursor.execute(""" + INSERT OR IGNORE INTO settings (key, value) + VALUES + ('ad_block', 'off'), + ('adult_block', 'off') + """) + + conn.commit() + self.logger.info("Database tables created/verified") + + def get_setting(self, setting: str) -> str: + """Get setting value.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("""SELECT value + FROM settings + WHERE key = ?""", (setting,)) + result = cursor.fetchone() + value = result[0] if result else 'off' + self.logger.debug(f"Retrieved setting {setting}: {value}") + return value + + def update_setting(self, setting: str, value: str) -> None: + """Update setting value.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute(""" + UPDATE settings + SET value = ? + WHERE key = ? + """, (value, setting)) + conn.commit() + self.logger.info(f"Updated setting {setting} to {value}") + + def add_blocked_domain(self, domain: str) -> None: + """Add a domain to blocked list.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + try: + cursor.execute("""INSERT INTO blocked_domains (domain) + VALUES (?)""", (domain,)) + conn.commit() + self.logger.info(f"Domain {domain} added to block list") + except sqlite3.IntegrityError: + self.logger.warning(f"Domain {domain} already exists in the database") + + def remove_blocked_domain(self, domain: str) -> bool: + """Remove a domain from blocked list.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("""DELETE FROM blocked_domains + WHERE domain = ?""", (domain,)) + conn.commit() + if cursor.rowcount: + self.logger.info(f"Domain {domain} removed from block list") + else: + self.logger.warning(f"Domain {domain} not found in block list") + return bool(cursor.rowcount) + + def get_blocked_domains(self) -> List[str]: + """Get list of all blocked domains.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("""SELECT domain + FROM blocked_domains""") + domains = [row[0] for row in cursor.fetchall()] + self.logger.debug(f"Retrieved {len(domains)} blocked domains") + return domains + + def is_domain_blocked(self, domain: str) -> bool: + """Check if domain is in blocked list.""" + with sqlite3.connect(self.db_file) as conn: + cursor = conn.cursor() + cursor.execute("""SELECT 1 + FROM blocked_domains + WHERE domain = ?""", (domain,)) + is_blocked = cursor.fetchone() is not None + self.logger.debug(f"Domain {domain} blocked status: {is_blocked}") + return is_blocked \ No newline at end of file diff --git a/server/src/handlers.py b/server/src/handlers.py new file mode 100644 index 0000000..d273392 --- /dev/null +++ b/server/src/handlers.py @@ -0,0 +1,174 @@ +from typing import Dict, Any +from .db_manager import DatabaseManager +from .utils import ( + Codes, + STR_AD_BLOCK, STR_ADULT_BLOCK, + STR_CODE, STR_CONTENT, STR_DOMAINS, + STR_DOMAIN_BLOCKED_MSG, STR_DOMAIN_NOT_FOUND_MSG, + STR_DOMAIN_UNBLOCKED_MSG, STR_OPERATION, + invalid_json_response +) +from .logger import setup_logger + +class RequestHandler: + """Base class for request handlers.""" + def __init__(self, db_manager: DatabaseManager): + self.db_manager = db_manager + self.logger = setup_logger(__name__) + +class AdBlockHandler(RequestHandler): + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle ad block requests.""" + try: + if STR_CONTENT in request_data: + state = request_data[STR_CONTENT] + self.db_manager.update_setting(STR_AD_BLOCK, state) + self.logger.info(f"Ad blocking turned {state}") + return { + STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: f"Ad blocking turned {state}", + STR_OPERATION: Codes.CODE_AD_BLOCK + } + + return invalid_json_response() + + except Exception as e: + self.logger.error(f"Error in ad block handler: {e}") + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e), + STR_OPERATION: Codes.CODE_AD_BLOCK + } + +class AdultContentBlockHandler(RequestHandler): + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle adult content block requests.""" + try: + if STR_CONTENT in request_data: + state = request_data[STR_CONTENT] + self.db_manager.update_setting(STR_ADULT_BLOCK, state) + self.logger.info(f"Adult content blocking turned {state}") + return { + STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: f"Adult content blocking turned {state}", + STR_OPERATION: Codes.CODE_ADULT_BLOCK + } + + return invalid_json_response() + + except Exception as e: + self.logger.error(f"Error in adult content block handler: {e}") + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e), + STR_OPERATION: Codes.CODE_ADULT_BLOCK + } + +class DomainBlockHandler(RequestHandler): + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle domain blocking requests.""" + try: + if STR_CONTENT not in request_data: + self.logger.warning("Invalid request format: missing content") + return invalid_json_response() + + operation_code = request_data[STR_CODE] + domain = request_data[STR_CONTENT] + + match operation_code: + case Codes.CODE_ADD_DOMAIN: + if self.db_manager.is_domain_blocked(domain): + self.logger.warning(f"Domain already blocked: {domain}") + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: f"Domain {domain} is already blocked", + STR_OPERATION: Codes.CODE_ADD_DOMAIN + } + + self.db_manager.add_blocked_domain(domain) + self.logger.info(f"Domain blocked: {domain}") + return { + STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: domain, + STR_OPERATION: Codes.CODE_ADD_DOMAIN + } + + case Codes.CODE_REMOVE_DOMAIN: + if self.db_manager.remove_blocked_domain(domain): + self.logger.info(f"Domain unblocked: {domain}") + return { + STR_CODE: Codes.CODE_SUCCESS, + STR_CONTENT: domain, + STR_OPERATION: Codes.CODE_REMOVE_DOMAIN + } + + self.logger.warning(f"Domain not found for unblocking: {domain}") + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: domain, + STR_OPERATION: Codes.CODE_REMOVE_DOMAIN + } + + self.logger.warning(f"Invalid action requested: {request_data[STR_CODE]}") + return invalid_json_response() + + except Exception as e: + self.logger.error(f"Error in domain block handler: {e}") + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e), + STR_OPERATION: operation_code + } + +class DomainListHandler(RequestHandler): + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Handle domain list requests.""" + try: + domains = self.db_manager.get_blocked_domains() + self.logger.info(f"Domain list requested, returned {len(domains)} domains") + return { + STR_CODE: Codes.CODE_SUCCESS, + STR_DOMAINS: domains, + STR_OPERATION: Codes.CODE_DOMAIN_LIST_UPDATE + } + + except Exception as e: + self.logger.error(f"Error in domain list handler: {e}") + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e), + STR_OPERATION: Codes.CODE_DOMAIN_LIST_UPDATE + } + +class RequestFactory: + """Factory class for creating appropriate request handlers.""" + def __init__(self, db_manager: DatabaseManager): + self.db_manager = db_manager + self.logger = setup_logger(__name__) + self.handlers = { + Codes.CODE_AD_BLOCK : AdBlockHandler(db_manager), + Codes.CODE_ADULT_BLOCK : AdultContentBlockHandler(db_manager), + Codes.CODE_ADD_DOMAIN : DomainBlockHandler(db_manager), + Codes.CODE_REMOVE_DOMAIN : DomainBlockHandler(db_manager), + Codes.CODE_DOMAIN_LIST_UPDATE : DomainListHandler(db_manager) + } + + def handle_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """Route request to appropriate handler.""" + try: + code = request_data.get(STR_CODE) + handler = self.handlers.get(code) + + if handler: + self.logger.debug(f"Handling request with code: {code}") + return handler.handle_request(request_data) + + self.logger.warning(f"Invalid request code: {code}") + return invalid_json_response() + + except Exception as e: + self.logger.error(f"Error in request factory: {e}") + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: str(e) + } diff --git a/server/src/logger.py b/server/src/logger.py new file mode 100644 index 0000000..6056c89 --- /dev/null +++ b/server/src/logger.py @@ -0,0 +1,45 @@ +"""Logger module for handling application-wide logging configuration.""" + +import logging +import os +from datetime import datetime +from typing import Optional +from .utils import LOG_DIR, LOG_FORMAT, LOG_DATE_FORMAT + +_logger: Optional[logging.Logger] = None + +def setup_logger(name: str) -> logging.Logger: + """ + Configure and return a logger instance. + + Args: + name: The name of the module requesting the logger. + + Returns: + logging.Logger: Configured logger instance. + """ + global _logger + + if _logger is not None: + return logging.getLogger(name) + + if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR) + + log_file: str = os.path.join( + LOG_DIR, f"server_{datetime.now().strftime(LOG_DATE_FORMAT)}.log" + ) + + logging.basicConfig( + level=logging.INFO, + format=LOG_FORMAT, + handlers=[ + logging.FileHandler(log_file), + logging.StreamHandler(), + ], + ) + + _logger = logging.getLogger(name) + _logger.info("Logger setup complete") + + return _logger \ No newline at end of file diff --git a/server/src/server.py b/server/src/server.py new file mode 100644 index 0000000..3b970f7 --- /dev/null +++ b/server/src/server.py @@ -0,0 +1,196 @@ +from typing import Dict, Any, Optional +import socket +import threading +import json +import asyncio +from .utils import ( + CLIENT_PORT, DEFAULT_ADDRESS, KERNEL_PORT, + STR_AD_BLOCK, STR_ADULT_BLOCK, STR_CODE, STR_DOMAINS, STR_CONTENT, + STR_TOGGLE_ON, STR_TOGGLE_OFF, STR_DOMAIN, STR_OPERATION, STR_SETTINGS, + Codes, invalid_json_response +) +from .db_manager import DatabaseManager +from .handlers import RequestFactory +from .logger import setup_logger + +class Server: + def __init__(self, db_manager: DatabaseManager) -> None: + """Initialize server with database manager.""" + self.db_manager = db_manager + self.request_factory = RequestFactory(self.db_manager) + self.running = True + self.logger = setup_logger(__name__) + self.logger.info("Server initialized") + + def handle_client_thread(self) -> None: + """Handle client connections using traditional socket.""" + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.bind((DEFAULT_ADDRESS, CLIENT_PORT)) + client_socket.listen(1) + client_socket.settimeout(1.0) + self.logger.info(f"Client server running on {DEFAULT_ADDRESS}:{CLIENT_PORT}") + + try: + while self.running: + try: + conn, addr = client_socket.accept() + self.logger.info(f"Client connected from {addr}") + + conn.settimeout(1.0) + + try: + # Send initial settings + initial_settings = self._get_initial_settings() + conn.send(json.dumps(initial_settings).encode() + b'\n') + self.logger.debug(f"Sent initial settings: {initial_settings}") + + while True: + try: + data = conn.recv(1024) + if not data: + break + + try: + request_data = json.loads(data.decode()) + self.logger.debug(f"Received request: {request_data}") + response = self.request_factory.handle_request(request_data) + + conn.send(json.dumps(response).encode() + b'\n') + self.logger.debug(f"Sent response: {response}") + + except json.JSONDecodeError: + self.logger.error("Invalid JSON format received") + conn.send(json.dumps(invalid_json_response()).encode() + b'\n') + + except socket.timeout: + if not self.running: + break + continue + finally: + conn.close() + + except socket.timeout: + if not self.running: + break + continue + except Exception as e: + self.logger.error(f"Client error: {e}") + + finally: + client_socket.close() + + async def handle_kernel_requests( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter + ) -> None: + """Handle kernel requests using asyncio for better performance.""" + addr = writer.get_extra_info('peername') + self.logger.info(f"Kernel module connected from {addr}") + + try: + while True: + data = await reader.readline() + if not data: + break + + request_data = json.loads(data.decode()) + domain = request_data.get(STR_DOMAIN, '').strip() + + if not domain: + continue + + ad_block_enabled = self.db_manager.get_setting(STR_AD_BLOCK) == STR_TOGGLE_ON + adult_block_enabled = self.db_manager.get_setting(STR_ADULT_BLOCK) == STR_TOGGLE_ON + + block_reason = None + should_block = False + + if self.db_manager.is_domain_blocked(domain): + should_block = True + block_reason = "custom_blocklist" + self.logger.info(f"Domain {domain} blocked (custom blocklist)") + + elif ad_block_enabled and request_data.get('is_ad', False): + should_block = True + block_reason = "ads" + self.logger.info(f"Domain {domain} blocked (ads)") + + elif adult_block_enabled and 'adult' in request_data.get('categories', []): + should_block = True + block_reason = "adult_content" + self.logger.info(f"Domain {domain} blocked (adult content)") + + response = { + 'block': should_block, + 'reason': block_reason or 'allowed', + 'domain': domain + } + + self.logger.debug(f"Domain check result: {domain} -> {'blocked' if should_block else 'allowed'} ({block_reason or 'no reason'})") + + writer.write(json.dumps(response).encode() + b'\n') + await writer.drain() + + except Exception as e: + self.logger.error(f"Kernel error: {e}") + finally: + writer.close() + await writer.wait_closed() + self.logger.info(f"Kernel connection closed for {addr}") + + async def start_server(self) -> None: + """Run both client and kernel handlers.""" + client_thread: Optional[threading.Thread] = None + kernel_server: Optional[asyncio.Server] = None + + try: + client_thread = threading.Thread(target=self.handle_client_thread) + client_thread.start() + self.logger.info("Client handler thread started") + + kernel_server = await asyncio.start_server( + self.handle_kernel_requests, + DEFAULT_ADDRESS, + KERNEL_PORT + ) + self.logger.info(f"Kernel server running on {DEFAULT_ADDRESS}:{KERNEL_PORT}") + + async with kernel_server: + await kernel_server.serve_forever() + + except Exception as e: + self.logger.error(f"Server error: {e}") + raise + finally: + self.running = False + # Clean up resources + if kernel_server: + kernel_server.close() + await kernel_server.wait_closed() + if client_thread and client_thread.is_alive(): + client_thread.join(timeout=1.0) + + def _get_initial_settings(self) -> Dict[str, Any]: + """Get initial settings and domain list for client initialization.""" + try: + domains = self.db_manager.get_blocked_domains() + settings = { + STR_AD_BLOCK: self.db_manager.get_setting(STR_AD_BLOCK), + STR_ADULT_BLOCK: self.db_manager.get_setting(STR_ADULT_BLOCK) + } + + return { + STR_OPERATION: Codes.CODE_INIT_SETTINGS, + STR_DOMAINS: domains, + STR_SETTINGS: settings + } + except Exception as e: + self.logger.error(f"Error getting initial settings: {e}") + return invalid_json_response() + +def initialize_server(db_file: str) -> None: + """Initialize and run the server.""" + db_manager = DatabaseManager(db_file) + server = Server(db_manager) + asyncio.run(server.start_server()) diff --git a/server/src/utils.py b/server/src/utils.py new file mode 100644 index 0000000..c479808 --- /dev/null +++ b/server/src/utils.py @@ -0,0 +1,96 @@ +"""Utility module containing constants and common functions for the application.""" + +import os +from pathlib import Path + +# Network related constants +DEFAULT_ADDRESS: str = "127.0.0.1" +CLIENT_PORT: int = 65432 +KERNEL_PORT: int = 65433 +BUFFER_SIZE: int = 1024 + +# Base directories +BASE_DIR = Path(__file__).parent.parent +LOG_DIR = os.path.join(BASE_DIR, "logs") +DB_FILE: str = 'my_internet.db' + +# Message codes +class Codes: + """Constants for message codes used in communication.""" + CODE_AD_BLOCK = "50" + CODE_ADULT_BLOCK = "51" + CODE_ADD_DOMAIN = "52" + CODE_REMOVE_DOMAIN = "53" + CODE_DOMAIN_LIST_UPDATE = "54" + CODE_SUCCESS = "100" + CODE_ERROR = "101" + CODE_ACK = "99" + CODE_INIT_SETTINGS = "55" +# Logging constants +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +LOG_DATE_FORMAT = "%Y%m%d_%H%M%S" + +# Message Types and Codes +STR_CODE = "code" +STR_CONTENT = "content" +STR_OPERATION = "operation" +STR_SETTINGS = "settings" +# STR_TYPE = "type" +# STR_ACTION = "action" +# STR_MESSAGE_ID = "message_id" +# STR_ACK = "ack" + +# Domain Related +STR_DOMAIN = "domain" +STR_DOMAINS = "domains" +STR_BLOCK = "block" +STR_UNBLOCK = "unblock" +# STR_IS_AD = "is_ad" +# STR_CATEGORIES = "categories" +# STR_REASON = "reason" + +# Features and Settings +STR_AD_BLOCK = "ad_block" +STR_ADULT_BLOCK = "adult_block" +STR_TOGGLE_ON = "on" +STR_TOGGLE_OFF = "off" + +# Status and Response Keys +STR_ERROR = "Error" +STR_SUCCESS = "success" +# STR_INVALID_REQUEST = "invalid_request" +# STR_DOMAIN_BLOCKED = "domain_blocked" +# STR_DOMAIN_NOT_FOUND = "domain_not_found" +# STR_INVALID_JSON = "invalid_json" + +# Block Reasons +# STR_CUSTOM_BLOCKLIST = "custom_blocklist" +# STR_ADS = "ads" +# STR_ADULT_CONTENT = "adult_content" +# STR_ALLOWED = "allowed" +# STR_DOMAIN_LIST = "domain_list" + +# Response Messages +STR_DOMAIN_BLOCKED_MSG = "Domain has been successfully blocked." +STR_DOMAIN_UNBLOCKED_MSG = "Domain has been successfully unblocked." +STR_DOMAIN_NOT_FOUND_MSG = "Domain not found in block list." +STR_INVALID_JSON_MSG = "Invalid JSON format." +# STR_REQUEST_PROCESSED = "Request processed successfully." +# STR_INVALID_REQUEST_FORMAT = "Invalid request format." +# STR_DOMAIN_EXISTS_MSG = "Domain already exists in block list." +# STR_ACK_TIMEOUT_MSG = "Acknowledgment timeout occurred." + +# Config Constants +# STR_BLOCKED_DOMAINS = "blocked_domains" +# STR_NETWORK = "network" +# STR_SETTINGS = "settings" +# STR_LOGGING = "logging" + +# Timeouts +# ACK_TIMEOUT = 5.0 # seconds + +def invalid_json_response(): + return { + STR_CODE: Codes.CODE_ERROR, + STR_CONTENT: STR_INVALID_JSON_MSG + } diff --git a/server/tests/__init__.py b/server/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/conftest.py b/server/tests/conftest.py new file mode 100644 index 0000000..a44aaac --- /dev/null +++ b/server/tests/conftest.py @@ -0,0 +1,55 @@ +import pytest +from unittest import mock +from typing import Generator +from My_Internet.server.src.server import Server +from My_Internet.server.src.db_manager import DatabaseManager + +@pytest.fixture +def mock_db_manager() -> mock.Mock: + """Create a mock database manager.""" + db_manager = mock.Mock(spec=DatabaseManager) + db_manager.get_blocked_domains.return_value = [] + db_manager.get_setting.return_value = 'off' + db_manager.is_domain_blocked.return_value = False + return db_manager + +@pytest.fixture +def server_instance(mock_db_manager: mock.Mock) -> Server: + """Create a server instance for testing.""" + server = Server(mock_db_manager) + server.logger = mock.Mock() # Mock the logger to prevent actual logging + return server + +@pytest.fixture +def mock_socket() -> mock.Mock: + """Create a mock socket for testing.""" + socket_mock = mock.Mock() + socket_mock.bind = mock.Mock() + socket_mock.listen = mock.Mock() + socket_mock.accept = mock.Mock() + socket_mock.close = mock.Mock() + socket_mock.settimeout = mock.Mock() + return socket_mock + +@pytest.fixture +def mock_stream_reader() -> mock.AsyncMock: + """Create a mock stream reader.""" + reader = mock.AsyncMock() + reader.readline = mock.AsyncMock() + return reader + +@pytest.fixture +def mock_stream_writer() -> mock.Mock: + """Create a mock stream writer.""" + writer = mock.Mock() + writer.write = mock.Mock() + writer.drain = mock.AsyncMock() + writer.close = mock.Mock() + writer.wait_closed = mock.AsyncMock() + writer.get_extra_info = mock.Mock(return_value=('127.0.0.1', 12345)) + return writer + +@pytest.fixture +def mock_asyncio_start_server() -> mock.AsyncMock: + """Create a mock for asyncio.start_server.""" + return mock.AsyncMock() \ No newline at end of file diff --git a/server/tests/test_handlers.py b/server/tests/test_handlers.py new file mode 100644 index 0000000..af6dc36 --- /dev/null +++ b/server/tests/test_handlers.py @@ -0,0 +1,136 @@ +from typing import Dict, Any +import pytest +from unittest import mock +from My_Internet.server.src.handlers import ( + RequestHandler, + AdBlockHandler, + AdultContentBlockHandler, + DomainBlockHandler, + DomainListHandler, + RequestFactory +) +from My_Internet.server.src.utils import Codes, RESPONSE_MESSAGES + +@pytest.fixture +def mock_db_manager() -> mock.Mock: + """Create a mock database manager.""" + return mock.Mock() + +@pytest.fixture +def mock_logger() -> mock.Mock: + """Create a mock logger.""" + return mock.Mock() + +class TestAdBlockHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock) -> AdBlockHandler: + """Create AdBlockHandler instance.""" + return AdBlockHandler(mock_db_manager) + + def test_handle_request_toggle_on(self, handler: AdBlockHandler) -> None: + """Test handling ad block toggle on request.""" + request_data: Dict[str, Any] = {'action': 'on'} + response = handler.handle_request(request_data) + + handler.db_manager.update_setting.assert_called_once_with('ad_block', 'on') + assert response['code'] == Codes.CODE_AD_BLOCK + assert response['message'] == "Ad blocking turned on" + + def test_handle_request_error(self, handler: AdBlockHandler) -> None: + """Test handling error in ad block request.""" + handler.db_manager.update_setting.side_effect = Exception("Test error") + response = handler.handle_request({'action': 'on'}) + + assert response['code'] == Codes.CODE_AD_BLOCK + assert response['message'] == "Test error" + +class TestAdultContentBlockHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock) -> AdultContentBlockHandler: + """Create AdultContentBlockHandler instance.""" + return AdultContentBlockHandler(mock_db_manager) + + def test_handle_request_toggle_on(self, handler: AdultContentBlockHandler) -> None: + """Test handling adult content block toggle on request.""" + request_data: Dict[str, Any] = {'action': 'on'} + response = handler.handle_request(request_data) + + handler.db_manager.update_setting.assert_called_once_with('adult_block', 'on') + assert response['code'] == Codes.CODE_ADULT_BLOCK + assert response['message'] == "Adult content blocking turned on" + +class TestDomainBlockHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock) -> DomainBlockHandler: + """Create DomainBlockHandler instance.""" + return DomainBlockHandler(mock_db_manager) + + def test_block_domain(self, handler: DomainBlockHandler) -> None: + """Test blocking a domain.""" + request_data: Dict[str, Any] = { + 'action': 'block', + 'domain': 'example.com' + } + response = handler.handle_request(request_data) + + handler.db_manager.add_blocked_domain.assert_called_once_with('example.com') + assert response['code'] == Codes.CODE_ADD_DOMAIN + assert response['message'] == RESPONSE_MESSAGES['domain_blocked'] + + def test_unblock_domain(self, handler: DomainBlockHandler) -> None: + """Test unblocking a domain.""" + handler.db_manager.remove_blocked_domain.return_value = True + request_data: Dict[str, Any] = { + 'action': 'unblock', + 'domain': 'example.com' + } + response = handler.handle_request(request_data) + + handler.db_manager.remove_blocked_domain.assert_called_once_with('example.com') + assert response['code'] == Codes.CODE_REMOVE_DOMAIN + assert response['message'] == RESPONSE_MESSAGES['success'] + + def test_invalid_request(self, handler: DomainBlockHandler) -> None: + """Test handling invalid request.""" + request_data: Dict[str, Any] = {'action': 'block'} # Missing domain + response = handler.handle_request(request_data) + + assert response['code'] == Codes.CODE_ADD_DOMAIN + assert response['message'] == RESPONSE_MESSAGES['invalid_request'] + +class TestDomainListHandler: + @pytest.fixture + def handler(self, mock_db_manager: mock.Mock) -> DomainListHandler: + """Create DomainListHandler instance.""" + return DomainListHandler(mock_db_manager) + + def test_get_domain_list(self, handler: DomainListHandler) -> None: + """Test getting list of blocked domains.""" + domains = ['example.com', 'test.com'] + handler.db_manager.get_blocked_domains.return_value = domains + response = handler.handle_request({}) + + assert response['code'] == Codes.CODE_DOMAIN_LIST_UPDATE + assert response['domains'] == domains + assert response['message'] == RESPONSE_MESSAGES['success'] + +class TestRequestFactory: + @pytest.fixture + def factory(self, mock_db_manager: mock.Mock) -> RequestFactory: + """Create RequestFactory instance.""" + return RequestFactory(mock_db_manager) + + def test_handle_valid_request(self, factory: RequestFactory) -> None: + """Test handling valid request with correct code.""" + request_data: Dict[str, Any] = { + 'code': Codes.CODE_AD_BLOCK, + 'action': 'on' + } + response = factory.handle_request(request_data) + assert response['code'] == Codes.CODE_AD_BLOCK + + def test_handle_invalid_code(self, factory: RequestFactory) -> None: + """Test handling request with invalid code.""" + request_data: Dict[str, Any] = {'code': 'invalid_code'} + response = factory.handle_request(request_data) + assert response['message'] == RESPONSE_MESSAGES['invalid_request'] \ No newline at end of file diff --git a/server/tests/test_server.py b/server/tests/test_server.py new file mode 100644 index 0000000..121eaf2 --- /dev/null +++ b/server/tests/test_server.py @@ -0,0 +1,218 @@ +import pytest +import json +import asyncio +from unittest import mock +from typing import Dict, Any, Generator +from My_Internet.server.src.server import Server +from My_Internet.server.src.utils import HOST, CLIENT_PORT, KERNEL_PORT + +@pytest.mark.asyncio +async def test_handle_kernel_requests_block_custom_domain( + server_instance: Server, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock +) -> None: + """Test handling kernel requests for custom blocked domain.""" + server_instance.db_manager.is_domain_blocked.return_value = True + mock_stream_reader.readline.side_effect = [ + json.dumps({ + 'domain': 'example.com', + 'is_ad': False, + 'categories': [] + }).encode() + b'\n', + b'' + ] + + await server_instance.handle_kernel_requests(mock_stream_reader, mock_stream_writer) + + response_data = json.loads(mock_stream_writer.write.call_args[0][0].decode().strip()) + assert response_data['block'] is True + assert response_data['reason'] == 'custom_blocklist' + assert response_data['domain'] == 'example.com' + +@pytest.mark.asyncio +async def test_handle_kernel_requests_block_ad( + server_instance: Server, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock +) -> None: + """Test handling kernel requests for ad blocking.""" + server_instance.db_manager.is_domain_blocked.return_value = False + server_instance.db_manager.get_setting.side_effect = lambda x: 'on' if x == 'ad_block' else 'off' + mock_stream_reader.readline.side_effect = [ + json.dumps({ + 'domain': 'ad.example.com', + 'is_ad': True, + 'categories': [] + }).encode() + b'\n', + b'' + ] + + await server_instance.handle_kernel_requests(mock_stream_reader, mock_stream_writer) + + response_data = json.loads(mock_stream_writer.write.call_args[0][0].decode().strip()) + assert response_data['block'] is True + assert response_data['reason'] == 'ads' + assert response_data['domain'] == 'ad.example.com' + +@pytest.mark.asyncio +async def test_handle_kernel_requests_block_adult_content( + server_instance: Server, + mock_stream_reader: mock.AsyncMock, + mock_stream_writer: mock.Mock +) -> None: + """Test handling kernel requests for adult content blocking.""" + server_instance.db_manager.is_domain_blocked.return_value = False + server_instance.db_manager.get_setting.side_effect = lambda x: 'on' if x == 'adult_block' else 'off' + mock_stream_reader.readline.side_effect = [ + json.dumps({ + 'domain': 'adult.example.com', + 'is_ad': False, + 'categories': ['adult'] + }).encode() + b'\n', + b'' + ] + + await server_instance.handle_kernel_requests(mock_stream_reader, mock_stream_writer) + + response_data = json.loads(mock_stream_writer.write.call_args[0][0].decode().strip()) + assert response_data['block'] is True + assert response_data['reason'] == 'adult_content' + assert response_data['domain'] == 'adult.example.com' + +def test_handle_client_thread_initial_domain_list( + server_instance: Server, + mock_socket: mock.Mock, + monkeypatch: pytest.MonkeyPatch +) -> None: + """Test sending initial domain list to client.""" + mock_conn = mock.Mock() + mock_conn.send = mock.Mock() + mock_conn.recv.return_value = b'' + + server_instance.db_manager.get_blocked_domains.return_value = ['example.com'] + + mock_socket_instance = mock.Mock() + mock_socket_instance.accept = mock.Mock(return_value=(mock_conn, ('127.0.0.1', 12345))) + + mock_socket_class = mock.Mock(return_value=mock_socket_instance) + monkeypatch.setattr('socket.socket', mock_socket_class) + + def mock_accept(*args: Any, **kwargs: Any) -> tuple[mock.Mock, tuple[str, int]]: + server_instance.running = False + return mock_conn, ('127.0.0.1', 12345) + + mock_socket_instance.accept = mock.Mock(side_effect=mock_accept) + + server_instance.handle_client_thread() + + mock_conn.send.assert_called() + sent_data = json.loads(mock_conn.send.call_args_list[0][0][0].decode().strip()) + assert sent_data['type'] == 'domain_list' + assert isinstance(sent_data['domains'], list) + assert 'example.com' in sent_data['domains'] + +def test_handle_client_thread_process_request( + server_instance: Server, + mock_socket: mock.Mock, + monkeypatch: pytest.MonkeyPatch +) -> None: + """Test processing client request.""" + mock_conn = mock.Mock() + mock_conn.send = mock.Mock() + + mock_conn.recv.side_effect = [ + json.dumps({'code': '50', 'action': 'on'}).encode(), + b'' + ] + + mock_socket_instance = mock.Mock() + mock_socket_instance.accept = mock.Mock(return_value=(mock_conn, ('127.0.0.1', 12345))) + + mock_socket_class = mock.Mock(return_value=mock_socket_instance) + monkeypatch.setattr('socket.socket', mock_socket_class) + + def mock_accept(*args: Any, **kwargs: Any) -> tuple[mock.Mock, tuple[str, int]]: + server_instance.running = False + return mock_conn, ('127.0.0.1', 12345) + + mock_socket_instance.accept = mock.Mock(side_effect=mock_accept) + + server_instance.db_manager.get_blocked_domains.return_value = ['example.com'] + + mock_request_factory = mock.Mock() + mock_request_factory.handle_request.return_value = { + 'status': 'success', + 'message': 'Request processed' + } + server_instance.request_factory = mock_request_factory + + server_instance.handle_client_thread() + + assert mock_conn.send.call_count >= 2 + assert server_instance.request_factory.handle_request.called + +def test_handle_client_thread_invalid_json( + server_instance: Server, + mock_socket: mock.Mock, + monkeypatch: pytest.MonkeyPatch +) -> None: + """Test handling invalid JSON request.""" + mock_conn = mock.Mock() + mock_conn.send = mock.Mock() + mock_conn.recv.side_effect = [b'invalid json', b''] + + mock_socket_instance = mock.Mock() + mock_socket_instance.accept = mock.Mock(return_value=(mock_conn, ('127.0.0.1', 12345))) + + mock_socket_class = mock.Mock(return_value=mock_socket_instance) + monkeypatch.setattr('socket.socket', mock_socket_class) + + def mock_accept(*args: Any, **kwargs: Any) -> tuple[mock.Mock, tuple[str, int]]: + server_instance.running = False + return mock_conn, ('127.0.0.1', 12345) + + mock_socket_instance.accept = mock.Mock(side_effect=mock_accept) + + server_instance.db_manager.get_blocked_domains.return_value = ['example.com'] + + server_instance.handle_client_thread() + + assert mock_conn.send.call_count >= 2 + sent_data = json.loads(mock_conn.send.call_args_list[1][0][0].decode().strip()) + assert sent_data['status'] == 'error' + assert 'Invalid JSON format' in sent_data['message'] + +@pytest.mark.asyncio +async def test_start_server( + server_instance: Server, + mock_asyncio_start_server: mock.AsyncMock, + monkeypatch: pytest.MonkeyPatch +) -> None: + """Test server startup process.""" + monkeypatch.setattr('asyncio.start_server', mock_asyncio_start_server) + + mock_thread = mock.Mock() + mock_thread_class = mock.Mock(return_value=mock_thread) + monkeypatch.setattr('threading.Thread', mock_thread_class) + + mock_kernel_server = mock.AsyncMock() + mock_kernel_server.__aenter__ = mock.AsyncMock(return_value=mock_kernel_server) + mock_kernel_server.__aexit__ = mock.AsyncMock() + mock_kernel_server.close = mock.AsyncMock() + + async def mock_serve_forever(): + server_instance.running = False + + mock_kernel_server.serve_forever = mock.AsyncMock(side_effect=mock_serve_forever) + mock_asyncio_start_server.return_value = mock_kernel_server + + await server_instance.start_server() + + mock_thread_class.assert_called_once() + mock_thread.start.assert_called_once() + assert mock_asyncio_start_server.called + assert mock_kernel_server.__aenter__.called + assert mock_kernel_server.serve_forever.called + assert not server_instance.running + await mock_kernel_server.close() \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b660429 --- /dev/null +++ b/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="My_Internet", + version="0.1", + packages=find_packages(), +) \ No newline at end of file