Skip to content

Commit

Permalink
Refactor bulb discovery
Browse files Browse the repository at this point in the history
* Allow for all kinds of missing connection and os errors
  • Loading branch information
codingjoe committed Jan 21, 2018
1 parent 2956e12 commit ed1d357
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 44 deletions.
13 changes: 6 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ Example:
@asyncio.coroutine
def turn_all_lights_on(bulbs):
while True:
print(bulbs)
for b in bulbs.values():
asyncio.Task(b.send_command("set_power",
["off", "sudden", 40]))
Expand All @@ -39,12 +38,12 @@ Example:
def main():
loop = asyncio.get_event_loop()
with search_bulbs() as bulbs:
loop.create_task(turn_all_lights_on(bulbs))
try:
loop.run_forever()
except KeyboardInterrupt:
loop.stop()
bulbs = loop.run_until_complete(search_bulbs())
loop.create_task(turn_all_lights_on(bulbs))
try:
loop.run_forever()
except KeyboardInterrupt:
loop.stop()
if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/test_discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ def test_wrong_location(self):


def test_search_bulbs():
asyncio.Task(search_bulbs())
loop = asyncio.get_event_loop()
with search_bulbs():
loop.run_until_complete(asyncio.sleep(1))
loop.run_until_complete(asyncio.sleep(1))
68 changes: 33 additions & 35 deletions yeelib/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,16 @@
import socket
import struct
import time
from contextlib import contextmanager

from ssdp import SSDPRequest, SimpleServiceDiscoveryProtocol

from .exceptions import YeelightError
from .devices import Bulb
from .exceptions import YeelightError

__all__ = ('search_bulbs', 'YeelightProtocol', 'bulbs')

logger = logging.getLogger('yeelib')


bulbs = {}

MCAST_PORT = 1982
Expand All @@ -31,16 +29,34 @@ async def send_search_broadcast(transport, search_interval=30):
('ST', 'wifi_bulb'),
])
while True:
request.sendto(transport, MCAST_ADDR)
try:
request.sendto(transport, MCAST_ADDR)
except OSError:
logger.exception("Connection error")
await asyncio.sleep(search_interval)


class YeelightProtocol(SimpleServiceDiscoveryProtocol):
excluded_headers = ['DATE', 'EXT', 'SERVER', 'CACHE-CONTROL', 'LOCATION']
location_patter = r'yeelight://(?P<ip>\d{1,3}(\.\d{1,3}){3}):(?P<port>\d+)'

def __init__(self, bulb_class=Bulb):
def __init__(self, bulb_class=Bulb, loop=None):
self.bulb_class = bulb_class
self.loop = loop or asyncio.get_event_loop()

def connection_made(self, transport):
ucast_socket = transport.get_extra_info('socket')
try:
ucast_socket.bind(('', MCAST_PORT))
fcntl.fcntl(ucast_socket, fcntl.F_SETFL, os.O_NONBLOCK)
group = socket.inet_aton(
SimpleServiceDiscoveryProtocol.MULTICAST_ADDRESS)
mreq = struct.pack("4sl", group, socket.INADDR_ANY)
ucast_socket.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
except socket.error as e:
ucast_socket.close()
self.connection_lost(exc=e)

@classmethod
def header_to_kwargs(cls, headers):
Expand Down Expand Up @@ -79,39 +95,21 @@ def register_bulb(self, **kwargs):
else:
bulbs[idx].last_seen = time.time()

def connection_lost(self, exc):
logger.exception("connection error")

def open_unicast_socket():
ucast_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
ucast_socket.bind(('', MCAST_PORT))
fcntl.fcntl(ucast_socket, fcntl.F_SETFL, os.O_NONBLOCK)
group = socket.inet_aton(
SimpleServiceDiscoveryProtocol.MULTICAST_ADDRESS)
mreq = struct.pack("4sl", group, socket.INADDR_ANY)
ucast_socket.setsockopt(
socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
return ucast_socket
async def _restart():
await asyncio.sleep(10)
await search_bulbs(self.bulb_class, self.loop)

asyncio.Task(_restart())

@contextmanager
def search_bulbs(bulb_class=Bulb, loop=None):

async def search_bulbs(bulb_class=Bulb, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
multicast_connection = loop.create_datagram_endpoint(
lambda: YeelightProtocol(bulb_class),
family=socket.AF_INET)
mcast_transport, _ = loop.run_until_complete(multicast_connection)
loop.create_task(send_search_broadcast(mcast_transport))

ucast_socket = open_unicast_socket()
unicast_connection = loop.create_datagram_endpoint(
lambda: YeelightProtocol(bulb_class),
sock=ucast_socket)
ucast_transport, _ = loop.run_until_complete(unicast_connection)
try:
yield bulbs
except BaseException:
pass
finally:
ucast_transport.close()
mcast_transport.close()
ucast_socket.close()
lambda: YeelightProtocol(bulb_class), family=socket.AF_INET)
ucast_transport, _ = await unicast_connection
loop.create_task(send_search_broadcast(ucast_transport))
return bulbs

0 comments on commit ed1d357

Please sign in to comment.