Refactor into context manager.
authorBenjamin Braatz <bb@bbraatz.eu>
Thu, 22 May 2025 12:57:22 +0000 (14:57 +0200)
committerBenjamin Braatz <bb@bbraatz.eu>
Thu, 22 May 2025 12:57:22 +0000 (14:57 +0200)
conf.json
controlpi_plugins/graph.py
controlpi_plugins/graph_connection.py [new file with mode: 0644]
test_graph_connection.py [new file with mode: 0644]

index 0fb89eb737cf8ab666c2e71bd9ff114322688e12..6b11cba20c76ebe0af12834322ece85fc3096154 100644 (file)
--- a/conf.json
+++ b/conf.json
@@ -1,7 +1,7 @@
 {
   "Master": {
     "plugin": "WSServer",
-    "port": 8080,
+    "port": 8123,
     "web": {
       "/": {
         "module": "controlpi_plugins.wsserver",
@@ -14,9 +14,9 @@
   },
   "Graph": {
     "plugin": "Graph",
-    "url": "tls://graph.example.com",
+    "url": "tls://it.hsrobotics:4339",
     "crt": "graph.crt",
-    "name": "te",
+    "name": "d1",
     "filter": [
       {
         "sender": { "const": "Example State" },
index 23572a64dd6435dbae434049a093f7ada127da33..2c43c13c7b56faca7242b72689f7d5ad080aa807 100644 (file)
@@ -1,14 +1,8 @@
 """Provide Graph Connections as ControlPi Plugin."""
-import asyncio
-import os.path
-import ssl
-import struct
-import urllib.parse
-import msgpack  # type: ignore
 import json
 from controlpi import BasePlugin, Message, MessageTemplate
 
-from typing import List, Dict, Any
+from controlpi_plugins.graph_connection import GraphConnection
 
 
 class Graph(BasePlugin):
@@ -24,147 +18,61 @@ class Graph(BasePlugin):
 
     def process_conf(self) -> None:
         """Register plugin as bus client."""
-        res = urllib.parse.urlparse(self.conf['url'])
-        if res.scheme != 'tls':
-            raise NotImplementedError("Only implemented scheme is 'tls'.")
-        self._host = res.hostname
-        self._port = res.port
-        if not os.path.isfile(self.conf['crt']):
-            raise FileNotFoundError("Cannot find certificate file"
-                                    f"'{self.conf['crt']}'.")
-        self._ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
-        self._ssl_ctx.load_cert_chain(self.conf['crt'])
-        self._open_connections = 0
-        self._call_lock = asyncio.Lock()
-        self._messages: List[Message] = []
+        self.graph_connection = GraphConnection(self.conf['url'],
+                                                self.conf['crt'])
+        self.messages: list[Message] = []
         self.bus.register(self.name, 'Graph',
                           [],
                           [([MessageTemplate({'target': {'const': self.name},
                                               'command': {'const': 'sync'}})],
-                            self._sync),
-                           (self.conf['filter'], self._receive)])
+                            self.sync),
+                           (self.conf['filter'], self.receive)])
 
-    async def _sync(self, message: Message) -> None:
-        if self._messages:
-            messages = self._messages
-            self._messages = []
-            try:
-                await self._open()
-                coroot_guid = await self._call('attributsknoten',
-                                               ['coroot_name',
-                                                self.conf['name']])
-                if coroot_guid:
-                    comessage_guid = await self._call('erzeuge',
-                                                      ['comessage'])
-                    if comessage_guid:
-                        await self._call('verknuepfe', [comessage_guid,
-                                                        coroot_guid])
-                        await self._call('setze', [comessage_guid,
-                                                   'comessage_json',
-                                                   json.dumps(messages)])
-                        await self._call('setze', [comessage_guid,
-                                                   'comessage_ready',
-                                                   True])
-                    else:
-                        raise Exception("Could not create comessage instance")
+    async def send(self, messages: list[Message]) -> None:
+        """Send a list of messages to configured graph."""
+        async with self.graph_connection as gc:
+            coroot_guid = await gc.call('attributsknoten',
+                                        ['coroot_name',
+                                         self.conf['name']])
+            if coroot_guid:
+                comessage_guid = await gc.call('erzeuge', ['comessage'])
+                if comessage_guid:
+                    await gc.call('verknuepfe', [comessage_guid,
+                                                 coroot_guid])
+                    await gc.call('setze', [comessage_guid,
+                                            'comessage_json',
+                                            json.dumps(messages)])
+                    await gc.call('setze', [comessage_guid,
+                                            'comessage_ready',
+                                            True])
                 else:
-                    raise Exception("Did not find coroot instance"
-                                    f" {self.conf['name']}.")
-                await self._close()
+                    raise Exception("Could not create comessage instance")
+            else:
+                raise Exception("Did not find coroot instance"
+                                f" '{self.conf['name']}'.")
+
+    async def sync(self, message: Message) -> None:
+        """Sync cached messages to configured graph."""
+        if self.messages:
+            messages = self.messages
+            self.messages = []
+            try:
+                await self.send(messages)
             except Exception as e:
                 self._messages.extend(messages)
                 print(f"Graph connection '{self.name}'"
                       f" to {self.conf['url']}:"
-                      f" Exception in '_sync()': {e}")
-
-    async def _receive(self, message: Message) -> None:
-        self._messages.append(message)
+                      f" Exception in 'sync()': {e}")
 
-    async def _open(self) -> None:
-        self._open_connections += 1
-        if self._open_connections == 1:
-            # First connection:
-            (reader, writer) = await asyncio.open_connection(self._host,
-                                                             self._port,
-                                                             ssl=self._ssl_ctx)
-            if writer and reader:
-                # Read banner:
-                size_bytes = await reader.readexactly(4)
-                size_int = struct.unpack('<i', size_bytes)[0]
-                message = await reader.readexactly(size_int)
-                # Inititalise call id:
-                self._call_id = 0
-                # Set attributes:
-                self._reader = reader
-                self._writer = writer
-
-    async def _call(self, method: str, params: List[Any]) -> Any:
-        if self._writer and self._reader:
-            async with self._call_lock:
-                # Build request:
-                self._call_id += 1
-                request = {'jsonrpc': '2.0', 'method': method,
-                           'params': params, 'id': self._call_id}
-                message = msgpack.packb(request)
-                size_bytes = struct.pack('<i', len(message))
-                # Write request:
-                self._writer.write(size_bytes)
-                self._writer.write(message)
-                await self._writer.drain()
-                # Read response:
-                size_bytes = await self._reader.readexactly(4)
-                size_int = struct.unpack('<i', size_bytes)[0]
-                message = await self._reader.readexactly(size_int)
-                response = msgpack.unpackb(message)
-                if ('jsonrpc' not in response or
-                        response['jsonrpc'] != request['jsonrpc']):
-                    raise Exception(f"Not a JSON-RPC 2.0 response: {response}")
-                if 'error' in response:
-                    raise Exception("JSON-RPC remote error:"
-                                    f" {response[b'error']}")
-                if 'id' not in response or response['id'] != request['id']:
-                    raise Exception("JSON-RPC id missing or invalid.")
-                return response['result']
-        else:
-            raise Exception("Reader or writer for graph connection not found.")
-
-    async def _close(self) -> None:
-        if self._open_connections > 0:
-            self._open_connections -= 1
-            if self._open_connections == 0:
-                if self._writer:
-                    # Close connection:
-                    self._writer.close()
-                self._reader = None
-                self._writer = None
+    async def receive(self, message: Message) -> None:
+        """Receive message through controlpi bus."""
+        self.messages.append(message)
 
     async def run(self) -> None:
         """Get coroot instance for name and send connection opened event."""
         try:
-            await self._open()
-            coroot_guid = await self._call('attributsknoten',
-                                           ['coroot_name',
-                                            self.conf['name']])
-            if coroot_guid:
-                comessage_guid = await self._call('erzeuge',
-                                                  ['comessage'])
-                if comessage_guid:
-                    await self._call('verknuepfe', [comessage_guid,
-                                                    coroot_guid])
-                    messages = [Message(self.name,
-                                        {'event': 'connection opened'})]
-                    await self._call('setze', [comessage_guid,
-                                               'comessage_json',
-                                               json.dumps(messages)])
-                    await self._call('setze', [comessage_guid,
-                                               'comessage_ready',
-                                               True])
-                else:
-                    raise Exception("Could not create comessage instance")
-            else:
-                raise Exception("Did not find coroot instance"
-                                f" {self.conf['name']}.")
-            await self._close()
+            await self.send([Message(self.name,
+                                     {'event': 'connection opened'})])
         except Exception as e:
             print(f"Graph connection '{self.name}'"
                   f" to {self.conf['url']}:"
diff --git a/controlpi_plugins/graph_connection.py b/controlpi_plugins/graph_connection.py
new file mode 100644 (file)
index 0000000..8a9d305
--- /dev/null
@@ -0,0 +1,115 @@
+"""Provide an asynchronous context manager for graph connections."""
+import asyncio
+from dataclasses import dataclass
+import msgpack  # type: ignore
+import os.path
+import ssl
+import struct
+import urllib.parse
+
+
+@dataclass
+class GraphConnectionData:
+    """Data for one graph connection."""
+
+    lock: asyncio.Lock
+    contexts: int = 0
+    reader: asyncio.StreamReader | None = None
+    writer: asyncio.StreamWriter | None = None
+    message_count: int = 0
+    opened: int = 0
+    closed: int = 0
+
+
+class GraphConnection:
+    """Asynchronous graph connection context manager."""
+
+    connections: dict[str, GraphConnectionData] = {}
+
+    def __init__(self, url: str, crt: str) -> None:
+        """Initialise with graph URL and TLS certificate."""
+        self.url = url
+        parsed_url = urllib.parse.urlparse(url)
+        if parsed_url.scheme != 'tls':
+            raise Exception("Only implemented scheme is 'tls'.")
+        self.hostname = parsed_url.hostname
+        self.port = parsed_url.port
+        if not os.path.isfile(crt):
+            raise Exception(f"Cannot find certificate file '{crt}'.")
+        self.ssl = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
+        self.ssl.load_cert_chain(crt)
+        if self.url not in self.connections:
+            self.connections[self.url] = GraphConnectionData(asyncio.Lock())
+        self.connection = self.connections[self.url]
+
+    async def __aenter__(self) -> 'GraphConnection':
+        """Open connection if first context."""
+        async with self.connection.lock:
+            self.connection.contexts += 1
+            if self.connection.contexts == 1:
+                # Open connection:
+                (r, w) = await asyncio.open_connection(self.hostname,
+                                                       self.port,
+                                                       ssl=self.ssl)
+                self.connection.opened += 1
+                # Read banner:
+                size_bytes = await r.readexactly(4)
+                size_int = struct.unpack('<i', size_bytes)[0]
+                message = await r.readexactly(size_int)
+                # Set reader and writer in data:
+                self.connection.reader = r
+                self.connection.writer = w
+        return self
+
+    async def __aexit__(self, exc_type, exc_value, traceback) -> None:
+        """Close connection if last context."""
+        async with self.connection.lock:
+            self.connection.contexts -= 1
+            if self.connection.contexts == 0:
+                # Close writer:
+                writer = self.connection.writer
+                if writer is not None:
+                    writer.close()
+                    await writer.wait_closed()
+                self.connection.closed += 1
+                # Remove reader and writer from data:
+                self.connection.reader = None
+                self.connection.writer = None
+
+    async def call(self, method: str, params: list):
+        """Execute a call on connection."""
+        async with self.connection.lock:
+            reader = self.connection.reader
+            writer = self.connection.writer
+            if reader is not None and writer is not None:
+                self.connection.message_count += 1
+                # Build and send request:
+                request = {'jsonrpc': '2.0', 'method': method,
+                           'params': params,
+                           'id': self.connection.message_count}
+                message = msgpack.packb(request)
+                size_bytes = struct.pack('<i', len(message))
+                writer.write(size_bytes)
+                writer.write(message)
+                await writer.drain()
+                # Receive and unpack response:
+                size_bytes = await reader.readexactly(4)
+                size_int = struct.unpack('<i', size_bytes)[0]
+                message = await reader.readexactly(size_int)
+                response = msgpack.unpackb(message)
+                # Raise errors:
+                if ('jsonrpc' not in response or
+                        response['jsonrpc'] != request['jsonrpc']):
+                    raise Exception("Not a JSON-RPC 2.0 response:"
+                                    f" '{response}'")
+                if 'error' in response:
+                    raise Exception("JSON-RPC remote error:"
+                                    f" {response['error']}")
+                if 'id' not in response:
+                    raise Exception("JSON-RPC id missing.")
+                if response['id'] != request['id']:
+                    raise Exception("JSON-RPC id invalid.")
+                # Return result:
+                return response['result']
+            else:
+                raise Exception("Reader or writer missing.")
diff --git a/test_graph_connection.py b/test_graph_connection.py
new file mode 100644 (file)
index 0000000..bf668a4
--- /dev/null
@@ -0,0 +1,38 @@
+import asyncio
+import json
+from pprint import pprint
+
+from controlpi_plugins.graph_connection import GraphConnection
+
+
+async def get_remote_modules(instance: int) -> None:
+    async with GraphConnection('tls://it.hsrobotics:4339',
+                               'graph.crt') as gc:
+        print(f"Task {instance} started.")
+        remotegraphen = await gc.call('alleattribute',
+                                      ['remotegraph',
+                                       'name, url',
+                                       'name asc',
+                                       'name != "local"'])
+        print(f"Task {instance}: {len(remotegraphen)} remote graphs")
+        for remotegraph_guid, remotegraph in remotegraphen.items():
+            async with GraphConnection(remotegraph['url'],
+                                       'graph.crt') as inner_gc:
+                graphmodule = await inner_gc.call('alleattribute',
+                                                  ['graphmodul',
+                                                   'prefix',
+                                                   'prefix asc'])
+                module = ', '.join([graphmodul['prefix']
+                                    for graphmodul in graphmodule.values()])
+                print(f"Task {instance}: {remotegraph['name']}: {module}")
+    print(f"Task {instance} finished.")
+
+
+async def main() -> None:
+    await asyncio.gather(get_remote_modules(1),
+                         get_remote_modules(2),
+                         get_remote_modules(3))
+    pprint(GraphConnection.connections)
+
+
+asyncio.run(main())