Manage callbacks in TemplateRegistry.
authorBenjamin Braatz <benjamin.braatz@graph-it.com>
Wed, 27 Oct 2021 12:24:36 +0000 (14:24 +0200)
committerBenjamin Braatz <benjamin.braatz@graph-it.com>
Wed, 27 Oct 2021 12:24:36 +0000 (14:24 +0200)
controlpi/messagebus.py

index 4c61817017bedcf1850cb30c918ac93a8a6b3f05..6aa69bb267cab8a29aa2c0fb7af28ec63665894b 100644 (file)
@@ -86,7 +86,8 @@ import json
 import fastjsonschema  # type: ignore
 import sys
 
-from typing import Union, Dict, List, Any, Iterable, Callable, Coroutine
+from typing import (Union, Dict, List, Any, Callable, Coroutine,
+                    Optional, Iterable)
 MessageValue = Union[None, str, int, float, bool, Dict[str, Any], List[Any]]
 # Should really be:
 # MessageValue = Union[None, str, int, float, bool,
@@ -764,13 +765,15 @@ class TemplateRegistry:
         >>> r = TemplateRegistry()
         """
         self._clients: List[str] = []
+        self._callbacks: Dict[str, List[MessageCallback]] = {}
         self._constants: Dict[str, Dict[str, TemplateRegistry]] = {}
         # First key is the message key, second key is the constant string
         self._schemas: Dict[str, Dict[str, TemplateRegistry]] = {}
         # First key is the message key, second key is the JSON schema string
         self._templates: Dict[str, List[MessageTemplate]] = {}
 
-    def insert(self, template: MessageTemplate, client: str) -> None:
+    def insert(self, template: MessageTemplate, client: str,
+               callback: Optional[MessageCallback] = None) -> None:
         """Register a client for a template.
 
         >>> r = TemplateRegistry()
@@ -780,8 +783,15 @@ class TemplateRegistry:
         >>> r.insert({'k1': {'type': 'integer'}, 'k2': {'const': 'v2'}}, 'C 4')
         >>> r.insert({}, 'C 5')
         """
+        if client not in self._templates:
+            self._templates[client] = []
+        self._templates[client].append(template)
         if not template:
             self._clients.append(client)
+            if callback:
+                if client not in self._callbacks:
+                    self._callbacks[client] = []
+                self._callbacks[client].append(callback)
         else:
             key, schema = next(iter(template.items()))
             reduced_template = MessageTemplate({k: template[k]
@@ -794,7 +804,8 @@ class TemplateRegistry:
                     self._constants[key] = {}
                 if value not in self._constants[key]:
                     self._constants[key][value] = TemplateRegistry()
-                self._constants[key][value].insert(reduced_template, client)
+                self._constants[key][value].insert(reduced_template,
+                                                   client, callback)
             else:
                 schema_string = json.dumps(schema)
                 if key not in self._schemas:
@@ -802,10 +813,7 @@ class TemplateRegistry:
                 if schema_string not in self._schemas[key]:
                     self._schemas[key][schema_string] = TemplateRegistry()
                 self._schemas[key][schema_string].insert(reduced_template,
-                                                         client)
-        if client not in self._templates:
-            self._templates[client] = []
-        self._templates[client].append(template)
+                                                         client, callback)
 
     def delete(self, client: str) -> bool:
         """Unregister a client from all templates.
@@ -821,7 +829,11 @@ class TemplateRegistry:
         >>> r.delete('C 4')
         True
         """
+        if client in self._templates:
+            del self._templates[client]
         self._clients = [c for c in self._clients if c != client]
+        if client in self._callbacks:
+            del self._callbacks[client]
         new_constants: Dict[str, Dict[str, TemplateRegistry]] = {}
         for key in self._constants:
             new_constants[key] = {}
@@ -840,9 +852,8 @@ class TemplateRegistry:
             if not new_schemas[key]:
                 del new_schemas[key]
         self._schemas = new_schemas
-        if client in self._templates:
-            del self._templates[client]
-        if self._clients or self._constants or self._schemas:
+        if (self._clients or self._callbacks or
+                self._constants or self._schemas):
             return True
         return False
 
@@ -923,6 +934,31 @@ class TemplateRegistry:
                                 result.append(client)
         return result
 
+    def get_callbacks(self, message: Message) -> List[MessageCallback]:
+        result = []
+        for client in self._callbacks:
+            for callback in self._callbacks[client]:
+                if callback not in result:
+                    result.append(callback)
+        for key in self._constants:
+            if (key in message and isinstance(message[key], str) and
+                    message[key] in self._constants[key]):
+                value = message[key]
+                assert isinstance(value, str)
+                child = self._constants[key][value]
+                for callback in child.get_callbacks(message):
+                    if callback not in result:
+                        result.append(callback)
+        for key in self._schemas:
+            if key in message:
+                for schema_string in self._schemas[key]:
+                    if validate(schema_string, message[key]):
+                        child = self._schemas[key][schema_string]
+                        for callback in child.get_callbacks(message):
+                            if callback not in result:
+                                result.append(callback)
+        return result
+
     def get_templates(self, client: str) -> List[MessageTemplate]:
         """Get all templates for a client.
 
@@ -1071,7 +1107,6 @@ class MessageBus:
         self._plugins: Dict[str, str] = {}
         self._send_reg: TemplateRegistry = TemplateRegistry()
         self._recv_reg: TemplateRegistry = TemplateRegistry()
-        self._callbacks: Dict[str, MessageCallback] = {}
 
     def register(self, client: str, plugin: str,
                  sends: Iterable[MessageTemplate],
@@ -1114,9 +1149,8 @@ class MessageBus:
             self._send_reg.insert(template, client)
         event['sends'] = self._send_reg.get_templates(client)
         for template in receives:
-            self._recv_reg.insert(template, client)
+            self._recv_reg.insert(template, client, callback)
         event['receives'] = self._recv_reg.get_templates(client)
-        self._callbacks[client] = callback
         self._queue.put_nowait(event)
 
     def unregister(self, client: str) -> None:
@@ -1141,8 +1175,6 @@ class MessageBus:
         del self._plugins[client]
         self._send_reg.delete(client)
         self._recv_reg.delete(client)
-        if client in self._callbacks:
-            del self._callbacks[client]
         self._queue.put_nowait(event)
 
     async def run(self) -> None:
@@ -1187,9 +1219,8 @@ class MessageBus:
                                             {'event': 'conf changed'})))
                         with open(sys.argv[1], 'w') as conf_file:
                             json.dump(message['conf'], conf_file)
-            for client in self._recv_reg.get(message):
-                if client in self._callbacks:
-                    await self._callbacks[client](message)
+            for callback in self._recv_reg.get_callbacks(message):
+                await callback(message)
             self._queue.task_done()
 
     async def send(self, message: Message) -> None: