improve framecounter thread safety

This commit is contained in:
onyx-and-iris 2026-03-02 20:44:08 +00:00
parent 7d741d6e8b
commit a210766b7b
2 changed files with 48 additions and 48 deletions

View File

@ -44,6 +44,7 @@ class VbanCmd(abc.ABC):
setattr(self, attr, val) setattr(self, attr, val)
self._framecounter = 0 self._framecounter = 0
self._framecounter_lock = threading.Lock()
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.sock.settimeout(self.timeout) self.sock.settimeout(self.timeout)
@ -130,13 +131,19 @@ class VbanCmd(abc.ABC):
def stopped(self): def stopped(self):
return self.stop_event is None or self.stop_event.is_set() return self.stop_event is None or self.stop_event.is_set()
def _get_next_framecounter(self) -> int:
"""Thread-safe method to get and increment framecounter."""
with self._framecounter_lock:
current = self._framecounter
self._framecounter = bump_framecounter(self._framecounter)
return current
def _ping(self, timeout: float = None) -> None: def _ping(self, timeout: float = None) -> None:
"""Send a PING packet and wait for PONG response to verify connectivity.""" """Send a PING packet and wait for PONG response to verify connectivity."""
if timeout is None: if timeout is None:
timeout = min(self.timeout, 3.0) timeout = min(self.timeout, 3.0)
ping_packet = VbanPing0Payload.create_packet(self._framecounter) ping_packet = VbanPing0Payload.create_packet(self._get_next_framecounter())
self._framecounter = bump_framecounter(self._framecounter)
original_timeout = self.sock.gettimeout() original_timeout = self.sock.gettimeout()
self.sock.settimeout(0.5) self.sock.settimeout(0.5)
@ -220,12 +227,11 @@ class VbanCmd(abc.ABC):
name=self.streamname, name=self.streamname,
bps_index=self.BPS_OPTS.index(self.bps), bps_index=self.BPS_OPTS.index(self.bps),
channel=self.channel, channel=self.channel,
framecounter=self._framecounter, framecounter=self._get_next_framecounter(),
payload=payload, payload=payload,
), ),
(socket.gethostbyname(self.ip), self.port), (socket.gethostbyname(self.ip), self.port),
) )
self._framecounter = bump_framecounter(self._framecounter)
def _set_rt(self, cmd: str, val: Union[str, float]): def _set_rt(self, cmd: str, val: Union[str, float]):
"""Sends a string request command over a network.""" """Sends a string request command over a network."""
@ -239,12 +245,16 @@ class VbanCmd(abc.ABC):
if self.disable_rt_listeners and script.endswith(('?', '?;')): if self.disable_rt_listeners and script.endswith(('?', '?;')):
try: try:
response = VbanMatrixResponseHeader.extract_payload( data, _ = self.sock.recvfrom(2048)
self.sock.recv(1024) payload = VbanMatrixResponseHeader.extract_payload(data)
)
return response
except ValueError as e: except ValueError as e:
self.logger.warning(f'Error extracting matrix response: {e}') self.logger.warning(f'Error extracting matrix response: {e}')
except TimeoutError as e:
self.logger.exception(f'Timeout waiting for matrix response: {e}')
raise VBANCMDConnectionError(
f'Timeout waiting for response from {self.ip}:{self.port}'
) from e
return payload
time.sleep(self.DELAY) time.sleep(self.DELAY)

View File

@ -1,5 +1,4 @@
import logging import logging
import socket
import threading import threading
import time import time
@ -13,7 +12,6 @@ from .packet.headers import (
) )
from .packet.nbs0 import VbanPacketNBS0 from .packet.nbs0 import VbanPacketNBS0
from .packet.nbs1 import VbanPacketNBS1 from .packet.nbs1 import VbanPacketNBS1
from .util import bump_framecounter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,24 +24,24 @@ class Subscriber(threading.Thread):
self._remote = remote self._remote = remote
self.stop_event = stop_event self.stop_event = stop_event
self.logger = logger.getChild(self.__class__.__name__) self.logger = logger.getChild(self.__class__.__name__)
self._framecounter = 0
def run(self): def run(self):
while not self.stopped(): while not self.stopped():
try: try:
for nbs in NBS: for nbs in NBS:
sub_packet = VbanSubscribeHeader().to_bytes(nbs, self._framecounter) sub_packet = VbanSubscribeHeader().to_bytes(
nbs, self._remote._get_next_framecounter()
)
self._remote.sock.sendto( self._remote.sock.sendto(
sub_packet, (self._remote.ip, self._remote.port) sub_packet, (self._remote.ip, self._remote.port)
) )
self._framecounter = bump_framecounter(self._framecounter) except TimeoutError as e:
self.wait_until_stopped(10)
except socket.gaierror as e:
self.logger.exception(f'{type(e).__name__}: {e}') self.logger.exception(f'{type(e).__name__}: {e}')
raise VBANCMDConnectionError( raise VBANCMDConnectionError(
f'unable to resolve hostname {self._remote.ip}' f'timeout sending subscription to {self._remote.ip}:{self._remote.port}'
) from e ) from e
self.wait_until_stopped(10)
self.logger.debug(f'terminating {self.name} thread') self.logger.debug(f'terminating {self.name} thread')
def stopped(self): def stopped(self):
@ -66,7 +64,6 @@ class Producer(threading.Thread):
self.queue = queue self.queue = queue
self.stop_event = stop_event self.stop_event = stop_event
self.logger = logger.getChild(self.__class__.__name__) self.logger = logger.getChild(self.__class__.__name__)
self._remote.sock.settimeout(self._remote.timeout)
self._remote._public_packets = [None] * (max(NBS) + 1) self._remote._public_packets = [None] * (max(NBS) + 1)
_pp = self._get_rt() _pp = self._get_rt()
self._remote._public_packets[_pp.nbs] = _pp self._remote._public_packets[_pp.nbs] = _pp
@ -77,16 +74,11 @@ class Producer(threading.Thread):
def _get_rt(self) -> VbanPacket: def _get_rt(self) -> VbanPacket:
"""Attempt to fetch data packet until a valid one found""" """Attempt to fetch data packet until a valid one found"""
while True: while True:
if resp := self._fetch_rt_packet():
return resp
def _fetch_rt_packet(self) -> VbanPacket | None:
try: try:
data, _ = self._remote.sock.recvfrom(2048) data, _ = self._remote.sock.recvfrom(2048)
if len(data) < HEADER_SIZE: if len(data) < HEADER_SIZE:
return continue
except TimeoutError as e: except TimeoutError as e:
self.logger.exception(f'{type(e).__name__}: {e}') self.logger.exception(f'{type(e).__name__}: {e}')
raise VBANCMDConnectionError( raise VBANCMDConnectionError(
@ -97,7 +89,7 @@ class Producer(threading.Thread):
header = VbanResponseHeader.from_bytes(data[:HEADER_SIZE]) header = VbanResponseHeader.from_bytes(data[:HEADER_SIZE])
except ValueError as e: except ValueError as e:
self.logger.debug(f'Error parsing response packet: {e}') self.logger.debug(f'Error parsing response packet: {e}')
return None continue
match header.format_nbs: match header.format_nbs:
case NBS.zero: case NBS.zero:
@ -110,8 +102,6 @@ class Producer(threading.Thread):
nbs=NBS.one, kind=self._remote.kind, data=data nbs=NBS.one, kind=self._remote.kind, data=data
) )
return None
def stopped(self): def stopped(self):
return self.stop_event.is_set() return self.stop_event.is_set()