From 53e3ecac1e803f1fd0319e78ff7abfdb2f4c7f1d Mon Sep 17 00:00:00 2001 From: Benjamin Braatz Date: Thu, 22 May 2025 14:57:22 +0200 Subject: [PATCH] Refactor into context manager. --- conf.json | 6 +- controlpi_plugins/graph.py | 174 ++++++-------------------- controlpi_plugins/graph_connection.py | 115 +++++++++++++++++ test_graph_connection.py | 38 ++++++ 4 files changed, 197 insertions(+), 136 deletions(-) create mode 100644 controlpi_plugins/graph_connection.py create mode 100644 test_graph_connection.py diff --git a/conf.json b/conf.json index 0fb89eb..6b11cba 100644 --- a/conf.json +++ b/conf.json @@ -1,7 +1,7 @@ { "Master": { "plugin": "WSServer", - "port": 8080, + "port": 8123, "web": { "/": { "module": "controlpi_plugins.wsserver", @@ -14,9 +14,9 @@ }, "Graph": { "plugin": "Graph", - "url": "tls://graph.example.com", + "url": "tls://it.hsrobotics:4339", "crt": "graph.crt", - "name": "te", + "name": "d1", "filter": [ { "sender": { "const": "Example State" }, diff --git a/controlpi_plugins/graph.py b/controlpi_plugins/graph.py index 23572a6..2c43c13 100644 --- a/controlpi_plugins/graph.py +++ b/controlpi_plugins/graph.py @@ -1,14 +1,8 @@ """Provide Graph Connections as ControlPi Plugin.""" -import asyncio -import os.path -import ssl -import struct -import urllib.parse -import msgpack # type: ignore import json from controlpi import BasePlugin, Message, MessageTemplate -from typing import List, Dict, Any +from controlpi_plugins.graph_connection import GraphConnection class Graph(BasePlugin): @@ -24,147 +18,61 @@ class Graph(BasePlugin): def process_conf(self) -> None: """Register plugin as bus client.""" - res = urllib.parse.urlparse(self.conf['url']) - if res.scheme != 'tls': - raise NotImplementedError("Only implemented scheme is 'tls'.") - self._host = res.hostname - self._port = res.port - if not os.path.isfile(self.conf['crt']): - raise FileNotFoundError("Cannot find certificate file" - f"'{self.conf['crt']}'.") - self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) - self._ssl_ctx.load_cert_chain(self.conf['crt']) - self._open_connections = 0 - self._call_lock = asyncio.Lock() - self._messages: List[Message] = [] + self.graph_connection = GraphConnection(self.conf['url'], + self.conf['crt']) + self.messages: list[Message] = [] self.bus.register(self.name, 'Graph', [], [([MessageTemplate({'target': {'const': self.name}, 'command': {'const': 'sync'}})], - self._sync), - (self.conf['filter'], self._receive)]) + self.sync), + (self.conf['filter'], self.receive)]) - async def _sync(self, message: Message) -> None: - if self._messages: - messages = self._messages - self._messages = [] - try: - await self._open() - coroot_guid = await self._call('attributsknoten', - ['coroot_name', - self.conf['name']]) - if coroot_guid: - comessage_guid = await self._call('erzeuge', - ['comessage']) - if comessage_guid: - await self._call('verknuepfe', [comessage_guid, - coroot_guid]) - await self._call('setze', [comessage_guid, - 'comessage_json', - json.dumps(messages)]) - await self._call('setze', [comessage_guid, - 'comessage_ready', - True]) - else: - raise Exception("Could not create comessage instance") + async def send(self, messages: list[Message]) -> None: + """Send a list of messages to configured graph.""" + async with self.graph_connection as gc: + coroot_guid = await gc.call('attributsknoten', + ['coroot_name', + self.conf['name']]) + if coroot_guid: + comessage_guid = await gc.call('erzeuge', ['comessage']) + if comessage_guid: + await gc.call('verknuepfe', [comessage_guid, + coroot_guid]) + await gc.call('setze', [comessage_guid, + 'comessage_json', + json.dumps(messages)]) + await gc.call('setze', [comessage_guid, + 'comessage_ready', + True]) else: - raise Exception("Did not find coroot instance" - f" {self.conf['name']}.") - await self._close() + raise Exception("Could not create comessage instance") + else: + raise Exception("Did not find coroot instance" + f" '{self.conf['name']}'.") + + async def sync(self, message: Message) -> None: + """Sync cached messages to configured graph.""" + if self.messages: + messages = self.messages + self.messages = [] + try: + await self.send(messages) except Exception as e: self._messages.extend(messages) print(f"Graph connection '{self.name}'" f" to {self.conf['url']}:" - f" Exception in '_sync()': {e}") - - async def _receive(self, message: Message) -> None: - self._messages.append(message) + f" Exception in 'sync()': {e}") - async def _open(self) -> None: - self._open_connections += 1 - if self._open_connections == 1: - # First connection: - (reader, writer) = await asyncio.open_connection(self._host, - self._port, - ssl=self._ssl_ctx) - if writer and reader: - # Read banner: - size_bytes = await reader.readexactly(4) - size_int = struct.unpack(' Any: - if self._writer and self._reader: - async with self._call_lock: - # Build request: - self._call_id += 1 - request = {'jsonrpc': '2.0', 'method': method, - 'params': params, 'id': self._call_id} - message = msgpack.packb(request) - size_bytes = struct.pack(' None: - if self._open_connections > 0: - self._open_connections -= 1 - if self._open_connections == 0: - if self._writer: - # Close connection: - self._writer.close() - self._reader = None - self._writer = None + async def receive(self, message: Message) -> None: + """Receive message through controlpi bus.""" + self.messages.append(message) async def run(self) -> None: """Get coroot instance for name and send connection opened event.""" try: - await self._open() - coroot_guid = await self._call('attributsknoten', - ['coroot_name', - self.conf['name']]) - if coroot_guid: - comessage_guid = await self._call('erzeuge', - ['comessage']) - if comessage_guid: - await self._call('verknuepfe', [comessage_guid, - coroot_guid]) - messages = [Message(self.name, - {'event': 'connection opened'})] - await self._call('setze', [comessage_guid, - 'comessage_json', - json.dumps(messages)]) - await self._call('setze', [comessage_guid, - 'comessage_ready', - True]) - else: - raise Exception("Could not create comessage instance") - else: - raise Exception("Did not find coroot instance" - f" {self.conf['name']}.") - await self._close() + await self.send([Message(self.name, + {'event': 'connection opened'})]) except Exception as e: print(f"Graph connection '{self.name}'" f" to {self.conf['url']}:" diff --git a/controlpi_plugins/graph_connection.py b/controlpi_plugins/graph_connection.py new file mode 100644 index 0000000..8a9d305 --- /dev/null +++ b/controlpi_plugins/graph_connection.py @@ -0,0 +1,115 @@ +"""Provide an asynchronous context manager for graph connections.""" +import asyncio +from dataclasses import dataclass +import msgpack # type: ignore +import os.path +import ssl +import struct +import urllib.parse + + +@dataclass +class GraphConnectionData: + """Data for one graph connection.""" + + lock: asyncio.Lock + contexts: int = 0 + reader: asyncio.StreamReader | None = None + writer: asyncio.StreamWriter | None = None + message_count: int = 0 + opened: int = 0 + closed: int = 0 + + +class GraphConnection: + """Asynchronous graph connection context manager.""" + + connections: dict[str, GraphConnectionData] = {} + + def __init__(self, url: str, crt: str) -> None: + """Initialise with graph URL and TLS certificate.""" + self.url = url + parsed_url = urllib.parse.urlparse(url) + if parsed_url.scheme != 'tls': + raise Exception("Only implemented scheme is 'tls'.") + self.hostname = parsed_url.hostname + self.port = parsed_url.port + if not os.path.isfile(crt): + raise Exception(f"Cannot find certificate file '{crt}'.") + self.ssl = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + self.ssl.load_cert_chain(crt) + if self.url not in self.connections: + self.connections[self.url] = GraphConnectionData(asyncio.Lock()) + self.connection = self.connections[self.url] + + async def __aenter__(self) -> 'GraphConnection': + """Open connection if first context.""" + async with self.connection.lock: + self.connection.contexts += 1 + if self.connection.contexts == 1: + # Open connection: + (r, w) = await asyncio.open_connection(self.hostname, + self.port, + ssl=self.ssl) + self.connection.opened += 1 + # Read banner: + size_bytes = await r.readexactly(4) + size_int = struct.unpack(' None: + """Close connection if last context.""" + async with self.connection.lock: + self.connection.contexts -= 1 + if self.connection.contexts == 0: + # Close writer: + writer = self.connection.writer + if writer is not None: + writer.close() + await writer.wait_closed() + self.connection.closed += 1 + # Remove reader and writer from data: + self.connection.reader = None + self.connection.writer = None + + async def call(self, method: str, params: list): + """Execute a call on connection.""" + async with self.connection.lock: + reader = self.connection.reader + writer = self.connection.writer + if reader is not None and writer is not None: + self.connection.message_count += 1 + # Build and send request: + request = {'jsonrpc': '2.0', 'method': method, + 'params': params, + 'id': self.connection.message_count} + message = msgpack.packb(request) + size_bytes = struct.pack(' None: + async with GraphConnection('tls://it.hsrobotics:4339', + 'graph.crt') as gc: + print(f"Task {instance} started.") + remotegraphen = await gc.call('alleattribute', + ['remotegraph', + 'name, url', + 'name asc', + 'name != "local"']) + print(f"Task {instance}: {len(remotegraphen)} remote graphs") + for remotegraph_guid, remotegraph in remotegraphen.items(): + async with GraphConnection(remotegraph['url'], + 'graph.crt') as inner_gc: + graphmodule = await inner_gc.call('alleattribute', + ['graphmodul', + 'prefix', + 'prefix asc']) + module = ', '.join([graphmodul['prefix'] + for graphmodul in graphmodule.values()]) + print(f"Task {instance}: {remotegraph['name']}: {module}") + print(f"Task {instance} finished.") + + +async def main() -> None: + await asyncio.gather(get_remote_modules(1), + get_remote_modules(2), + get_remote_modules(3)) + pprint(GraphConnection.connections) + + +asyncio.run(main()) -- 2.34.1