Skip to content

Commit

Permalink
Implement new options (-4/--ipv4, -6/--ipv6, -p/--port <port>).
Browse files Browse the repository at this point in the history
By default both IPv4 and IPv6 is supported and order of precedence depends on OS.
By using -46, IPv4 is prefered, but by using -64, IPv6 is preferd.
For now the old way how to specify port (host:port) has been kept intact.
  • Loading branch information
arthepsy committed Oct 26, 2016
1 parent 8018209 commit 66b9e07
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 39 deletions.
165 changes: 135 additions & 30 deletions ssh-audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,22 @@

VERSION = 'v1.6.1.dev'

if sys.version_info >= (3,):
if sys.version_info >= (3,): # pragma: nocover
StringIO, BytesIO = io.StringIO, io.BytesIO
text_type = str
binary_type = bytes
else:
else: # pragma: nocover
import StringIO as _StringIO # pylint: disable=import-error
StringIO = BytesIO = _StringIO.StringIO
text_type = unicode # pylint: disable=undefined-variable
binary_type = str
try:
try: # pragma: nocover
# pylint: disable=unused-import
from typing import List, Tuple, Optional, Callable, Union, Any
from typing import List, Set, Sequence, Tuple, Iterable
from typing import Callable, Optional, Union, Any
except ImportError: # pragma: nocover
pass
try:
try: # pragma: nocover
from colorama import init as colorama_init
colorama_init() # pragma: nocover
except ImportError: # pragma: nocover
Expand All @@ -53,13 +54,16 @@ def usage(err=None):
# type: (Optional[str]) -> None
uout = Output()
p = os.path.basename(sys.argv[0])
uout.head('# {0} {1}, moo@arthepsy.eu'.format(p, VERSION))
uout.head('# {0} {1}, moo@arthepsy.eu\n'.format(p, VERSION))
if err is not None:
uout.fail('\n' + err)
uout.info('\nusage: {0} [-12bnv] [-l <level>] <host[:port]>\n'.format(p))
uout.info('usage: {0} [-1246pbnvl] <host>\n'.format(p))
uout.info(' -h, --help print this help')
uout.info(' -1, --ssh1 force ssh version 1 only')
uout.info(' -2, --ssh2 force ssh version 2 only')
uout.info(' -4, --ipv4 enable IPv4 (order of precedence)')
uout.info(' -6, --ipv6 enable IPv6 (order of precedence)')
uout.info(' -p, --port=<port> port to connect')
uout.info(' -b, --batch batch output')
uout.info(' -n, --no-colors disable colors')
uout.info(' -v, --verbose verbose output')
Expand All @@ -69,6 +73,7 @@ def usage(err=None):


class AuditConf(object):
# pylint: disable=too-many-instance-attributes
def __init__(self, host=None, port=22):
# type: (Optional[str], int) -> None
self.host = host
Expand All @@ -79,12 +84,35 @@ def __init__(self, host=None, port=22):
self.colors = True
self.verbose = False
self.minlevel = 'info'
self.ipvo = () # type: Sequence[int]
self.ipv4 = False
self.ipv6 = False

def __setattr__(self, name, value):
# type: (str, Union[str, int, bool]) -> None
# type: (str, Union[str, int, bool, Sequence[int]]) -> None
valid = False
if name in ['ssh1', 'ssh2', 'batch', 'colors', 'verbose']:
valid, value = True, True if value else False
elif name in ['ipv4', 'ipv6']:
valid = False
value = True if value else False
ipv = 4 if name == 'ipv4' else 6
if value:
value = tuple(list(self.ipvo) + [ipv])
else:
if len(self.ipvo) == 0:
value = (6,) if ipv == 4 else (4,)
else:
value = tuple(filter(lambda x: x != ipv, self.ipvo))
self.__setattr__('ipvo', value)
elif name == 'ipvo':
if isinstance(value, (tuple, list)):
uniq_value = utils.unique_seq(value)
value = tuple(filter(lambda x: x in (4, 6), uniq_value))
valid = True
ipv_both = len(value) == 0
object.__setattr__(self, 'ipv4', ipv_both or 4 in value)
object.__setattr__(self, 'ipv6', ipv_both or 6 in value)
elif name == 'port':
valid, port = True, utils.parse_int(value)
if port < 1 or port > 65535:
Expand All @@ -105,20 +133,27 @@ def from_cmdline(cls, args, usage_cb):
# pylint: disable=too-many-branches
aconf = cls()
try:
sopts = 'h12bnvl:'
lopts = ['help', 'ssh1', 'ssh2', 'batch',
'no-colors', 'verbose', 'level=']
sopts = 'h1246p:bnvl:'
lopts = ['help', 'ssh1', 'ssh2', 'ipv4', 'ipv6', 'port',
'batch', 'no-colors', 'verbose', 'level=']
opts, args = getopt.getopt(args, sopts, lopts)
except getopt.GetoptError as err:
usage_cb(str(err))
aconf.ssh1, aconf.ssh2 = False, False
oport = None
for o, a in opts:
if o in ('-h', '--help'):
usage_cb()
elif o in ('-1', '--ssh1'):
aconf.ssh1 = True
elif o in ('-2', '--ssh2'):
aconf.ssh2 = True
elif o in ('-4', '--ipv4'):
aconf.ipv4 = True
elif o in ('-6', '--ipv6'):
aconf.ipv6 = True
elif o in ('-p', '--port'):
oport = a
elif o in ('-b', '--batch'):
aconf.batch = True
aconf.verbose = True
Expand All @@ -132,14 +167,20 @@ def from_cmdline(cls, args, usage_cb):
aconf.minlevel = a
if len(args) == 0:
usage_cb()
s = args[0].split(':')
host, port = s[0].strip(), 22
if len(s) > 1:
port = utils.parse_int(s[1])
if oport is not None:
host = args[0]
port = utils.parse_int(oport)
else:
s = args[0].split(':')
host = s[0].strip()
if len(s) == 2:
oport, port = s[1], utils.parse_int(s[1])
else:
oport, port = '22', 22
if not host:
usage_cb('host is empty')
if port <= 0 or port > 65535:
usage_cb('port {0} is not valid'.format(s[1]))
usage_cb('port {0} is not valid'.format(oport))
aconf.host = host
aconf.port = port
if not (aconf.ssh1 or aconf.ssh2):
Expand Down Expand Up @@ -1038,24 +1079,67 @@ class InsufficientReadException(Exception):

SM_BANNER_SENT = 1

def __init__(self, host, port, cto=3.0, rto=5.0):
# type: (str, int, float, float) -> None
def __init__(self, host, port):
# type: (str, int) -> None
super(SSH.Socket, self).__init__()
self.__block_size = 8
self.__state = 0
self.__header = [] # type: List[text_type]
self.__banner = None # type: Optional[SSH.Banner]
super(SSH.Socket, self).__init__()
try:
self.__sock = socket.create_connection((host, port), cto)
self.__sock.settimeout(rto)
except Exception as e: # pylint: disable=broad-except
out.fail('[fail] {0}'.format(e))
sys.exit(1)
self.__host = host
self.__port = port
self.__sock = None # type: socket.socket

def __enter__(self):
# type: () -> SSH.Socket
return self

def _resolve(self, ipvo):
# type: (Sequence[int]) -> Iterable[Tuple[int, Tuple[Any, ...]]]
ipvo = tuple(filter(lambda x: x in (4, 6), utils.unique_seq(ipvo)))
ipvo_len = len(ipvo)
prefer_ipvo = ipvo_len > 0
prefer_ipv4 = prefer_ipvo and ipvo[0] == 4
if len(ipvo) == 1:
family = {4: socket.AF_INET, 6: socket.AF_INET6}.get(ipvo[0])
else:
family = socket.AF_UNSPEC
try:
stype = socket.SOCK_STREAM
r = socket.getaddrinfo(self.__host, self.__port, family, stype)
if prefer_ipvo:
r = sorted(r, key=lambda x: x[0], reverse=not prefer_ipv4)
check = any(stype == rline[2] for rline in r)
for (af, socktype, proto, canonname, addr) in r:
if not check or socktype == socket.SOCK_STREAM:
yield (af, addr)
except socket.error as e:
out.fail('[exception] {0}'.format(e))
sys.exit(1)

def connect(self, ipvo=(), cto=3.0, rto=5.0):
# type: (Sequence[int], float, float) -> None
err = None
for (af, addr) in self._resolve(ipvo):
s = None
try:
s = socket.socket(af, socket.SOCK_STREAM)
s.settimeout(cto)
s.connect(addr)
s.settimeout(rto)
self.__sock = s
return
except socket.error as e:
err = e
self._close_socket(s)
if err is None:
errm = 'host {0} has no DNS records'.format(self.__host)
else:
errt = (self.__host, self.__port, err)
errm = 'cannot connect to {0} port {1}: {2}'.format(*errt)
out.fail('[exception] {0}'.format(errm))
sys.exit(1)

def get_banner(self, sshv=2):
# type: (int) -> Tuple[Optional[SSH.Banner], List[text_type]]
banner = 'SSH-{0}-OpenSSH_7.3'.format('1.5' if sshv == 1 else '2.0')
Expand Down Expand Up @@ -1188,6 +1272,15 @@ def send_packet(self):
data = struct.pack('>Ib', plen, padding) + payload + pad_bytes
return self.send(data)

def _close_socket(self, s):
# type: (Optional[socket.socket]) -> None
try:
if s is not None:
s.shutdown(socket.SHUT_RDWR)
s.close()
except: # pylint: disable=bare-except
pass

def __del__(self):
# type: () -> None
self.__cleanup()
Expand All @@ -1198,11 +1291,7 @@ def __exit__(self, *args):

def __cleanup(self):
# type: () -> None
try:
self.__sock.shutdown(socket.SHUT_RDWR)
self.__sock.close()
except: # pylint: disable=bare-except
pass
self._close_socket(self.__sock)


class KexDH(object):
Expand Down Expand Up @@ -1847,6 +1936,21 @@ def to_ascii(cls, v, errors='replace'):
return cls.to_ntext(v.encode('ascii', errors))
raise cls._type_err(v, 'ascii')

@classmethod
def unique_seq(cls, seq):
# type: (Sequence[Any]) -> Sequence[Any]
seen = set() # type: Set[Any]

def _seen_add(x):
# type: (Any) -> bool
seen.add(x)
return False

if isinstance(seq, tuple):
return tuple(x for x in seq if x not in seen and not _seen_add(x))
else:
return [x for x in seq if x not in seen and not _seen_add(x)]

@staticmethod
def parse_int(v):
# type: (Any) -> int
Expand All @@ -1863,6 +1967,7 @@ def audit(aconf, sshv=None):
out.verbose = aconf.verbose
out.minlevel = aconf.minlevel
s = SSH.Socket(aconf.host, aconf.port)
s.connect(aconf.ipvo)
if sshv is None:
sshv = 2 if aconf.ssh2 else 1
err = None
Expand Down
34 changes: 26 additions & 8 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pytest, os, sys, io, socket
import os
import io
import sys
import socket
import pytest


if sys.version_info[0] == 2:
import StringIO
import StringIO # pylint: disable=import-error
StringIO = StringIO.StringIO
else:
StringIO = io.StringIO
Expand All @@ -17,6 +21,7 @@ def ssh_audit():
return __import__('ssh-audit')


# pylint: disable=attribute-defined-outside-init
class _OutputSpy(list):
def begin(self):
self.__out = StringIO()
Expand Down Expand Up @@ -50,11 +55,14 @@ def _check_err(self, method):
if method_error:
raise method_error

def _connect(self, address):
def connect(self, address):
return self._connect(address, False)

def _connect(self, address, ret=True):
self.peer_address = address
self._connected = True
self._check_err('connect')
return self
return self if ret else None

def settimeout(self, timeout):
self.timeout = timeout
Expand All @@ -77,13 +85,15 @@ def listen(self, backlog):
pass

def accept(self):
# pylint: disable=protected-access
conn = _VirtualSocket()
conn.sock_address = self.sock_address
conn.peer_address = ('127.0.0.1', 0)
conn._connected = True
return conn, conn.peer_address

def recv(self, bufsize, flags=0):
# pylint: disable=unused-argument
if not self._connected:
raise socket.error(54, 'Connection reset by peer')
if not len(self.rdata) > 0:
Expand All @@ -103,10 +113,18 @@ def send(self, data):
@pytest.fixture()
def virtual_socket(monkeypatch):
vsocket = _VirtualSocket()
def _c(address):
return vsocket._connect(address)

# pylint: disable=unused-argument
def _socket(family=socket.AF_INET,
socktype=socket.SOCK_STREAM,
proto=0,
fileno=None):
return vsocket

def _cc(address, timeout=0, source_address=None):
return vsocket._connect(address)
# pylint: disable=protected-access
return vsocket._connect(address, True)

monkeypatch.setattr(socket, 'create_connection', _cc)
monkeypatch.setattr(socket.socket, 'connect', _c)
monkeypatch.setattr(socket, 'socket', _socket)
return vsocket
Loading

0 comments on commit 66b9e07

Please sign in to comment.