From 76874ca7721402e38bba60ebbd29719bba9e3c5c Mon Sep 17 00:00:00 2001 From: Benjamin Braatz Date: Wed, 28 May 2025 15:35:38 +0200 Subject: [PATCH] Correct handling of exceptions during __aenter__ and __aexit__. --- controlpi_plugins/graph_connection.py | 45 ++++++++++----------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/controlpi_plugins/graph_connection.py b/controlpi_plugins/graph_connection.py index 1d1358e..86d707b 100644 --- a/controlpi_plugins/graph_connection.py +++ b/controlpi_plugins/graph_connection.py @@ -45,8 +45,7 @@ class GraphConnection: async def __aenter__(self) -> 'GraphConnection': """Open connection if first context.""" async with self.connection.lock: - self.connection.contexts += 1 - if self.connection.contexts == 1: + if self.connection.contexts == 0: # Open connection: (r, w) = await asyncio.open_connection(self.hostname, self.port, @@ -59,22 +58,27 @@ class GraphConnection: # Set reader and writer in data: self.connection.reader = r self.connection.writer = w + # Increase contexts using this connection + # (only if no exception until here): + self.connection.contexts += 1 + return self async def __aexit__(self, exc_type, exc_value, traceback) -> None: """Close connection if last context.""" async with self.connection.lock: + # Decrease contexts using this connection: self.connection.contexts -= 1 if self.connection.contexts == 0: - # Close writer: writer = self.connection.writer - if writer is not None: - writer.close() - await writer.wait_closed() - self.connection.closed += 1 # Remove reader and writer from data: self.connection.reader = None self.connection.writer = None + # Close writer (if existed): + if writer is not None: + writer.close() + await writer.wait_closed() + self.connection.closed += 1 async def call(self, method: str, params: list): """Execute a call on connection.""" @@ -112,25 +116,10 @@ class GraphConnection: # Return result: return response['result'] else: - if writer is not None: - writer.close() - self.connection.closed += 1 - # Open connection: - (r, w) = await asyncio.open_connection(self.hostname, - self.port, - ssl=self.ssl) - self.connection.opened += 1 - # Read banner: - size_bytes = await r.readexactly(4) - size_int = struct.unpack('