Skip to content

Commit

Permalink
fix snmpv3 auth_none. fix unknown pid. typing. python3.12.
Browse files Browse the repository at this point in the history
  • Loading branch information
Koos85 committed Sep 16, 2024
1 parent 9eaddca commit 15c6c45
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 61 deletions.
4 changes: 3 additions & 1 deletion asyncsnmplib/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class Class(enum.IntEnum):
TNumber = Union[Number, int]
TType = Union[Type, int]
TClass = Union[Class, int]
TOid = Tuple[int, ...]
TValue = Any


class Tag(NamedTuple):
Expand Down Expand Up @@ -546,7 +548,7 @@ def _decode_null(bytes_data: bytes) -> None:
raise Error("ASN1 syntax error")

@staticmethod
def _decode_object_identifier(bytes_data: bytes) -> tuple:
def _decode_object_identifier(bytes_data: bytes) -> TOid:
result: List[int] = []
value: int = 0
for i in range(len(bytes_data)):
Expand Down
62 changes: 38 additions & 24 deletions asyncsnmplib/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
from typing import Iterable, Optional, Tuple, List
from .exceptions import (
SnmpNoConnection,
SnmpErrorNoSuchName,
SnmpTooMuchRows,
SnmpNoAuthParams,
)
from .asn1 import Tag, TOid, TValue
from .package import SnmpMessage
from .pdu import SnmpGet, SnmpGetNext, SnmpGetBulk
from .protocol import SnmpProtocol
Expand All @@ -17,8 +19,14 @@
class Snmp:
version = 1 # = v2

def __init__(self, host, port=161, community='public', max_rows=10000):
self._loop = asyncio.get_event_loop()
def __init__(
self,
host: str,
port: int = 161,
community: str = 'public',
max_rows: int = 10_000,
loop: Optional[asyncio.AbstractEventLoop] = None):
self._loop = loop if loop else asyncio.get_running_loop()
self._protocol = None
self._transport = None
self.host = host
Expand All @@ -28,7 +36,7 @@ def __init__(self, host, port=161, community='public', max_rows=10000):

# On some systems it seems to be required to set the remote_addr argument
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_datagram_endpoint
async def connect(self, timeout=10):
async def connect(self, timeout: float = 10.0):
try:
infos = await self._loop.getaddrinfo(self.host, self.port)
family, *_, addr = infos[0]
Expand All @@ -44,7 +52,7 @@ async def connect(self, timeout=10):
self._transport = transport

def _get(self, oids, timeout=None):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGet(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
Expand All @@ -54,32 +62,34 @@ def _get(self, oids, timeout=None):
return self._protocol.send(message)

def _get_next(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGetNext(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
return self._protocol.send(message)

def _get_bulk(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGetBulk(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
return self._protocol.send(message)

async def get(self, oid, timeout=None):
async def get(self, oid: TOid, timeout: Optional[float] = None
) -> Tuple[TOid, Tag, TValue]:
vbs = await self._get([oid], timeout)
return vbs[0]

async def get_next(self, oid):
async def get_next(self, oid: TOid) -> Tuple[TOid, Tag, TValue]:
vbs = await self._get_next([oid])
return vbs[0]

async def get_next_multi(self, oids):
async def get_next_multi(self, oids: Iterable[TOid]
) -> List[Tuple[TOid, TValue]]:
vbs = await self._get_next(oids)
return [(oid, value) for oid, _, value in vbs if oid[:-1] in oids]

async def walk(self, oid):
async def walk(self, oid: TOid) -> List[Tuple[TOid, TValue]]:
next_oid = oid
prefixlen = len(oid)
rows = []
Expand Down Expand Up @@ -115,7 +125,7 @@ def close(self):
class SnmpV1(Snmp):
version = 0

async def walk(self, oid):
async def walk(self, oid: TOid) -> List[Tuple[TOid, TValue]]:
next_oid = oid
prefixlen = len(oid)
rows = []
Expand Down Expand Up @@ -150,15 +160,16 @@ class SnmpV3(Snmp):

def __init__(
self,
host,
username,
auth_proto='USM_AUTH_NONE',
auth_passwd=None,
priv_proto='USM_PRIV_NONE',
priv_passwd=None,
port=161,
max_rows=10000):
self._loop = asyncio.get_event_loop()
host: str,
username: str,
auth_proto: str = 'USM_AUTH_NONE',
auth_passwd: Optional[str] = None,
priv_proto: str = 'USM_PRIV_NONE',
priv_passwd: Optional[str] = None,
port: int = 161,
max_rows: int = 10_000,
loop: Optional[asyncio.AbstractEventLoop] = None):
self._loop = loop if loop else asyncio.get_running_loop()
self._protocol = None
self._transport = None
self.host = host
Expand Down Expand Up @@ -191,7 +202,7 @@ def __init__(

# On some systems it seems to be required to set the remote_addr argument
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_datagram_endpoint
async def connect(self, timeout=10):
async def connect(self, timeout: float = 10.0):
try:
infos = await self._loop.getaddrinfo(self.host, self.port)
family, *_, addr = infos[0]
Expand All @@ -211,6 +222,9 @@ async def connect(self, timeout=10):
raise SnmpNoAuthParams

async def _get_auth_params(self, timeout=10):
# TODO for long requests this will need to be refreshed
# https://datatracker.ietf.org/doc/html/rfc3414#section-2.2.3
assert self._protocol is not None
pdu = SnmpGet(0, [])
message = SnmpV3Message.make(pdu, [b'', 0, 0, b'', b'', b''])
# this request will not retry like the other requests
Expand All @@ -225,7 +239,7 @@ async def _get_auth_params(self, timeout=10):
if self._priv_proto else None

def _get(self, oids, timeout=None):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand All @@ -248,7 +262,7 @@ def _get(self, oids, timeout=None):
self._priv_hash_localized)

def _get_next(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand All @@ -262,7 +276,7 @@ def _get_next(self, oids):
self._priv_hash_localized)

def _get_bulk(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand Down
1 change: 0 additions & 1 deletion asyncsnmplib/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
__all__ = (
"SnmpTimeoutError",
"SnmpUnsupportedValueType",
"SnmpErrorTooBig",
"SnmpErrorNoSuchName",
"SnmpErrorBadValue",
Expand Down
21 changes: 14 additions & 7 deletions asyncsnmplib/mib/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Tuple, Union
from typing import Tuple, Union, List
from ..asn1 import TOid, TValue
from .mib_index import MIB_INDEX
from .syntax_funs import SYNTAX_FUNS

Expand All @@ -8,7 +9,7 @@
FLAGS_SEPERATOR = ','


def on_octet_string(value: bytes) -> str:
def on_octet_string(value: TValue) -> Union[str, None]:
"""
used as a fallback for OCTET STRING when no formatter is found/defined
"""
Expand All @@ -18,13 +19,13 @@ def on_octet_string(value: bytes) -> str:
return


def on_integer(value: int) -> str:
def on_integer(value: TValue) -> Union[int, None]:
if not isinstance(value, int):
return
return value


def on_oid_map(oid: Tuple[int]) -> str:
def on_oid_map(oid: TValue) -> Union[str, None]:
if not isinstance(oid, tuple):
# some devices don't follow mib's syntax
# for example ipAddressTable.ipAddressPrefix returns an int in case of
Expand All @@ -45,7 +46,7 @@ def on_value_map_b(value: bytes, map_: dict) -> str:
v for k, v in map_.items() if value[k // 8] & (1 << k % 8))


def on_syntax(syntax: dict, value: Union[int, str, bytes]):
def on_syntax(syntax: dict, value: TValue):
"""
this is point where bytes are converted to right datatype
"""
Expand All @@ -65,7 +66,10 @@ def on_syntax(syntax: dict, value: Union[int, str, bytes]):
raise Exception(f'Invalid syntax {syntax}')


def on_result(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
def on_result(
base_oid: TOid,
result: List[Tuple[TOid, TValue]],
) -> Tuple[str, List[dict]]:
"""returns a more compat result (w/o prefixes) and groups formatted
metrics by base_oid
"""
Expand Down Expand Up @@ -109,7 +113,10 @@ def on_result(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
return result_name, list(table.values())


def on_result_base(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
def on_result_base(
base_oid: TOid,
result: List[Tuple[TOid, TValue]],
) -> Tuple[str, List[dict]]:
"""returns formatted metrics grouped by base_oid
"""
base = MIB_INDEX[base_oid]
Expand Down
12 changes: 7 additions & 5 deletions asyncsnmplib/package.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .asn1 import Decoder, Encoder, Number
from typing import Optional, Tuple, List
from .asn1 import Decoder, Encoder, Number, Tag, TOid, TValue


class Package:
Expand All @@ -8,12 +9,13 @@ class Package:
pdu = None

def __init__(self):
self.request_id = None
self.error_status = None
self.error_index = None
self.variable_bindings = []
self.request_id: Optional[int] = None
self.error_status: Optional[int] = None
self.error_index: Optional[int] = None
self.variable_bindings: List[Tuple[TOid, Tag, TValue]] = []

def encode(self):
assert self.pdu is not None
encoder = Encoder()

with encoder.enter(Number.Sequence):
Expand Down
11 changes: 7 additions & 4 deletions asyncsnmplib/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class SnmpProtocol(asyncio.DatagramProtocol):
__slots__ = ('loop', 'target', 'transport', 'requests', '_request_id')

def __init__(self, target):
self.loop = asyncio.get_event_loop()
self.loop = asyncio.get_running_loop()
self.target = target
self.requests = {}
self._request_id = 0
Expand All @@ -47,8 +47,11 @@ def datagram_received(self, data: bytes, *args):
# before request_id is known we cannot do anything and the query
# will time out
pid = pkg.request_id
if pid is not None:
if pid in self.requests:
self.requests[pid].set_exception(exceptions.SnmpDecodeError)
elif pid is not None:
logging.error(
self._log_with_suffix(f'Unknown package pid {pid}'))
else:
logging.error(
self._log_with_suffix('Failed to decode package'))
Expand All @@ -59,9 +62,9 @@ def datagram_received(self, data: bytes, *args):
self._log_with_suffix(f'Unknown package pid {pid}'))
else:
exception = None
if pkg.error_status != 0:
if pkg.error_status: # also exclude None for trap-pdu
oid = None
if pkg.error_index != 0:
if pkg.error_index: # also exclude None for trap-pdu
oidtuple = \
pkg.variable_bindings[pkg.error_index - 1][0]
oid = '.'.join(map(str, oidtuple))
Expand Down
38 changes: 23 additions & 15 deletions asyncsnmplib/trapserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,39 @@ def datagram_received(self, data: bytes, *args):
logging.error(
self._log_with_suffix('Failed to decode package'))
else:
# print(pkg.variable_bindings)
# for oid, tag, value in pkg.variable_bindings:
# print(oid, MIB_INDEX.get(oid[:-1])['name'])

for oid, tag, value in pkg.variable_bindings[1:]:
print(MIB_INDEX.get(oid[:-1])['name'], MIB_INDEX.get(value[:-1])['name'])
logging.debug('Trap message received')
for oid, tag, value in pkg.variable_bindings:
mib_object = MIB_INDEX.get(oid[:-1])
if mib_object is None:
# only accept oids from loaded mibs
continue
logging.info(
f'oid: {oid} name: {mib_object["name"]} value: {value}'
)
# TODO some values need oid lookup for the value, do here or in
# outside processor


class SnmpTrap:
def __init__(self, host='0.0.0.0', port=162, community='public', max_rows=10000):
self._loop = asyncio.get_event_loop()
def __init__(self, host='0.0.0.0', port=162, community='public', max_rows=10000, loop=None):
self._loop = loop if loop else asyncio.get_running_loop()
self._protocol = None
self._transport = None
self.host = host
self.port = port
self.community = community
self.max_rows = max_rows

def start(self):
transport, protocol = self._loop.run_until_complete(
self._loop.create_datagram_endpoint(
lambda: SnmpTrapProtocol((None, None)),
local_addr=(self.host, self.port),
)
async def listen(self):
transport, protocol = await self._loop.create_datagram_endpoint(
lambda: SnmpTrapProtocol((None, None)),
local_addr=(self.host, self.port),
)
self._protocol = protocol
self._transport = transport
self._loop.run_forever()

def close(self):
if self._transport is not None and not self._transport.is_closing():
self._transport.close()
self._protocol = None
self._transport = None
7 changes: 3 additions & 4 deletions asyncsnmplib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,9 @@ def snmpv3_credentials(config: dict):
'auth_proto': auth_type,
'auth_passwd': auth_passwd,
}
else:
return {
'username': user_name,
}
return {
'username': user_name,
}


async def snmp_queries(
Expand Down

0 comments on commit 15c6c45

Please sign in to comment.