From 2d04d876a1002c91cf6f69f3f7f3b4d7442631eb Mon Sep 17 00:00:00 2001 From: Benjamin Braatz Date: Wed, 27 Oct 2021 14:24:36 +0200 Subject: [PATCH] Manage callbacks in TemplateRegistry. --- controlpi/messagebus.py | 67 ++++++++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/controlpi/messagebus.py b/controlpi/messagebus.py index 4c61817..6aa69bb 100644 --- a/controlpi/messagebus.py +++ b/controlpi/messagebus.py @@ -86,7 +86,8 @@ import json import fastjsonschema # type: ignore import sys -from typing import Union, Dict, List, Any, Iterable, Callable, Coroutine +from typing import (Union, Dict, List, Any, Callable, Coroutine, + Optional, Iterable) MessageValue = Union[None, str, int, float, bool, Dict[str, Any], List[Any]] # Should really be: # MessageValue = Union[None, str, int, float, bool, @@ -764,13 +765,15 @@ class TemplateRegistry: >>> r = TemplateRegistry() """ self._clients: List[str] = [] + self._callbacks: Dict[str, List[MessageCallback]] = {} self._constants: Dict[str, Dict[str, TemplateRegistry]] = {} # First key is the message key, second key is the constant string self._schemas: Dict[str, Dict[str, TemplateRegistry]] = {} # First key is the message key, second key is the JSON schema string self._templates: Dict[str, List[MessageTemplate]] = {} - def insert(self, template: MessageTemplate, client: str) -> None: + def insert(self, template: MessageTemplate, client: str, + callback: Optional[MessageCallback] = None) -> None: """Register a client for a template. >>> r = TemplateRegistry() @@ -780,8 +783,15 @@ class TemplateRegistry: >>> r.insert({'k1': {'type': 'integer'}, 'k2': {'const': 'v2'}}, 'C 4') >>> r.insert({}, 'C 5') """ + if client not in self._templates: + self._templates[client] = [] + self._templates[client].append(template) if not template: self._clients.append(client) + if callback: + if client not in self._callbacks: + self._callbacks[client] = [] + self._callbacks[client].append(callback) else: key, schema = next(iter(template.items())) reduced_template = MessageTemplate({k: template[k] @@ -794,7 +804,8 @@ class TemplateRegistry: self._constants[key] = {} if value not in self._constants[key]: self._constants[key][value] = TemplateRegistry() - self._constants[key][value].insert(reduced_template, client) + self._constants[key][value].insert(reduced_template, + client, callback) else: schema_string = json.dumps(schema) if key not in self._schemas: @@ -802,10 +813,7 @@ class TemplateRegistry: if schema_string not in self._schemas[key]: self._schemas[key][schema_string] = TemplateRegistry() self._schemas[key][schema_string].insert(reduced_template, - client) - if client not in self._templates: - self._templates[client] = [] - self._templates[client].append(template) + client, callback) def delete(self, client: str) -> bool: """Unregister a client from all templates. @@ -821,7 +829,11 @@ class TemplateRegistry: >>> r.delete('C 4') True """ + if client in self._templates: + del self._templates[client] self._clients = [c for c in self._clients if c != client] + if client in self._callbacks: + del self._callbacks[client] new_constants: Dict[str, Dict[str, TemplateRegistry]] = {} for key in self._constants: new_constants[key] = {} @@ -840,9 +852,8 @@ class TemplateRegistry: if not new_schemas[key]: del new_schemas[key] self._schemas = new_schemas - if client in self._templates: - del self._templates[client] - if self._clients or self._constants or self._schemas: + if (self._clients or self._callbacks or + self._constants or self._schemas): return True return False @@ -923,6 +934,31 @@ class TemplateRegistry: result.append(client) return result + def get_callbacks(self, message: Message) -> List[MessageCallback]: + result = [] + for client in self._callbacks: + for callback in self._callbacks[client]: + if callback not in result: + result.append(callback) + for key in self._constants: + if (key in message and isinstance(message[key], str) and + message[key] in self._constants[key]): + value = message[key] + assert isinstance(value, str) + child = self._constants[key][value] + for callback in child.get_callbacks(message): + if callback not in result: + result.append(callback) + for key in self._schemas: + if key in message: + for schema_string in self._schemas[key]: + if validate(schema_string, message[key]): + child = self._schemas[key][schema_string] + for callback in child.get_callbacks(message): + if callback not in result: + result.append(callback) + return result + def get_templates(self, client: str) -> List[MessageTemplate]: """Get all templates for a client. @@ -1071,7 +1107,6 @@ class MessageBus: self._plugins: Dict[str, str] = {} self._send_reg: TemplateRegistry = TemplateRegistry() self._recv_reg: TemplateRegistry = TemplateRegistry() - self._callbacks: Dict[str, MessageCallback] = {} def register(self, client: str, plugin: str, sends: Iterable[MessageTemplate], @@ -1114,9 +1149,8 @@ class MessageBus: self._send_reg.insert(template, client) event['sends'] = self._send_reg.get_templates(client) for template in receives: - self._recv_reg.insert(template, client) + self._recv_reg.insert(template, client, callback) event['receives'] = self._recv_reg.get_templates(client) - self._callbacks[client] = callback self._queue.put_nowait(event) def unregister(self, client: str) -> None: @@ -1141,8 +1175,6 @@ class MessageBus: del self._plugins[client] self._send_reg.delete(client) self._recv_reg.delete(client) - if client in self._callbacks: - del self._callbacks[client] self._queue.put_nowait(event) async def run(self) -> None: @@ -1187,9 +1219,8 @@ class MessageBus: {'event': 'conf changed'}))) with open(sys.argv[1], 'w') as conf_file: json.dump(message['conf'], conf_file) - for client in self._recv_reg.get(message): - if client in self._callbacks: - await self._callbacks[client](message) + for callback in self._recv_reg.get_callbacks(message): + await callback(message) self._queue.task_done() async def send(self, message: Message) -> None: -- 2.34.1