diff --git a/apns.py b/apns.py index 3f39438..654e96a 100644 --- a/apns.py +++ b/apns.py @@ -85,10 +85,28 @@ WAIT_WRITE_TIMEOUT_SEC = 10 WAIT_READ_TIMEOUT_SEC = 10 WRITE_RETRY = 3 +WAIT_READ = 1 +WAIT_WRITE = 2 ER_STATUS = 'status' ER_IDENTIFER = 'identifier' + +def _wait_for_socket(sock, direction, timeout=None): + try: + poll = select.poll() + poll.register(sock, select.POLLIN if direction == WAIT_READ else select.POLLOUT) + if timeout: + timeout *= 1000 + events = poll.poll(timeout) + return bool(events) + except AttributeError: # fallback for systems not supporting poll() + rlist = [sock] if direction == WAIT_READ else [] + wlist = [sock] if direction == WAIT_WRITE else [] + rlist, wlist, _ = select.select(rlist, wlist, [], timeout) + return bool(rlist or wlist) + + class APNs(object): """A class representing an Apple Push Notification service connection""" @@ -211,9 +229,9 @@ def _connect(self): break except ssl.SSLError as err: if ssl.SSL_ERROR_WANT_READ == err.args[0]: - select.select([self._ssl], [], []) + _wait_for_socket(self._ssl, WAIT_READ) elif ssl.SSL_ERROR_WANT_WRITE == err.args[0]: - select.select([], [self._ssl], []) + _wait_for_socket(self._ssl, WAIT_WRITE) else: raise @@ -254,9 +272,9 @@ def read(self, n=None): def write(self, string): if self.enhanced: # nonblocking socket self._last_activity_time = time.time() - _, wlist, _ = select.select([], [self._connection()], [], WAIT_WRITE_TIMEOUT_SEC) - - if len(wlist) > 0: + writeable = _wait_for_socket(self._connection(), WAIT_WRITE, WAIT_WRITE_TIMEOUT_SEC) + + if writeable: length = self._connection().sendall(string) if length == 0: _logger.debug("sent length: %d" % length) #DEBUG @@ -594,9 +612,8 @@ def run(self): continue try: - rlist, _, _ = select.select([self._apns_connection._connection()], [], [], WAIT_READ_TIMEOUT_SEC) - - if len(rlist) > 0: # there's some data from APNs + readable = _wait_for_socket(self._apns_connection._connection(), WAIT_READ, WAIT_READ_TIMEOUT_SEC) + if readable: # there's some data from APNs with self._apns_connection._send_lock: buff = self._apns_connection.read(ERROR_RESPONSE_LENGTH) if len(buff) == ERROR_RESPONSE_LENGTH: diff --git a/tests.py b/tests.py index fb17a54..3b15235 100644 --- a/tests.py +++ b/tests.py @@ -1,8 +1,12 @@ #!/usr/bin/env python # coding: utf-8 +from contextlib import contextmanager + from apns import * +from apns import _wait_for_socket from binascii import a2b_hex from random import random +import socket import hashlib import os @@ -209,5 +213,56 @@ def testPayloadTooLargeError(self): self.assertRaises(PayloadTooLargeError, Payload, u'\u0100' * (int(max_raw_payload_bytes / 2) + 1)) + def testWaitForSocket(self): + @contextmanager + def assert_timing(expected, delta): + start = time.time() + yield + end = time.time() + took = end - start + self.assertTrue(expected > took - delta / 2) + self.assertTrue(expected < took + delta / 2) + + socket1, socket2 = socket.socketpair() + socket1.setblocking(False) + socket2.setblocking(False) + + # Nothing was written, therefore waiting for reading should time out + with assert_timing(1, 0.1): + result = _wait_for_socket(socket1, WAIT_READ, 1) + self.assertFalse(result) + + # Send-buffer is empty, waiting for write shouldn't block + with assert_timing(0, 0.1): + result = _wait_for_socket(socket1, WAIT_WRITE, 5) + self.assertTrue(result) + socket2.send('test') + + # We just sent something, reading on the other ending shouldn't block now + with assert_timing(0, 0.1): + result = _wait_for_socket(socket1, WAIT_READ, 5) + self.assertTrue(result) + self.assertEquals(socket1.recv(1024), 'test') + + # Fill up the write-buffer + try: + while socket1.send(1024 * 'a') == 1024: + continue + except socket.error: + pass + + # Waiting for write should block now + with assert_timing(1, 0.1): + result = _wait_for_socket(socket1, WAIT_WRITE, 1) + self.assertFalse(result) + + # Closed socket returns being readable + socket2.close() + with assert_timing(0, 0.1): + result = _wait_for_socket(socket1, WAIT_READ) + self.assertTrue(result) + + socket1.close() + if __name__ == '__main__': unittest.main()