Refactor to minimise duplication.
authorBenjamin Braatz <bb@bbraatz.eu>
Wed, 28 Jul 2021 18:15:31 +0000 (20:15 +0200)
committerBenjamin Braatz <bb@bbraatz.eu>
Wed, 28 Jul 2021 18:15:31 +0000 (20:15 +0200)
controlpi_plugins/wsclient.py

index 5155e021a55f28080d32f941bb1b034c9fa349dd..60b09e2a21f9d47406ea772a2c58893143274b32 100644 (file)
@@ -1,17 +1,95 @@
-"""Provide …
-
-…
-
-TODO: documentation, doctests
-"""
 import asyncio
 import fcntl
 import json
 import socket
 import struct
-from websockets import ConnectionClosed, connect
+from websockets.exceptions import ConnectionClosed
+from websockets.legacy.client import connect, WebSocketClientProtocol
+
 from controlpi import BasePlugin, Message, MessageTemplate
 
+from typing import Optional, Dict, Any
+
+
+def translate_message(original_message: Dict[str, Any], sender: str,
+                      receiver: str) -> Optional[Dict[str, Any]]:
+    """Translate message from sender to receiver.
+
+    The message comes from the message bus of the sender and is intended
+    for the message bus of the receiver. The name of the sender is prepended
+    to the 'original sender' and 'target' keys. If the 'original sender' key
+    already started with the receiver None is returned to avoid message
+    loops. If the 'target' key started with the receiver it is removed, so
+    that recipients on the receiver message bus get the message.
+    """
+    message = json.loads(json.dumps(original_message))
+    prefix = receiver + '/'
+    original_sender = sender
+    if 'original sender' in message:
+        assert isinstance(message['original sender'], str)
+        if message['original sender'].startswith(prefix):
+            return None
+        original_sender += '/' + message['original sender']
+    elif 'sender' in message:
+        assert isinstance(message['sender'], str)
+        if message['sender'] != '':
+            original_sender += '/' + message['sender']
+        del message['sender']
+    message['original sender'] = original_sender
+    if 'target' in message:
+        assert isinstance(message['target'], str)
+        target = message['target']
+        if target == '':
+            target = sender
+        elif target.startswith(prefix):
+            target = target[len(prefix):]
+        else:
+            target = sender + '/' + target
+        message['target'] = target
+    return message
+
+
+def translate_template(original_template: Dict[str, Any], sender: str,
+                       receiver: str) -> Optional[Dict[str, Any]]:
+    """Translate message template from sender to receiver.
+
+    Same functionality as translate_message, but for templates. Templates
+    do not necessarily have a 'sender' or 'original sender' key, so no
+    'original sender' is added if none is present before. And they have
+    JSON schema instances instead of plain strings as values. This function
+    only deals with 'const' schemas.
+    """
+    template = json.loads(json.dumps(original_template))
+    prefix = receiver + '/'
+    if 'original sender' in template:
+        assert isinstance(template['original sender'], dict)
+        if 'const' in template['original sender']:
+            assert isinstance(template['original sender']['const'], str)
+            original_sender = sender + '/' + \
+                template['original sender']['const']
+            template['original sender'] = {'const': original_sender}
+    elif 'sender' in template:
+        assert isinstance(template['sender'], dict)
+        if 'const' in template['sender']:
+            assert isinstance(template['sender']['const'], str)
+            original_sender = sender + '/' + template['sender']['const']
+            template['original sender'] = {'const': original_sender}
+    if 'sender' in template:
+        del template['sender']
+    if 'target' in template:
+        assert isinstance(template['target'], dict)
+        if 'const' in template['target']:
+            assert isinstance(template['target']['const'], str)
+            target = template['target']['const']
+            if target == '':
+                target = sender
+            elif target.startswith(prefix):
+                target = target[len(prefix):]
+            else:
+                target = sender + '/' + target
+            template['target'] = {'const': target}
+    return template
+
 
 class WSClient(BasePlugin):
     """Websocket client plugin.
@@ -32,56 +110,22 @@ class WSClient(BasePlugin):
     async def _receive(self, message: Message) -> None:
         if not self._websocket:
             return
-        assert isinstance(message['sender'], str)
-        prefix = f"{self.name}/"
-        original_sender = self._client
-        if 'original sender' in message:
-            if message['original sender'].startswith(prefix):
-                return
-            original_sender += f"/{message['original sender']}"
-        elif message['sender'] != '':
-            original_sender += f"/{message['sender']}"
-        message['original sender'] = original_sender
-        del message['sender']
-        if 'target' in message:
-            assert isinstance(message['target'], str)
-            target = message['target']
-            if target == '':
-                target = self._client
-            elif target.startswith(prefix):
-                target = target[len(prefix):]
-            else:
-                target = f"{self._client}/{target}"
-            message['target'] = target
-        json_message = json.dumps(message)
-        await self._websocket.send(json_message)
+        translated_message = translate_message(message,
+                                               self._client, self.name)
+        if translated_message is not None:
+            json_message = json.dumps(translated_message)
+            await self._websocket.send(json_message)
 
     async def _send(self, json_message: str) -> None:
         message = json.loads(json_message)
-        prefix = f"{self._client}/"
-        original_sender = self.name
-        if 'original sender' in message:
-            if message['original sender'].startswith(prefix):
-                return
-            original_sender += f"/{message['original sender']}"
-        elif message['sender'] != '':
-            original_sender += f"/{message['sender']}"
-        message['original sender'] = original_sender
-        message['sender'] = self.name
-        if 'target' in message:
-            target = message['target']
-            if target == '':
-                target = self.name
-            elif target.startswith(prefix):
-                target = target[len(prefix):]
-            else:
-                target = f"{self.name}/{target}"
-            message['target'] = target
-        await self.bus.send(message)
+        translated_message = translate_message(message,
+                                               self.name, self._client)
+        if translated_message is not None:
+            await self.bus.send(Message(self.name, translated_message))
 
     def process_conf(self) -> None:
         """Register plugin as bus client."""
-        self._websocket = None
+        self._websocket: Optional[WebSocketClientProtocol] = None
         if 'client' in self.conf:
             self._client = self.conf['client']
         if 'interface' in self.conf:
@@ -92,52 +136,14 @@ class WSClient(BasePlugin):
                                            bytes(self.conf['interface'],
                                                  'utf-8')[:15]))
             self._mac = ':'.join('%02x' % b for b in info[18:24])
-        sends = []
-        sends.append(MessageTemplate({'event':
-                                      {'const': 'registered'}}))
-        sends.append(MessageTemplate({'event':
-                                      {'const': 'connection opened'}}))
-        sends.append(MessageTemplate({'event':
-                                      {'const': 'connection closed'}}))
-        for template in self.conf['down filter']:
-            send_template = MessageTemplate(template)
-            if ('sender' in send_template and
-                    'const' in send_template['sender']):
-                original_sender = self.name
-                if ('original sender' in send_template and
-                        'const' in send_template['original sender']):
-                    const = send_template['original sender']['const']
-                    original_sender += f"/{const}"
-                elif send_template['sender']['const'] != '':
-                    const = send_template['sender']['const']
-                    original_sender += f"/{const}"
-                send_template['original sender'] = {'const': original_sender}
-                del send_template['sender']
-            if ('target' in send_template and
-                    'const' in send_template['target']):
-                target = send_template['target']['const']
-                if target == '':
-                    target = self.name
-                elif 'client' in self.conf:
-                    prefix = f"{self.conf['client']}/"
-                    if target.startswith(prefix):
-                        target = target[len(prefix):]
-                    else:
-                        target = f"{self.name}/{target}"
-                else:
-                    target = f"{self.name}/{target}"
-                send_template['target'] = {'const': target}
-            sends.append(send_template)
-        self.bus.register(self.name, 'WSClient', sends,
-                          self.conf['up filter'], self._receive)
 
     async def run(self) -> None:
         """Connect to wsserver and process messages from it."""
         while True:
             try:
                 async with connect(self.conf['url']) as websocket:
-                    conf_command = {'command': 'configure websocket',
-                                    'target': ''}
+                    conf_command: Dict[str, Any] = \
+                            {'command': 'configure websocket', 'target': ''}
                     if 'client' in self.conf:
                         conf_command['name'] = self._client
                     else:
@@ -146,33 +152,26 @@ class WSClient(BasePlugin):
                         self._client = f"{address}:{port}"
                     if 'interface' in self.conf:
                         conf_command['mac'] = self._mac
+                    sends = []
+                    sends.append(MessageTemplate({'event':
+                                                 {'const': 'registered'}}))
+                    sends.append(MessageTemplate({'event':
+                                                 {'const':
+                                                  'connection opened'}}))
+                    sends.append(MessageTemplate({'event':
+                                                 {'const':
+                                                  'connection closed'}}))
+                    for template in self.conf['down filter']:
+                        template = translate_template(template,
+                                                      self.name, self._client)
+                        sends.append(MessageTemplate(template))
+                    self.bus.register(self.name, 'WSClient', sends,
+                                      self.conf['up filter'], self._receive)
                     up_filter = []
                     for template in self.conf['up filter']:
-                        up_template = MessageTemplate(template)
-                        if ('sender' in up_template and
-                                'const' in up_template['sender']):
-                            original_sender = self._client
-                            if ('original sender' in up_template and
-                                    'const' in up_template['original sender']):
-                                const = up_template['original sender']['const']
-                                original_sender += f"/{const}"
-                            elif up_template['sender']['const'] != '':
-                                const = up_template['sender']['const']
-                                original_sender += f"/{const}"
-                            up_template['original sender'] = {'const': original_sender}
-                            del up_template['sender']
-                        if ('target' in up_template and
-                                'const' in up_template['target']):
-                            target = up_template['target']['const']
-                            prefix = f"{self.name}/"
-                            if target == '':
-                                target = self._client
-                            elif target.startswith(prefix):
-                                target = target[len(prefix):]
-                            else:
-                                target = f"{self._client}/{target}"
-                            up_template['target'] = {'const': target}
-                        up_filter.append(up_template)
+                        template = translate_template(template,
+                                                      self._client, self.name)
+                        up_filter.append(MessageTemplate(template))
                     conf_command['up filter'] = up_filter
                     conf_command['down filter'] = self.conf['down filter']
                     json_command = json.dumps(conf_command)
@@ -183,12 +182,15 @@ class WSClient(BasePlugin):
                     self._websocket = websocket
                     try:
                         async for json_message in websocket:
+                            assert isinstance(json_message, str)
                             await self._send(json_message)
                     except ConnectionClosed:
-                        self._websocket = None
+                        pass
+                    self._websocket = None
                     await self.bus.send(Message(self.name,
                                                 {'event':
                                                  'connection closed'}))
+                    self.bus.unregister(self.name)
             except OSError:
                 pass
             await asyncio.sleep(1)