From bd55c77f4bb98fba1a22c454585229a0a9ed8350 Mon Sep 17 00:00:00 2001 From: linuxdaemon Date: Sun, 25 Jun 2017 23:41:56 -0500 Subject: [PATCH] Initial source commit --- .gitignore | 177 +++++++++++++++++++++++++++ README.md | 30 +++++ asyncirc/__init__.py | 6 + asyncirc/irc.py | 282 +++++++++++++++++++++++++++++++++++++++++++ asyncirc/protocol.py | 280 ++++++++++++++++++++++++++++++++++++++++++ asyncirc/server.py | 35 ++++++ requirements.txt | 2 + setup.cfg | 5 + setup.py | 23 ++++ tests/parser_test.py | 46 +++++++ 10 files changed, 886 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 asyncirc/__init__.py create mode 100644 asyncirc/irc.py create mode 100644 asyncirc/protocol.py create mode 100644 asyncirc/server.py create mode 100644 requirements.txt create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/parser_test.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7eb69c4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,177 @@ +# Created by .ignore support plugin (hsz.mobi) +### VirtualEnv template +# Virtualenv +# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ +.Python +[Bb]in +[Ii]nclude +[Ll]ib +[Ll]ib64 +[Ll]ocal +[Ss]cripts +pyvenv.cfg +.venv +pip-selfcheck.json +### Linux template +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*,cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +venv/ +ENV/ + +# Spyder project settings +.spyderproject + +# Rope project settings +.ropeproject +### JetBrains template +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff: +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/dictionaries + +# Sensitive or high-churn files: +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.xml +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml + +# Gradle: +.idea/**/gradle.xml +.idea/**/libraries + +# Mongo Explorer plugin: +.idea/**/mongoSettings.xml + +## File-based project format: +*.iws + +## Plugin-specific files: + +# IntelliJ +/out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties +### Vim template +# swap +[._]*.s[a-v][a-z] +[._]*.sw[a-p] +[._]s[a-v][a-z] +[._]sw[a-p] +# session +Session.vim +# temporary +.netrwhist +# auto-generated tag files +tags diff --git a/README.md b/README.md new file mode 100644 index 0000000..53e3a5b --- /dev/null +++ b/README.md @@ -0,0 +1,30 @@ +# async-irc +An implementation of asyncio.Protocol for IRC + +### Example +```(py) +import asyncio + +from asyncirc.protocol import IrcProtocol +from asyncirc.server import Server + +servers = [ + Server("irc.example.org", 6697, True), + Server("irc.example.com", 6667), +] + +async def log(conn, message): + print(message) + +async def main(): + conn = IrcProtocol(servers, "BotNick", loop=loop) + conn.register_cap('userhost-in-names') + conn.register('*', log) + await conn.connect() + await asyncio.sleep(24 * 60 * 60) + +try: + loop.run_until_complete(main()) +finally: + loop.stop() +``` \ No newline at end of file diff --git a/asyncirc/__init__.py b/asyncirc/__init__.py new file mode 100644 index 0000000..fb40124 --- /dev/null +++ b/asyncirc/__init__.py @@ -0,0 +1,6 @@ +# coding=utf-8 +""" +Async IRC Interface Library +""" + +__all__ = ["error", "irc", "protocol", "server"] diff --git a/asyncirc/irc.py b/asyncirc/irc.py new file mode 100644 index 0000000..4690066 --- /dev/null +++ b/asyncirc/irc.py @@ -0,0 +1,282 @@ +# coding=utf-8 +""" +Basic parser objects + logic +""" + +import re +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import List, Tuple, Dict, Iterable, AnyStr + +TAGS_SENTINEL = '@' +TAGS_SEP = ';' +TAG_VALUE_SEP = '=' + +PREFIX_SENTINEL = ':' +PREFIX_USER_SEP = '!' +PREFIX_HOST_SEP = '@' + +PARAM_SEP = ' ' +TRAIL_SENTINEL = ':' + +CAP_SEP = ' ' +CAP_VALUE_SEP = '=' + +PREFIX_RE = re.compile(r':?(?P.+?)(?:!(?P.+?))?(?:@(?P.+?))?') + +TAG_VALUE_ESCAPES = { + '\\s': ' ', + '\\:': ';', + '\\r': '\r', + '\\n': '\n', + '\\\\': '\\', +} +TAG_VALUE_UNESCAPES = { + unescaped: escaped + for escaped, unescaped in TAG_VALUE_ESCAPES.items() +} + + +class Parseable(ABC): + """Abstract class for parseable objects""" + + @abstractmethod + def __str__(self): + return NotImplemented + + @staticmethod + @abstractmethod + def parse(text: str) -> 'Parseable': + """Parse the object from a string""" + return NotImplemented + + +class Cap(Parseable): + """Represents a CAP entity as defined in IRCv3.2""" + + def __init__(self, name: str, value: str = None): + self.name = name + self.value = value or None + + def __str__(self): + if self.value: + return CAP_VALUE_SEP.join((self.name, self.value)) + return self.name + + @staticmethod + def parse(text: str) -> 'Cap': + """Parse a CAP entity from a string""" + name, _, value = text.partition(CAP_VALUE_SEP) + return Cap(name, value) + + +class CapList(Parseable, List[Cap]): + """Represents a list of CAP entities""" + + def __str__(self) -> str: + return CAP_SEP.join(self) + + @staticmethod + def parse(text: str) -> 'CapList': + """Parse a list of CAPs from a string""" + return CapList(map(Cap.parse, text.split(CAP_SEP))) + + +class MessageTag(Parseable): + """ + Basic class to wrap a message tag + """ + + def __init__(self, name: str, value: str = None) -> None: + self.name = name + self.value = value + + @staticmethod + def unescape(value: str) -> str: + """ + Replace the escaped characters in a tag value with their literals + :param value: Escaped string + :return: Unescaped string + """ + new_value = "" + found = False + for i in range(len(value)): + if found: + found = False + continue + if value[i] == '\\': + if i + 1 >= len(value): + raise ValueError("Unexpected end of string while parsing: {}".format(value)) + new_value += TAG_VALUE_ESCAPES[value[i:i + 2]] + found = True + else: + new_value += value[i] + return new_value + + @staticmethod + def escape(value: str) -> str: + """ + Replace characters with their escaped variants + :param value: The raw string + :return: The escaped string + """ + return "".join(TAG_VALUE_UNESCAPES.get(c, c) for c in value) + + def __str__(self) -> str: + if self.value: + return "{}{}{}".format( + self.name, TAG_VALUE_SEP, self.escape(self.value) + ) + return self.name + + @staticmethod + def parse(text: str) -> 'MessageTag': + """ + Parse a tag from a string + :param text: The basic tag string + :return: The MessageTag object + """ + name, _, value = text.partition(TAG_VALUE_SEP) + if value: + value = MessageTag.unescape(value) + return MessageTag(name, value or None) + + +class TagList(Parseable, OrderedDict, Dict[str, MessageTag]): + """Object representing the list of message tags on a line""" + + def __init__(self, tags: Iterable[MessageTag]) -> None: + super().__init__((tag.name, tag) for tag in tags) + + def __str__(self) -> str: + return TAGS_SENTINEL + TAGS_SEP.join(map(str, self.values())) + + @staticmethod + def parse(text: str) -> 'TagList': + """ + Parse the list of tags from a string + :param text: The string to parse + :return: The parsed object + """ + return TagList( + map(MessageTag.parse, filter(None, text.split(TAGS_SEP))) + ) + + +class Prefix(Parseable): + """ + Object representing the prefix of a line + """ + + def __init__(self, nick: str, user: str = None, host: str = None) -> None: + self.nick = nick + self.user = user + self.host = host + + @property + def mask(self) -> str: + """ + The complete n!u@h mask + """ + m = self.nick + if self.user: + m += PREFIX_USER_SEP + self.user + if self.host: + m += PREFIX_HOST_SEP + self.host + return m + + def __str__(self) -> str: + return PREFIX_SENTINEL + self.mask + + def __bool__(self) -> bool: + return bool(self.nick) + + @staticmethod + def parse(text: str) -> 'Prefix': + """ + Parse the prefix from a string + :param text: String to parse + :return: Parsed Object + """ + if not text: + return Prefix('') + match = PREFIX_RE.fullmatch(text) + assert match, "Prefix did not match prefix pattern" + nick, user, host = match.groups() + return Prefix(nick, user, host) + + +class ParamList(Parseable, List[str]): + """ + An object representing the parameter list from a line + """ + + def __init__(self, seq: Iterable[str], has_trail: bool = False) -> None: + super().__init__(seq) + self.has_trail = has_trail or (self and PARAM_SEP in self[-1]) + + def __str__(self) -> str: + if self.has_trail and self[-1][0] != TRAIL_SENTINEL: + return PARAM_SEP.join(self[:-1] + [TRAIL_SENTINEL + self[-1]]) + return PARAM_SEP.join(self) + + @staticmethod + def parse(text: str) -> 'ParamList': + """ + Parse a list of parameters + :param text: The list of parameters + :return: The parsed object + """ + args = [] + has_trail = False + while text: + if text[0] == TRAIL_SENTINEL: + args.append(text[1:]) + has_trail = True + break + arg, _, text = text.partition(PARAM_SEP) + if arg: + args.append(arg) + return ParamList(args, has_trail=has_trail) + + +class Message(Parseable): + """ + An object representing a parsed IRC line + """ + + def __init__(self, tags: TagList = None, prefix: Prefix = None, command: str = None, + parameters: ParamList = None) -> None: + self.tags = tags + self.prefix = prefix + self.command = command + self.parameters = parameters + + @property + def parts(self) -> Tuple[TagList, Prefix, str, ParamList]: + """The parts that make up this message""" + return self.tags, self.prefix, self.command, self.parameters + + def __str__(self) -> str: + return PARAM_SEP.join(map(str, filter(None, self.parts))) + + def __bool__(self) -> bool: + return any(self.parts) + + @staticmethod + def parse(text: AnyStr) -> 'Message': + """Parse an IRC message in to objects""" + if isinstance(text, bytes): + text = text.decode() + tags = '' + prefix = '' + if text.startswith(TAGS_SENTINEL): + tags, _, text = text.partition(PARAM_SEP) + if text.startswith(PREFIX_SENTINEL): + prefix, _, text = text.partition(PARAM_SEP) + command, _, params = text.partition(PARAM_SEP) + tags = TagList.parse(tags[1:]) + prefix = Prefix.parse(prefix[1:]) + command = command.upper() + params = ParamList.parse(params) + return Message(tags, prefix, command, params) diff --git a/asyncirc/protocol.py b/asyncirc/protocol.py new file mode 100644 index 0000000..9c4cd98 --- /dev/null +++ b/asyncirc/protocol.py @@ -0,0 +1,280 @@ +# coding=utf-8 +""" +Basic asyncio.Protocol interface for IRC connections +""" +import asyncio +import base64 +import random +import ssl +from asyncio import Protocol +from collections import defaultdict +from enum import IntEnum, auto, unique +from itertools import cycle +from typing import Sequence, Optional, Tuple, Callable, Dict, Coroutine, AnyStr, TYPE_CHECKING + +from asyncirc.irc import Message, CapList +from asyncirc.server import ConnectedServer + +if TYPE_CHECKING: + from logging import Logger + from asyncirc.server import Server + from asyncio import AbstractEventLoop, Transport + + +@unique +class SASLMechanism(IntEnum): + """Represents different SASL auth mechanisms""" + NONE = auto() + PLAIN = auto() + EXTERNAL = auto() + + +async def _internal_ping(conn: 'IrcProtocol', message: 'Message'): + conn.send("PONG {}".format(message.parameters)) + + +async def _internal_cap_handler(conn: 'IrcProtocol', message: 'Message'): + if message.parameters[1] == 'LS': + caplist = CapList.parse(message.parameters[-1]) + for cap in caplist: + if cap.name in conn.cap_handlers: + conn.server.caps[cap.name] = None + + if message.parameters[2] != '*': + for cap in conn.server.caps: + conn.send("CAP REQ :{}".format(cap)) + + elif message.parameters[1] in ('ACK', 'NAK'): + caplist = CapList.parse(message.parameters[-1]) + enabled = message.parameters[1] == 'ACK' + for cap in caplist: + conn.server.caps[cap.name] = enabled + if enabled: + handlers = filter(None, conn.cap_handlers[cap.name]) + await asyncio.gather(*[func(conn) for func in handlers]) + if all(val is not None for val in conn.server.caps.values()): + conn.send("CAP END") + + +async def _do_sasl(conn: 'IrcProtocol'): + if not conn.sasl_mech or conn.sasl_mech is SASLMechanism.NONE: + return + conn.send("AUTHENTICATE {}".format(conn.sasl_mech.name)) + auth_msg = await conn.wait_for("AUTHENTICATE", timeout=5) + if auth_msg and auth_msg.parameters[0] == '+': + auth_line = '+' + if conn.sasl_mech is SASLMechanism.PLAIN: + auth_line = '\0'.join((conn.nick, *conn.sasl_auth)) + auth_line = base64.b64encode(auth_line.encode()).decode() + conn.send("AUTHENTICATE {}".format(auth_line)) + + +class IrcProtocol(Protocol): + """Async IRC Interface""" + + _transport: Optional['Transport'] = None + _buff = b"" + _server: Optional['ConnectedServer'] = None + + def __init__(self, servers: Sequence['Server'], nick: str, user: str = None, realname: str = None, + certpath: str = None, sasl_auth: Tuple[str, str] = None, sasl_mech: SASLMechanism = None, + logger: 'Logger' = None, loop: 'AbstractEventLoop' = None) -> None: + self.servers = servers + self.nick = nick + self._user = user + self._realname = realname + self.certpath = certpath + self.sasl_auth = sasl_auth + self.sasl_mech = SASLMechanism(sasl_mech or SASLMechanism.NONE) + self.logger = logger + self.loop = loop or asyncio.get_event_loop() + + if self.sasl_mech == SASLMechanism.PLAIN: + assert self.sasl_auth, "You must specify sasl_auth when using SASL PLAIN" + + self._connected = False + self._quitting = False + + self.handlers: Dict[int, Tuple[str, Callable]] = {} + self.cap_handlers = defaultdict(list) + + self._connected_future = self.loop.create_future() + self.quit_future = self.loop.create_future() + + self.register("PING", _internal_ping) + self.register("CAP", _internal_cap_handler) + self.register_cap('sasl', _do_sasl) + + def __call__(self, *args, **kwargs) -> 'IrcProtocol': + """ + This is here to allow an instance of IrcProtocol to be passed + directly to AbstractEventLoop.create_connection() + """ + return self + + async def __aenter__(self) -> 'IrcProtocol': + return self.__enter__() + + async def __aexit__(self, *exc): + self.quit() + await self.quit_future + + def __enter__(self): + return self + + def __exit__(self, *exc): + self.quit() + + async def connect(self) -> None: + """Attempt to connect to the server, cycling through the server list until successful""" + for server in cycle(self.servers): + if await self._connect(server): + break + + async def _connect(self, server: 'Server') -> bool: + self._connected_future = self.loop.create_future() + self.quit_future = self.loop.create_future() + self._server = ConnectedServer(server) + if self.logger: + if self.connected: + self.logger.info("Reconnecting to %s", self.server) + else: + self.logger.info("Connecting to %s", self.server) + if self.server.is_ssl: + ssl_ctx = ssl.create_default_context() + if self.certpath: + ssl_ctx.load_cert_chain(self.certpath) + else: + ssl_ctx = None + fut = self.loop.create_connection(self, self.server.host, self.server.port, ssl=ssl_ctx) + try: + await asyncio.wait_for(fut, 30) + except asyncio.TimeoutError: + return False + return True + + def register(self, cmd: str, handler: Callable[['IrcProtocol', 'Message'], Coroutine]) -> int: + """Register a command handler""" + hook_id = 0 + while not hook_id or hook_id in self.handlers: + hook_id = random.randint(1, (2 ** 32) - 1) + self.handlers[hook_id] = (cmd, handler) + return hook_id + + def unregister(self, hook_id: int) -> None: + """Unregister a hook""" + del self.handlers[hook_id] + + def register_cap(self, cap: str, handler: Optional[Callable[['IrcProtocol'], Coroutine]] = None) -> None: + """Register a CAP handler + + If the handler is None, the CAP will be requested from the server, but no handler will be called, + allowing registration of CAPs that only require basic requests + """ + self.cap_handlers[cap].append(handler) + + async def wait_for(self, *cmds: str, timeout: int = None) -> None: + """Wait for a specific command from the server, optionally returning after [timeout] seconds""" + if not cmds: + return + fut = self.loop.create_future() + + # noinspection PyUnusedLocal + async def _wait(conn: 'IrcProtocol', message: 'Message') -> None: + if not fut.done(): + fut.set_result(message) + + hooks = [ + self.register(cmd, _wait) for cmd in cmds + ] + + try: + result = await asyncio.wait_for(fut, timeout) + except asyncio.TimeoutError: + result = None + finally: + for hook_id in hooks: + self.unregister(hook_id) + return result + + def send(self, text: AnyStr) -> None: + """Send a raw line to the server""" + asyncio.run_coroutine_threadsafe(self._send(text), self.loop) + + async def _send(self, text: AnyStr) -> None: + if not self.connected: + await self._connected_future + if isinstance(text, str): + text = text.encode() + if self.logger: + self.logger.info(">> %s", text.decode()) + self._transport.write(text + b'\r\n') + + def quit(self, reason: str = None) -> None: + """Quit the IRC connection with an optional reason""" + if not self._quitting: + self._quitting = True + if reason: + self.send("QUIT {}".format(reason)) + else: + self.send("QUIT") + + def connection_made(self, transport: 'Transport') -> None: + """Called by the event loop when the connection has been established""" + self._transport = transport + self._connected = True + self._connected_future.set_result(None) + del self._connected_future + self.send("CAP LS 302") + if self.server.password: + self.send("PASS {}".format(self.server.password)) + self.send("NICK {}".format(self.nick)) + self.send("USER {} 0 * :{}".format(self.user, self.realname)) + + def connection_lost(self, exc) -> None: + """Connection to the IRC server has been lost""" + self._transport = None + self._connected = False + if not self._quitting: + self._connected_future = self.loop.create_future() + asyncio.run_coroutine_threadsafe(self.connect(), self.loop) + else: + self.quit_future.set_result(None) + + def data_received(self, data: bytes) -> None: + """Called by the event loop when data has been read from the socket""" + self._buff += data + while b'\r\n' in self._buff: + raw_line, self._buff = self._buff.split(b'\r\n', 1) + message = Message.parse(raw_line) + for trigger, func in self.handlers.values(): + if trigger in (message.command, '*'): + self.loop.create_task(func(self, message)) + + @property + def user(self) -> str: + """The username used for this connection""" + return self._user or self.nick + + @user.setter + def user(self, value: str) -> None: + self._user = value + + @property + def realname(self) -> str: + """The realname or GECOS used for this connection""" + return self._realname or self.nick + + @realname.setter + def realname(self, value: str) -> None: + self._realname = value + + @property + def connected(self) -> bool: + """Whether or not the connection is still active""" + return self._connected + + @property + def server(self) -> Optional['ConnectedServer']: + """The current server object""" + return self._server diff --git a/asyncirc/server.py b/asyncirc/server.py new file mode 100644 index 0000000..668df38 --- /dev/null +++ b/asyncirc/server.py @@ -0,0 +1,35 @@ +# coding=utf-8 +""" +Server objects used for different stages in the connect process +to store contextual data +""" +from typing import Optional, Dict + +from asyncirc.irc import Cap + + +class Server: + """Represents a server to connect to""" + + def __init__(self, host: str, port: int, is_ssl: bool = False, password: str = None): + self.host = host + self.port = port + self.password = password + self.is_ssl = is_ssl + + def __str__(self) -> str: + if self.is_ssl: + return "{}:+{}".format(self.host, self.port) + return "{}:{}".format(self.host, self.port) + + +class ConnectedServer(Server): + """Represents a connected server + + Used to store session data like ISUPPORT tokens and enabled CAPs + """ + + def __init__(self, server): + super().__init__(server.host, server.port, server.is_ssl, server.password) + self.isupport_tokens: Dict[str, str] = {} + self.caps: Dict[Cap, Optional[bool]] = {} diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..27472ef --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +pytest +setuptools diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..42b1684 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,5 @@ +[aliases] +test=pytest + +[tool:pytest] +testpaths = tests \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..169a1bb --- /dev/null +++ b/setup.py @@ -0,0 +1,23 @@ +# coding=utf-8 +from setuptools import setup + +setup( + name='async-irc', + version='0.1.2', + python_requires=">=3.6", + description="A simple asyncio.Protocol implementation designed for IRC", + url='https://github.com/SnoonetIRC/async-irc', + author='linuxdaemon', + author_email='linuxdaemon@snoonet.org', + license='MIT', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', + ], + keywords='asyncio irc asyncirc async-irc irc-framework', + packages=['asyncirc'], + install_requires=[], + setup_requires=['pytest-runner'], + tests_require=['pytest'], +) diff --git a/tests/parser_test.py b/tests/parser_test.py new file mode 100644 index 0000000..8f4edff --- /dev/null +++ b/tests/parser_test.py @@ -0,0 +1,46 @@ +# coding=utf-8 +from pytest import raises + +from asyncirc.irc import MessageTag, Cap, CapList + + +class TestCaps: + def test_cap_list(self): + cases = ( + ( + "blah blah-blah cap-1 test-cap=value-data", + (("blah", None), ("blah-blah", None), ("cap-1", None), ("test-cap", "value-data")) + ), + ) + + for text, expected in cases: + parsed = CapList.parse(text) + assert len(parsed) == len(expected) + for (name, value), actual in zip(expected, parsed): + assert actual.name == name + assert actual.value == value + + def test_caps(self): + cases = ( + ("vendor.example.org/cap-name", "vendor.example.org/cap-name", None), + ) + + for text, name, value in cases: + cap = Cap.parse(text) + assert cap.name == name + assert cap.value == value + + +def test_message_tags(): + cases = ( + ("a=b", "a", "b"), + ("test/blah=", "test/blah", None), + ("blah=aa\\r\\n\\:\\\\", "blah", "aa\r\n;\\"), + ) + for text, name, value in cases: + tag = MessageTag.parse(text) + assert tag.name == name + assert tag.value == value + + with raises(ValueError): + MessageTag.parse("key=value\\")