diff --git a/nginx_config_reloader/__init__.py b/nginx_config_reloader/__init__.py index 6d910a0..d281157 100644 --- a/nginx_config_reloader/__init__.py +++ b/nginx_config_reloader/__init__.py @@ -2,23 +2,23 @@ from __future__ import absolute_import import argparse +import asyncio import fnmatch import logging import logging.handlers import os import shutil import signal +import socket import subprocess import sys import threading import time -from typing import Callable, Optional import pyinotify -from pynats import NATSClient, NATSMessage from nginx_config_reloader.copy_files import safe_copy_files -from nginx_config_reloader.nats import initialize_nats, publish_nats_message +from nginx_config_reloader.nats_client import get_nats_client from nginx_config_reloader.settings import ( BACKUP_CONFIG_DIR, CUSTOM_CONFIG_DIR, @@ -50,7 +50,6 @@ def my_init( magento2_flag=None, notifier=None, use_systemd=False, - nats_client: Optional[NATSClient] = None, ): """Constructor called by ProcessEvent @@ -75,8 +74,8 @@ def my_init( self.notifier = notifier self.use_systemd = use_systemd self.dirty = False + self.dirty_cluster = False self.applying = False - self.nats_client = nats_client def process_IN_DELETE(self, event): """Triggered by inotify on removal of file or removal of dir @@ -107,11 +106,8 @@ def process_IN_MOVE_SELF(self, event): def handle_event(self, event): if not any(fnmatch.fnmatch(event.name, pat) for pat in WATCH_IGNORE_FILES): self.logger.info("{} detected on {}.".format(event.maskname, event.name)) - if self.nats_client: - if not self.dirty: - self.nats_client = publish_nats_message(self.nats_client) - else: - self.dirty = True + self.dirty = True + self.dirty_cluster = True def install_magento_config(self): # Check if configs are present @@ -311,49 +307,25 @@ def after_loop(nginx_config_reloader: NginxConfigReloader) -> None: nginx_config_reloader.applying = False -def construct_message_handler( - nginx_config_reloader: NginxConfigReloader, -) -> Callable[[NATSMessage], None]: - def message_handler(msg: NATSMessage) -> None: - if msg.subject == NATS_SUBJECT and msg.payload == NATS_RELOAD_BODY: - logger.debug("NATS message received, reloading config") - nginx_config_reloader.dirty = True - # Trigger manually to ensure it's running. The `.applying` flag will prevent - # concurrent runs. - after_loop(nginx_config_reloader) - - return message_handler - - -def start_message_subscribe_loop( - nginx_config_reloader: NginxConfigReloader, nats_server: str -) -> None: - def listen_nats() -> None: +def watch_for_changes(nginx_config_reloader: NginxConfigReloader, nats_server: str): + async def listen() -> None: + client = await get_nats_client(nats_server) while True: - # Create new connection to throw away any old subscriptions. - # This is useful when there are many writes queued up, and we - # only want to reload once. - try: - nc = initialize_nats(nats_server) - nginx_config_reloader.nats_client = nc - sub = nc.subscribe( - NATS_SUBJECT, - callback=construct_message_handler(nginx_config_reloader), - max_messages=1, - ) - nc.auto_unsubscribe(sub) - except Exception as e: - logger.debug(f"Couldn't make NATS client: {e}") - continue - - logger.debug(f"Waiting for message on {NATS_SUBJECT}") - try: - nc.wait(count=1) - except Exception as e: - logger.debug(f"NATS error: {e}") + # Dirty flag is only set for local changes, not for NATS changes + # So if nginx_config_reloader + if nginx_config_reloader.dirty_cluster: + nginx_config_reloader.dirty_cluster = False + try: + await client.publish( + NATS_SUBJECT, + NATS_RELOAD_BODY, + headers={"From": socket.gethostname()}, + ) + except Exception as e: + logger.error(f"Error while publishing event: {e}") + await asyncio.sleep(1) - t = threading.Thread(target=listen_nats) - t.start() + asyncio.run(listen()) def wait_loop( @@ -373,11 +345,13 @@ def wait_loop( renamed or removed, the inotify-handler raises an exception to break out of the inner loop and we're back here in the outer loop. - :param obj logger: The logger object + :param logging.Logger logger: The logger object :param bool no_magento_config: True if we should not install Magento configuration :param bool no_custom_config: True if we should not copy custom configuration :param str dir_to_watch: The directory to watch :param bool recursive_watch: True if we should watch the dir recursively + :param use_systemd: True if we should reload nginx using systemd instead of process signal + :param nats_server: NATS server to connect to. If not set, NATS will not be used. :return None: """ dir_to_watch = os.path.abspath(dir_to_watch) @@ -400,7 +374,11 @@ def process_IN_DELETE(self, event): ) if nats_server: - start_message_subscribe_loop(nginx_config_changed_handler, nats_server) + nats_thread = threading.Thread( + target=watch_for_changes, + args=(nginx_config_changed_handler, nats_server), + ) + nats_thread.start() while True: while not os.path.exists(dir_to_watch): @@ -431,8 +409,6 @@ def process_IN_DELETE(self, event): logger.critical(err) except ListenTargetTerminated: logger.warning("Configuration dir lost, waiting for it to reappear") - if nc: - nc.close() def as_unprivileged_user(): diff --git a/nginx_config_reloader/nats.py b/nginx_config_reloader/nats.py deleted file mode 100644 index db96221..0000000 --- a/nginx_config_reloader/nats.py +++ /dev/null @@ -1,32 +0,0 @@ -import logging - -from pynats import NATSClient - -from nginx_config_reloader.settings import NATS_RELOAD_BODY, NATS_SUBJECT - -logger = logging.getLogger(__name__) - - -def initialize_nats(url: str) -> NATSClient: - logger.debug(f"Initializing NATS connection to {url}") - - nc = NATSClient(url) - nc.connect() - nc.ping() - return nc - - -def client_to_url(nc: NATSClient) -> str: - url = f"{nc._conn_options.scheme}://{nc._conn_options.hostname}:{nc._conn_options.port}" - logger.debug(f"Converting NATS client to URL: {url}") - return url - - -def publish_nats_message(nc: NATSClient) -> NATSClient: - logger.debug(f"Publishing to NATS: {NATS_SUBJECT} {NATS_RELOAD_BODY!r}") - try: - nc.publish(subject=NATS_SUBJECT, payload=NATS_RELOAD_BODY) - except Exception as e: - logger.exception(f"NATS publish failed, recreating connection: {e}") - return initialize_nats(client_to_url(nc)) - return nc diff --git a/nginx_config_reloader/nats_client.py b/nginx_config_reloader/nats_client.py new file mode 100644 index 0000000..5acecfb --- /dev/null +++ b/nginx_config_reloader/nats_client.py @@ -0,0 +1,67 @@ +import argparse +import logging +import ssl +from typing import Optional + +import nats +from nats.aio.client import Client + +logger = logging.getLogger(__name__) + + +def get_default_nats_ssl_context() -> dict: + # NATS SSL context is defined in /etc/defaults/nginx_config_reloader + try: + ssl_context = {} + with open("/etc/default/nginx_config_reloader") as f: + for line in f.readlines(): + if line.startswith("NATS_CERT="): + ssl_context["crt"] = line.split("=")[1].strip() + elif line.startswith("NATS_KEY="): + ssl_context["key"] = line.split("=")[1].strip() + elif line.startswith("NATS_CA="): + ssl_context["ca"] = line.split("=")[1].strip() + return ssl_context + except FileNotFoundError: + pass + logger.warning(f"Couldn't find NATS_SSL_CONTEXT, assuming no SSL") + return {} + + +def get_ssl_context(args: Optional[argparse.Namespace] = None): + if args and args.nats_cert and args.nats_key and args.nats_ca: + return {"crt": args.nats_cert, "key": args.nats_key, "ca": args.nats_ca} + return get_default_nats_ssl_context() + + +async def error_cb(e): + logger.warning(f"Error: {e}") + + +async def reconnected_cb(): + logger.info("Got reconnected to NATS...") + + +async def get_nats_client(server) -> Client: + logger.debug(f"Connecting to NATS server on {server}") + options = { + "servers": [server], + "error_cb": error_cb, + "reconnected_cb": reconnected_cb, + "drain_timeout": 3, + "max_reconnect_attempts": 5, # 3 tries in total + "connect_timeout": 1, + "reconnect_time_wait": 1, + } + + ssl_context = get_ssl_context() + if ssl_context: + ssl_ctx = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH) + ssl_ctx.load_verify_locations(ssl_context.get("ca")) + ssl_ctx.load_cert_chain( + certfile=ssl_context.get("crt"), keyfile=ssl_context.get("key") + ) + options["tls"] = ssl_ctx + + nc = await nats.connect(**options) + return nc diff --git a/requirements.txt b/requirements.txt index 681275c..b6932d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ pytest-xdist==3.2.0 tox==4.4.5 black==23.1.0 pre-commit==2.21.0 --e git+https://github.com/ByteInternet/nats-python.git@755ce98487ad15bec2889365d8b7caa4b2455e84#egg=nats-python +nats-py==2.6.0