From: Benjamin Braatz Date: Wed, 28 May 2025 13:35:38 +0000 (+0200) Subject: Correct handling of exceptions during __aenter__ and __aexit__. X-Git-Url: http://git.graph-it.com/?a=commitdiff_plain;h=76874ca7721402e38bba60ebbd29719bba9e3c5c;p=graphit%2Fcontrolpi-graph.git Correct handling of exceptions during __aenter__ and __aexit__. --- diff --git a/controlpi_plugins/graph_connection.py b/controlpi_plugins/graph_connection.py index 1d1358e..86d707b 100644 --- a/controlpi_plugins/graph_connection.py +++ b/controlpi_plugins/graph_connection.py @@ -45,8 +45,7 @@ class GraphConnection: async def __aenter__(self) -> 'GraphConnection': """Open connection if first context.""" async with self.connection.lock: - self.connection.contexts += 1 - if self.connection.contexts == 1: + if self.connection.contexts == 0: # Open connection: (r, w) = await asyncio.open_connection(self.hostname, self.port, @@ -59,22 +58,27 @@ class GraphConnection: # Set reader and writer in data: self.connection.reader = r self.connection.writer = w + # Increase contexts using this connection + # (only if no exception until here): + self.connection.contexts += 1 + return self async def __aexit__(self, exc_type, exc_value, traceback) -> None: """Close connection if last context.""" async with self.connection.lock: + # Decrease contexts using this connection: self.connection.contexts -= 1 if self.connection.contexts == 0: - # Close writer: writer = self.connection.writer - if writer is not None: - writer.close() - await writer.wait_closed() - self.connection.closed += 1 # Remove reader and writer from data: self.connection.reader = None self.connection.writer = None + # Close writer (if existed): + if writer is not None: + writer.close() + await writer.wait_closed() + self.connection.closed += 1 async def call(self, method: str, params: list): """Execute a call on connection.""" @@ -112,25 +116,10 @@ class GraphConnection: # Return result: return response['result'] else: - if writer is not None: - writer.close() - self.connection.closed += 1 - # Open connection: - (r, w) = await asyncio.open_connection(self.hostname, - self.port, - ssl=self.ssl) - self.connection.opened += 1 - # Read banner: - size_bytes = await r.readexactly(4) - size_int = struct.unpack('