Source code for sttp.transport.datapublisher

# ******************************************************************************************************
#  datapublisher.py - Gbtc
#
#  Copyright © 2026, Grid Protection Alliance.  All Rights Reserved.
#
#  Licensed to the Grid Protection Alliance (GPA) under one or more contributor license agreements. See
#  the NOTICE file distributed with this work for additional information regarding copyright ownership.
#  The GPA licenses this file to you under the MIT License (MIT), the "License"; you may not use this
#  file except in compliance with the License. You may obtain a copy of the License at:
#
#      http://opensource.org/licenses/MIT
#
#  Unless agreed to in writing, the subject software distributed under the License is distributed on an
#  "AS-IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. Refer to the
#  License for the specific language governing permissions and limitations.
#
#  Code Modification History:
#  ----------------------------------------------------------------------------------------------------
#  01/06/2026 - Generated by porting C++ DataPublisher
#       Ported from cppapi/src/lib/transport/DataPublisher.{h,cpp}
#
# ******************************************************************************************************

# Ported from cppapi/src/lib/transport/DataPublisher.cpp : class DataPublisher
# Differences: Python uses socket server and threading; otherwise parity maintained.

from ..data.dataset import DataSet
from ..data.filterexpressionparser import FilterExpressionParser
from .measurement import Measurement
from .subscriberconnection import SubscriberConnection
from sttp.metadata.record.measurement import MeasurementRecord as MetadataMeasurementRecord


# Note: MeasurementRecord for filtering is imported from sttp.metadata.record.measurement
# This provides full measurement metadata, matching C++ MeasurementMetadata

from .routingtables import RoutingTables
from .constants import SecurityMode
from typing import List, Optional, Callable, Set
from uuid import UUID, uuid4
from threading import Thread, RLock
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
import socket
import numpy as np


[docs] class DataPublisher: """ Represents a data publisher that accepts subscriber connections and publishes measurements. """ def __init__(self): """ Creates a new data publisher. """ # Configuration self._node_id = uuid4() self._security_mode = SecurityMode.OFF self._maximum_allowed_connections = -1 # -1 = unlimited self._is_metadata_refresh_allowed = True self._is_nan_value_filter_allowed = True self._is_nan_value_filter_forced = False self._supports_temporal_subscriptions = False self._use_base_time_offsets = True self._cipher_key_rotation_period = 60000 # State self._started = False self._stopped = False self._shuttingdown = False self._reverse_connection = False # Metadata self._metadata: Optional[DataSet] = None self._filtering_metadata: Optional[DataSet] = None # Connections self._subscriber_connections: Set[SubscriberConnection] = set() self._subscriber_connections_lock = RLock() # Routing self._routing_tables = RoutingTables() # Server socket self._server_socket: Optional[socket.socket] = None self._accept_thread: Optional[Thread] = None self._port = np.uint16(0) self._ipv6 = False # Thread pool for callbacks self._callback_executor = ThreadPoolExecutor(max_workers=4) self._callback_queue: Queue = Queue() self._callback_thread: Optional[Thread] = None # Callbacks self.statusmessage_callback: Optional[Callable[[str], None]] = None self.errormessage_callback: Optional[Callable[[str], None]] = None self.clientconnected_callback: Optional[Callable[[SubscriberConnection], None]] = None self.clientdisconnected_callback: Optional[Callable[[SubscriberConnection], None]] = None self.processingintervalchange_callback: Optional[Callable[[SubscriberConnection], None]] = None self.temporalsubscription_requested_callback: Optional[Callable[[SubscriberConnection], None]] = None self.temporalsubscription_canceled_callback: Optional[Callable[[SubscriberConnection], None]] = None self.usercommand_callback: Optional[Callable[[SubscriberConnection, np.uint8, bytes], None]] = None # User data self._user_data = None # Start callback thread self._callback_thread = Thread(target=self._process_callbacks, daemon=True) self._callback_thread.start() # Properties @property def node_id(self) -> UUID: """Gets the node ID.""" return self._node_id @node_id.setter def node_id(self, value: UUID): """Sets the node ID.""" self._node_id = value @property def metadata(self) -> Optional[DataSet]: """Gets the primary metadata.""" return self._metadata @property def filtering_metadata(self) -> Optional[DataSet]: """Gets the filtering metadata.""" return self._filtering_metadata @property def security_mode(self) -> SecurityMode: """Gets the security mode.""" return self._security_mode @security_mode.setter def security_mode(self, value: SecurityMode): """Sets the security mode.""" self._security_mode = value @property def maximum_allowed_connections(self) -> int: """Gets the maximum allowed connections (-1 = unlimited).""" return self._maximum_allowed_connections @maximum_allowed_connections.setter def maximum_allowed_connections(self, value: int): """Sets the maximum allowed connections.""" self._maximum_allowed_connections = value @property def is_metadata_refresh_allowed(self) -> bool: """Gets flag indicating if metadata refresh is allowed.""" return self._is_metadata_refresh_allowed @is_metadata_refresh_allowed.setter def is_metadata_refresh_allowed(self, value: bool): """Sets metadata refresh allowed flag.""" self._is_metadata_refresh_allowed = value @property def is_nan_value_filter_allowed(self) -> bool: """Gets flag indicating if NaN value filtering is allowed.""" return self._is_nan_value_filter_allowed @is_nan_value_filter_allowed.setter def is_nan_value_filter_allowed(self, value: bool): """Sets NaN value filter allowed flag.""" self._is_nan_value_filter_allowed = value @property def is_nan_value_filter_forced(self) -> bool: """Gets flag indicating if NaN value filtering is forced.""" return self._is_nan_value_filter_forced @is_nan_value_filter_forced.setter def is_nan_value_filter_forced(self, value: bool): """Sets NaN value filter forced flag.""" self._is_nan_value_filter_forced = value @property def supports_temporal_subscriptions(self) -> bool: """Gets flag indicating if temporal subscriptions are supported.""" return self._supports_temporal_subscriptions @supports_temporal_subscriptions.setter def supports_temporal_subscriptions(self, value: bool): """Sets temporal subscriptions supported flag.""" self._supports_temporal_subscriptions = value @property def use_base_time_offsets(self) -> bool: """Gets flag indicating if base time offsets should be used.""" return self._use_base_time_offsets @use_base_time_offsets.setter def use_base_time_offsets(self, value: bool): """Sets use base time offsets flag.""" self._use_base_time_offsets = value @property def port(self) -> np.uint16: """Gets the listening port.""" return self._port @property def is_ipv6(self) -> bool: """Gets flag indicating if using IPv6.""" return self._ipv6 @property def is_started(self) -> bool: """Gets flag indicating if publisher is started.""" return self._started # Metadata management
[docs] def define_metadata(self, metadata: DataSet): """ Defines the publisher metadata from a DataSet. Builds the ActiveMeasurements filtering metadata table from the source metadata. """ import os from sttp.data.dataset import DataSet from decimal import Decimal self._metadata = metadata # Helper functions matching C++ implementation def get_protocol_type(protocol_name: str) -> str: """Returns protocol type based on protocol name.""" if not protocol_name: return "Frame" protocol_upper = protocol_name.upper() if (protocol_upper == "STREAMING TELEMETRY TRANSPORT PROTOCOL" or protocol_upper == "STTP" or protocol_upper.startswith("GATEWAY") or protocol_upper.startswith("MODBUS") or protocol_upper.startswith("DNP")): return "Measurement" return "Frame" def get_engineering_units(signal_type: str) -> str: """Returns engineering units based on signal type.""" if not signal_type: return "" signal_upper = signal_type.upper() if signal_upper == "IPHM": return "Amps" if signal_upper == "VPHM": return "Volts" if signal_upper == "FREQ": return "Hz" if signal_upper.endswith("PHA"): return "Degrees" return "" def get_column_index(table, column_name: str) -> int: """Gets column index by name, raises error if not found.""" col = table.column_byname(column_name) if col is None: raise ValueError(f"Column name '{column_name}' was not found in table '{table.name}'") return col.index # Create device data map used to build flatter metadata view device_detail = metadata.table("DeviceDetail") device_data = {} if device_detail: acronym_idx = get_column_index(device_detail, "Acronym") protocol_idx = get_column_index(device_detail, "ProtocolName") fps_idx = get_column_index(device_detail, "FramesPerSecond") company_idx = get_column_index(device_detail, "CompanyAcronym") longitude_idx = get_column_index(device_detail, "Longitude") latitude_idx = get_column_index(device_detail, "Latitude") for i, row in enumerate(device_detail): device_acronym = row[acronym_idx] if row[acronym_idx] else "" if not device_acronym: continue protocol = row[protocol_idx] if row[protocol_idx] else "" device_data[device_acronym] = { 'DeviceID': i, 'FramesPerSecond': int(row[fps_idx]) if row[fps_idx] else 30, 'Company': row[company_idx] if row[company_idx] else "", 'Protocol': protocol, 'ProtocolType': get_protocol_type(protocol), 'Longitude': Decimal(str(row[longitude_idx])) if row[longitude_idx] else Decimal('0'), 'Latitude': Decimal(str(row[latitude_idx])) if row[latitude_idx] else Decimal('0') } # Create phasor data map phasor_detail = metadata.table("PhasorDetail") phasor_data = {} if phasor_detail: id_idx = get_column_index(phasor_detail, "ID") device_acronym_idx = get_column_index(phasor_detail, "DeviceAcronym") type_idx = get_column_index(phasor_detail, "Type") phase_idx = get_column_index(phasor_detail, "Phase") source_index_idx = get_column_index(phasor_detail, "SourceIndex") for row in phasor_detail: device_acronym = row[device_acronym_idx] if row[device_acronym_idx] else "" if not device_acronym: continue if device_acronym not in phasor_data: phasor_data[device_acronym] = {} source_index = int(row[source_index_idx]) if row[source_index_idx] else 0 phasor_data[device_acronym][source_index] = { 'PhasorID': int(row[id_idx]) if row[id_idx] else 0, 'PhasorType': row[type_idx] if row[type_idx] else "", 'Phase': row[phase_idx] if row[phase_idx] else "" } # Load active measurements schema from embedded resource schema_path = os.path.join(os.path.dirname(__file__), "ActiveMeasurementsSchema.xml") with open(schema_path, 'r') as f: schema_xml = f.read() filtering_metadata, err = DataSet.from_xml(schema_xml) if err: raise RuntimeError(f"Failed to load ActiveMeasurementsSchema.xml: {err}") # Build active measurements table from measurement detail measurement_detail = metadata.table("MeasurementDetail") active_measurements = filtering_metadata.table("ActiveMeasurements") if measurement_detail is not None and active_measurements is not None: try: # Lookup column indices for measurement detail table md_device_acronym = get_column_index(measurement_detail, "DeviceAcronym") md_id = get_column_index(measurement_detail, "ID") md_signal_id = get_column_index(measurement_detail, "SignalID") md_point_tag = get_column_index(measurement_detail, "PointTag") md_signal_reference = get_column_index(measurement_detail, "SignalReference") md_signal_acronym = get_column_index(measurement_detail, "SignalAcronym") md_phasor_source_index = get_column_index(measurement_detail, "PhasorSourceIndex") md_description = get_column_index(measurement_detail, "Description") md_internal = get_column_index(measurement_detail, "Internal") md_enabled = get_column_index(measurement_detail, "Enabled") md_updated_on = get_column_index(measurement_detail, "UpdatedOn") # Lookup column indices for active measurements table am_source_node_id = get_column_index(active_measurements, "SourceNodeID") am_id = get_column_index(active_measurements, "ID") am_signal_id = get_column_index(active_measurements, "SignalID") am_point_tag = get_column_index(active_measurements, "PointTag") am_signal_reference = get_column_index(active_measurements, "SignalReference") am_internal = get_column_index(active_measurements, "Internal") am_subscribed = get_column_index(active_measurements, "Subscribed") am_device = get_column_index(active_measurements, "Device") am_device_id = get_column_index(active_measurements, "DeviceID") am_frames_per_second = get_column_index(active_measurements, "FramesPerSecond") am_protocol = get_column_index(active_measurements, "Protocol") am_protocol_type = get_column_index(active_measurements, "ProtocolType") am_signal_type = get_column_index(active_measurements, "SignalType") am_engineering_units = get_column_index(active_measurements, "EngineeringUnits") am_phasor_id = get_column_index(active_measurements, "PhasorID") am_phasor_type = get_column_index(active_measurements, "PhasorType") am_phase = get_column_index(active_measurements, "Phase") am_adder = get_column_index(active_measurements, "Adder") am_multiplier = get_column_index(active_measurements, "Multiplier") am_company = get_column_index(active_measurements, "Company") am_longitude = get_column_index(active_measurements, "Longitude") am_latitude = get_column_index(active_measurements, "Latitude") am_description = get_column_index(active_measurements, "Description") am_updated_on = get_column_index(active_measurements, "UpdatedOn") # Build active measurements from measurement detail rows_processed = 0 rows_enabled = 0 rows_added = 0 for md_row in measurement_detail: rows_processed += 1 # Skip disabled measurements if md_row[md_enabled] == 0: continue rows_enabled += 1 am_row = active_measurements.create_row() # Set basic measurement info am_row[am_source_node_id] = self._node_id am_row[am_id] = md_row[md_id] am_row[am_signal_id] = md_row[md_signal_id] am_row[am_point_tag] = md_row[md_point_tag] am_row[am_signal_reference] = md_row[md_signal_reference] am_row[am_internal] = 1 if md_row[md_internal] else 0 am_row[am_subscribed] = 0 am_row[am_description] = md_row[md_description] am_row[am_adder] = 0.0 am_row[am_multiplier] = 1.0 am_row[am_updated_on] = md_row[md_updated_on] # Set signal type and engineering units signal_type = md_row[md_signal_acronym] if md_row[md_signal_acronym] else "CALC" am_row[am_signal_type] = signal_type am_row[am_engineering_units] = get_engineering_units(signal_type) # Get device info device_acronym = md_row[md_device_acronym] if md_row[md_device_acronym] else "" if not device_acronym: # Set default values when measurement is not associated with a device am_row[am_frames_per_second] = 30 else: am_row[am_device] = device_acronym # Lookup associated device record device = device_data.get(device_acronym) if device: am_row[am_device_id] = device['DeviceID'] am_row[am_frames_per_second] = device['FramesPerSecond'] am_row[am_company] = device['Company'] am_row[am_protocol] = device['Protocol'] am_row[am_protocol_type] = device['ProtocolType'] am_row[am_longitude] = device['Longitude'] am_row[am_latitude] = device['Latitude'] # Lookup associated phasor records if device_acronym in phasor_data: source_index = int(md_row[md_phasor_source_index]) if md_row[md_phasor_source_index] else 0 phasor = phasor_data[device_acronym].get(source_index) if phasor: am_row[am_phasor_id] = phasor['PhasorID'] am_row[am_phasor_type] = phasor['PhasorType'] am_row[am_phase] = phasor['Phase'] active_measurements.add_row(am_row) rows_added += 1 except Exception as ex: self._dispatch_status_message(f"ERROR building ActiveMeasurements: {ex}") self._filtering_metadata = filtering_metadata # Notify all subscribers that configuration metadata has changed with self._subscriber_connections_lock: for connection in self._subscriber_connections: connection.send_response(ServerResponse.CONFIGURATIONCHANGED, ServerCommand.SUBSCRIBE) self._dispatch_status_message( f"Metadata defined with {len(metadata)} tables" )
[docs] def filter_metadata(self, filter_expression: str) -> List[MetadataMeasurementRecord]: """ Filters metadata using a filter expression against the MeasurementDetail table. This is for the publisher application to decide what measurements to publish. Matches C++ DataPublisher::FilterMetadata which filters from m_metadata. Args: filter_expression: Filter expression (e.g., "SignalAcronym <> 'STAT'") Returns: List of MeasurementRecord objects from MeasurementDetail """ results = [] try: if not self._metadata: raise RuntimeError("Cannot filter metadata, no metadata has been defined.") # Find MeasurementDetail table from original metadata (matching C++) measurement_detail = self._metadata.table("MeasurementDetail") if not measurement_detail: self._dispatch_error_message("MeasurementDetail table not found in metadata") return results # Parse and execute filter expression on MeasurementDetail table # Matching C++: FilterExpressionParser::Select(m_metadata, filterExpression, "MeasurementDetail") filtered_rows, err = FilterExpressionParser.select_datarows( self._metadata, filter_expression, "MeasurementDetail" ) if err is not None: self._dispatch_error_message(f"Error filtering metadata: {type(err).__name__}: {err}") return results if not filtered_rows: return results # Get column indices from MeasurementDetail (matching C++) def get_column_index(table, column_name: str) -> int: col = table.column_byname(column_name) if col is None: raise ValueError(f"Column name '{column_name}' was not found in table '{table.name}'") return col.index device_acronym_idx = get_column_index(measurement_detail, "DeviceAcronym") id_idx = get_column_index(measurement_detail, "ID") signal_id_idx = get_column_index(measurement_detail, "SignalID") point_tag_idx = get_column_index(measurement_detail, "PointTag") signal_reference_idx = get_column_index(measurement_detail, "SignalReference") phasor_source_index_idx = get_column_index(measurement_detail, "PhasorSourceIndex") description_idx = get_column_index(measurement_detail, "Description") enabled_idx = get_column_index(measurement_detail, "Enabled") updated_on_idx = get_column_index(measurement_detail, "UpdatedOn") signal_acronym_idx = get_column_index(measurement_detail, "SignalAcronym") # Build MeasurementRecord objects from filtered rows (matching C++) for row in filtered_rows: # Skip disabled measurements (matching C++) if not row[enabled_idx]: continue # Parse measurement key (ID column contains "Source:ID" format like "PPA:12") id_value = row[id_idx] if row[id_idx] else "" source = "" numeric_id = 0 if isinstance(id_value, str) and ':' in id_value: parts = id_value.split(':', 1) source = parts[0] try: numeric_id = int(parts[1]) except (ValueError, IndexError): numeric_id = 0 # Create MeasurementRecord with all metadata fields (matching C++) metadata = MetadataMeasurementRecord( signalid=row[signal_id_idx] if row[signal_id_idx] else MetadataMeasurementRecord.DEFAULT_SIGNALID, id=np.uint64(numeric_id), source=source if source else MetadataMeasurementRecord.DEFAULT_SOURCE, pointtag=row[point_tag_idx] if row[point_tag_idx] else MetadataMeasurementRecord.DEFAULT_POINTTAG, signalreference=row[signal_reference_idx] if row[signal_reference_idx] else MetadataMeasurementRecord.DEFAULT_SIGNALREFERENCE, deviceacronym=row[device_acronym_idx] if row[device_acronym_idx] else MetadataMeasurementRecord.DEFAULT_DEVICEACRONYM, signaltypename=row[signal_acronym_idx] if row[signal_acronym_idx] else MetadataMeasurementRecord.DEFAULT_SIGNALTYPENAME, description=row[description_idx] if row[description_idx] else MetadataMeasurementRecord.DEFAULT_DESCRIPTION, updatedon=row[updated_on_idx] if row[updated_on_idx] else MetadataMeasurementRecord.DEFAULT_UPDATEDON ) # Store phasor source index (C++ stores this in metadata->PhasorSourceIndex) metadata.phasor_source_index = int(row[phasor_source_index_idx]) if row[phasor_source_index_idx] else 0 results.append(metadata) self._dispatch_status_message( f"Filter expression '{filter_expression}' matched {len(results)} measurements" ) except Exception as ex: self._dispatch_error_message(f"Error filtering metadata: {ex}") return results
# Connection management # Connection management
[docs] def start(self, port: int, ipv6: bool = False): """Starts the data publisher listening on the specified port.""" if self._started: raise RuntimeError("Publisher is already started") try: self._port = np.uint16(port) self._ipv6 = ipv6 # Create server socket family = socket.AF_INET6 if ipv6 else socket.AF_INET self._server_socket = socket.socket(family, socket.SOCK_STREAM) self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Bind and listen host = "::" if ipv6 else "0.0.0.0" self._server_socket.bind((host, port)) self._server_socket.listen(10) self._started = True self._stopped = False self._shuttingdown = False # Start accept thread self._accept_thread = Thread(target=self._accept_connections, daemon=True) self._accept_thread.start() self._dispatch_status_message(f"DataPublisher listening on port {port}") except Exception as ex: self._dispatch_error_message(f"Failed to start DataPublisher: {ex}") raise
[docs] def stop(self): """Stops the data publisher.""" if self._stopped: return self._shuttingdown = True self._started = False # Close server socket if self._server_socket: try: self._server_socket.close() except: pass self._server_socket = None # Wait for accept thread if self._accept_thread and self._accept_thread.is_alive(): self._accept_thread.join(timeout=2.0) # Disconnect all subscribers with self._subscriber_connections_lock: connections = list(self._subscriber_connections) for connection in connections: try: connection.stop() except: pass # Clear routing tables self._routing_tables.clear() self._stopped = True self._dispatch_status_message("DataPublisher stopped")
def _accept_connections(self): """Background thread that accepts client connections.""" while self._started and not self._shuttingdown: try: client_socket, address = self._server_socket.accept() # Check connection limit with self._subscriber_connections_lock: connection_count = len(self._subscriber_connections) connection_accepted = (self._maximum_allowed_connections == -1 or connection_count < self._maximum_allowed_connections) if connection_accepted: # Create subscriber connection connection = SubscriberConnection(self, client_socket, address) with self._subscriber_connections_lock: self._subscriber_connections.add(connection) # Start connection connection.start() # Notify callback self._dispatch_client_connected(connection) self._dispatch_status_message( f"Client connected from {address[0]}:{address[1]}" ) else: # Reject connection self._dispatch_error_message( f"Connection from {address[0]}:{address[1]} refused: " + f"would exceed {self._maximum_allowed_connections} maximum connections" ) try: client_socket.close() except: pass except Exception as ex: if self._started and not self._shuttingdown: self._dispatch_error_message(f"Error accepting connection: {ex}") def _connection_terminated(self, connection: SubscriberConnection): """Called when a subscriber connection is terminated.""" # Remove from active connections with self._subscriber_connections_lock: self._subscriber_connections.discard(connection) # Remove from routing tables self._routing_tables.remove_routes(connection) # Notify callback self._dispatch_client_disconnected(connection) self._dispatch_status_message(f"Client disconnected: {connection.connection_id}") # Measurement publication
[docs] def publish_measurements(self, measurements: List[Measurement]): """Publishes measurements to subscribed clients.""" if not self._started or self._shuttingdown: return # Route measurements to appropriate subscribers self._routing_tables.publish_measurements(measurements)
# Callback dispatching def _process_callbacks(self): """Background thread that processes callback queue.""" while True: try: callback, args = self._callback_queue.get(timeout=1.0) if callback: self._callback_executor.submit(callback, *args) except: pass def _dispatch_status_message(self, message: str): """Dispatches a status message to the callback.""" if self.statusmessage_callback: self._callback_queue.put((self.statusmessage_callback, (message,))) def _dispatch_error_message(self, message: str): """Dispatches an error message to the callback.""" if self.errormessage_callback: self._callback_queue.put((self.errormessage_callback, (message,))) def _dispatch_client_connected(self, connection: SubscriberConnection): """Dispatches a client connected event to the callback.""" if self.clientconnected_callback: self._callback_queue.put((self.clientconnected_callback, (connection,))) def _dispatch_client_disconnected(self, connection: SubscriberConnection): """Dispatches a client disconnected event to the callback.""" if self.clientdisconnected_callback: self._callback_queue.put((self.clientdisconnected_callback, (connection,))) # Cleanup
[docs] def dispose(self): """Disposes of the data publisher.""" self.stop() self._routing_tables.dispose() self._callback_executor.shutdown(wait=False)