diff --git a/main.py b/main.py index 03b34f3..8a8a4f0 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,5 @@ import signal -import json import time -import socket import logging import os import argparse @@ -19,6 +17,14 @@ from selenium.webdriver import ActionChains from selenium.webdriver.support.wait import WebDriverWait from selenium.webdriver.support import expected_conditions as ec +# Plugin system imports +import importlib +import importlib.util +import inspect +import glob +import sys # Import the sys module +from utils.plugins_base import StatsSetupPlugin, StatsDownloadPlugin + logger = logging.getLogger(__name__) args = None @@ -41,6 +47,7 @@ def setupArgParser(): parser.add_argument('--hub-url', type=str, help='URL of the Selenium hub to connect to. If not provided, local Chrome driver will be used.') parser.add_argument('--webrtc-internals-path', type=str, help='Path to the WebRTC internals extension.') parser.add_argument('--log-level', type=str, help='Log level to use. Default: INFO') + parser.add_argument('--plugin-dir', type=str, help='Path to the plugin directory.') return parser @@ -73,147 +80,215 @@ def setupChromeDriver(command_executor: str | None, webrtc_internals_path: str) return driver -def saveStats(stats: list, socket_url: str, socket_port: int): - try: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - logger.log(logging.DEBUG, f'Saving stats: {json.dumps(stats, indent=4)}') - sock.sendto(json.dumps(stats).encode(), (socket_url, socket_port)) - sock.close() - logger.log(logging.DEBUG, 'Sent stats to socket.') - except socket.error as e: - logger.error(f'Got socket error: {e}') - -def downloadStats(driver: webdriver.Remote | webdriver.Chrome, peersDict: dict, socket_url: str, socket_port: int): - html = driver.find_element(By.CLASS_NAME ,'vjs-stats-list').get_attribute('innerHTML') - if html is not None: - htmlBS = bs(html, 'html.parser') - else: - raise ValueError("html is None") - - stats = htmlBS.find_all('div', attrs={'style': 'display: block;'}) - - playerStats = { - stat.div.text: stat.span.text.replace('\u21d3', 'down').replace('down/', 'down /').replace('\u21d1 ', 'up').replace('\u21d1', 'up').replace('\u00b7', '-').strip() - for stat in stats - } - - keys = list(playerStats.keys()) - for stat in keys: - if 'Viewport / Frames' == stat: - viewport, frames = playerStats[stat].split(' / ') - width, height = viewport.split('x') - height, devicePixelRatio = height.split('*') - dropped, total = frames.split(' of ')[0].split()[0], frames.split(' of ')[1].split()[0] - playerStats[stat] = {'Width': int(width), 'Height': int(height), 'Pixel ratio': float(devicePixelRatio), 'Frames': {'Dropped': int(dropped), 'Total': int(total)}} - - if 'Codecs' == stat: - video, audio = playerStats[stat].split(' / ') - playerStats[stat] = {'Video': video, 'Audio': audio} - - if 'Volume' == stat: - if ' (' in playerStats[stat]: - volume, muted = playerStats[stat].split(' (') - playerStats[stat] = {'Volume': int(volume), 'Muted': 'muted' in muted} - else: - playerStats[stat] = {'Volume': int(playerStats[stat]), 'Muted': False} - - if 'Connection Speed' == stat: - speed, unit = playerStats[stat].split() - - speedBytes = float(speed) * (1024 ** {'B/s': 0, 'KB/s': 1, 'MB/s': 2, 'GB/s': 3}[unit]) - - playerStats[stat] = {'Speed': speedBytes, 'Granularity': 's'} - - if 'Network Activity' == stat: - downString, upString = playerStats[stat].split(' / ') - - down, downUnit = downString.replace('down', '').strip().split() - up, upUnit = upString.replace('up', '').strip().split() - - downBytes = convert_to_bytes(down, downUnit) - upBytes = convert_to_bytes(up, upUnit) - - playerStats[stat] = {'Down': downBytes, 'Up': upBytes} - - if 'Total Transfered' == stat: - downString, upString = playerStats[stat].split(' / ') - - down, downUnit = downString.replace('down', '').strip().split() - up, upUnit = upString.replace('up', '').strip().split() - - downBytes = convert_to_bytes(down, downUnit) - upBytes = convert_to_bytes(up, upUnit) - - playerStats[stat] = {'Down': downBytes, 'Up': upBytes} - - if 'Download Breakdown' == stat: - server, peer = playerStats[stat].split(' - ') - - server, serverUnit = server.replace('from servers', '').strip().split() - peer, peerUnit = peer.replace('from peers', '').strip().split() - - serverBytes = convert_to_bytes(server, serverUnit) - peerBytes = convert_to_bytes(peer, peerUnit) - - playerStats[stat] = {'Server': serverBytes, 'Peers': peerBytes} - - if 'Buffer State' == stat: - del(playerStats[stat]) - - if 'Live Latency' == stat: - latency, edge = playerStats[stat].split(' (from edge: ') - latency = sum(int(x) * 60 ** i for i, x in enumerate(reversed([part for part in latency.replace('s', '').split('m') if part]))) - edge = sum(int(x) * 60 ** i for i, x in enumerate(reversed([part for part in edge.replace('s', '').replace(')', '').split('m') if part]))) - playerStats[stat] = {'Latency': latency, 'Edge': edge} - - stats = { - 'player': playerStats, - 'peers': peersDict, - 'url': driver.current_url, - 'timestamp': int(time.time() * 1000), - 'session': driver.session_id - } - - saveStats([stats], socket_url, socket_port) - def convert_to_bytes(down, downUnit): return float(down) * (1024 ** {'B': 0, 'KB': 1, 'MB': 2, 'GB': 3}[downUnit]) -def setupStats(driver: webdriver.Remote, url: str, retries: int = 5) -> webdriver.Remote: - logger.log(logging.INFO, 'Setting up stats.') - actions = ActionChains(driver) - wait = WebDriverWait(driver, 30, poll_frequency=0.2) - - sleep(2) +# Default Plugin Implementations +class DefaultStatsSetupPlugin(StatsSetupPlugin): + def setup_stats(self, driver: webdriver.Remote, url: str, retries: int = 5) -> webdriver.Remote: + logger.log(logging.INFO, 'Setting up stats.') + actions = ActionChains(driver) + wait = WebDriverWait(driver, 30, poll_frequency=0.2) + + sleep(2) - for attempt in range(retries): - driver.get(url) + for attempt in range(retries): + driver.get(url) + try: + wait.until(ec.presence_of_element_located((By.CLASS_NAME, 'vjs-big-play-button'))) + break + except Exception: + logger.error(f'Timeout while waiting for the big play button to be present. Attempt {attempt + 1} of {retries}') + if attempt == retries - 1: + logger.error('Timeout limit reached. Exiting.') + driver.quit() + raise SystemExit(1) + + actions.click(driver.find_element(By.CLASS_NAME ,'video-js')).perform() + wait.until(ec.visibility_of_element_located((By.CLASS_NAME, 'vjs-control-bar'))) + actions.context_click(driver.find_element(By.CLASS_NAME ,'video-js')).perform() + statsForNerds = driver.find_elements(By.CLASS_NAME ,'vjs-menu-item') + actions.click(statsForNerds[-1]).perform() + wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, 'div.vjs-stats-content[style="display: block;"]'))) + actions.move_to_element(driver.find_element(By.CLASS_NAME ,'vjs-control-bar')).perform() + logger.log(logging.INFO, 'Stats setup complete.') + + return driver + +class DefaultStatsDownloadPlugin(StatsDownloadPlugin): + def download_stats(self, driver: webdriver.Remote, peersDict: dict, socket_url: str, socket_port: int): + html = driver.find_element(By.CLASS_NAME ,'vjs-stats-list').get_attribute('innerHTML') + if html is not None: + htmlBS = bs(html, 'html.parser') + else: + raise ValueError("html is None") + + stats = htmlBS.find_all('div', attrs={'style': 'display: block;'}) + + playerStats = { + stat.div.text: stat.span.text.replace('\u21d3', 'down').replace('down/', 'down /').replace('\u21d1 ', 'up').replace('\u21d1', 'up').replace('\u00b7', '-').strip() # type: ignore + for stat in stats + } + + keys = list(playerStats.keys()) + for stat in keys: + if 'Viewport / Frames' == stat: + viewport, frames = playerStats[stat].split(' / ') + width, height = viewport.split('x') + height, devicePixelRatio = height.split('*') + dropped, total = frames.split(' of ')[0].split()[0], frames.split(' of ')[1].split()[0] + playerStats[stat] = {'Width': int(width), 'Height': int(height), 'Pixel ratio': float(devicePixelRatio), 'Frames': {'Dropped': int(dropped), 'Total': int(total)}} + + if 'Codecs' == stat: + video, audio = playerStats[stat].split(' / ') + playerStats[stat] = {'Video': video, 'Audio': audio} + + if 'Volume' == stat: + if ' (' in playerStats[stat]: + volume, muted = playerStats[stat].split(' (') + playerStats[stat] = {'Volume': int(volume), 'Muted': 'muted' in muted} + else: + playerStats[stat] = {'Volume': int(playerStats[stat]), 'Muted': False} + + if 'Connection Speed' == stat: + speed, unit = playerStats[stat].split() + + speedBytes = float(speed) * (1024 ** {'B/s': 0, 'KB/s': 1, 'MB/s': 2, 'GB/s': 3}[unit]) + + playerStats[stat] = {'Speed': speedBytes, 'Granularity': 's'} + + if 'Network Activity' == stat: + downString, upString = playerStats[stat].split(' / ') + + down, downUnit = downString.replace('down', '').strip().split() + up, upUnit = upString.replace('up', '').strip().split() + + downBytes = convert_to_bytes(down, downUnit) + upBytes = convert_to_bytes(up, upUnit) + + playerStats[stat] = {'Down': downBytes, 'Up': upBytes} + + if 'Total Transfered' == stat: + downString, upString = playerStats[stat].split(' / ') + + down, downUnit = downString.replace('down', '').strip().split() + up, upUnit = upString.replace('up', '').strip().split() + + downBytes = convert_to_bytes(down, downUnit) + upBytes = convert_to_bytes(up, upUnit) + + playerStats[stat] = {'Down': downBytes, 'Up': upBytes} + + if 'Download Breakdown' == stat: + server, peer = playerStats[stat].split(' - ') + + server, serverUnit = server.replace('from servers', '').strip().split() + peer, peerUnit = peer.replace('from peers', '').strip().split() + + serverBytes = convert_to_bytes(server, serverUnit) + peerBytes = convert_to_bytes(peer, peerUnit) + + playerStats[stat] = {'Server': serverBytes, 'Peers': peerBytes} + + if 'Buffer State' == stat: + del(playerStats[stat]) + + if 'Live Latency' == stat: + latency, edge = playerStats[stat].split(' (from edge: ') + latency = sum(int(x) * 60 ** i for i, x in enumerate(reversed([part for part in latency.replace('s', '').split('m') if part]))) + edge = sum(int(x) * 60 ** i for i, x in enumerate(reversed([part for part in edge.replace('s', '').replace(')', '').split('m') if part]))) + playerStats[stat] = {'Latency': latency, 'Edge': edge} + + stats = { + 'player': playerStats, + 'peers': peersDict, + 'url': driver.current_url, + 'timestamp': int(time.time() * 1000), + 'session': driver.session_id + } + + super().saveStats([stats], socket_url, socket_port) + +# Plugin loading mechanism +def load_plugins(plugin_dir: str) -> tuple[StatsSetupPlugin | None, StatsDownloadPlugin | None]: + """ + Loads plugins from the specified directory. + + Args: + plugin_dir: The directory to search for plugins. + + Returns: + A tuple containing the loaded StatsSetupPlugin and StatsDownloadPlugin, or (None, None) if no plugins were found. + """ + + logger.info(f"Loading plugins from {plugin_dir}") + + setup_plugin = None + download_plugin = None + + plugin_files = glob.glob(os.path.join(plugin_dir, "*.py")) + + # Log the contents of the plugin directory + logger.debug(f"Plugin directory contents: {os.listdir(plugin_dir)}") + + for plugin_file in plugin_files: + module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension + logger.debug(f"Loading plugin file {plugin_file}") try: - wait.until(ec.presence_of_element_located((By.CLASS_NAME, 'vjs-big-play-button'))) - break - except Exception: - logger.error(f'Timeout while waiting for the big play button to be present. Attempt {attempt + 1} of {retries}') - if attempt == retries - 1: - logger.error('Timeout limit reached. Exiting.') - driver.quit() - raise SystemExit(1) + spec = importlib.util.spec_from_file_location(module_name, plugin_file) + logger.debug(f"Spec: {spec}") + if spec is None: + logger.warning(f"Can't load plugin file {plugin_file}") + continue + module = importlib.util.module_from_spec(spec) + logger.debug(f"Module: {module}") + if spec.loader is not None: + spec.loader.exec_module(module) + else: + logger.warning(f"Can't load module {module_name} from {plugin_file}") + + for name, obj in inspect.getmembers(module): + logger.debug(f"Found member: {name} in module {module_name}") + if inspect.isclass(obj): + if issubclass(obj, StatsSetupPlugin) and obj is not StatsSetupPlugin: + logger.info(f"Found StatsSetupPlugin: {obj.__name__}") + setup_plugin = obj() + logger.debug(f"Loaded StatsSetupPlugin: {obj.__name__} from {plugin_file}") + elif issubclass(obj, StatsDownloadPlugin) and obj is not StatsDownloadPlugin: + logger.info(f"Found StatsDownloadPlugin: {obj.__name__}") + download_plugin = obj() + logger.debug(f"Loaded StatsDownloadPlugin: {obj.__name__} from {plugin_file}") + else: + logger.debug(f"Class {obj.__name__} is not a subclass of StatsSetupPlugin or StatsDownloadPlugin") + else: + logger.debug(f"{name} is not a class") + except Exception as e: + logger.warning(f"Error loading plugin {plugin_file}: {e}") - actions.click(driver.find_element(By.CLASS_NAME ,'video-js')).perform() - wait.until(ec.visibility_of_element_located((By.CLASS_NAME, 'vjs-control-bar'))) - actions.context_click(driver.find_element(By.CLASS_NAME ,'video-js')).perform() - statsForNerds = driver.find_elements(By.CLASS_NAME ,'vjs-menu-item') - actions.click(statsForNerds[-1]).perform() - wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, 'div.vjs-stats-content[style="display: block;"]'))) - actions.move_to_element(driver.find_element(By.CLASS_NAME ,'vjs-control-bar')).perform() - logger.log(logging.INFO, 'Stats setup complete.') - - return driver + return setup_plugin, download_plugin if __name__ == '__main__': args = setupArgParser().parse_args() setupLogger() + # Load plugins + plugin_dir = firstValid(args.plugin_dir, os.getenv('PLUGIN_DIR'), default=None) + if plugin_dir is None: + logger.info("No plugin directory provided. Using default plugins.") + setup_plugin = None + download_plugin = None + else: + setup_plugin, download_plugin = load_plugins(plugin_dir) + + # Use default plugins if none are loaded + if setup_plugin is None: + setup_plugin = DefaultStatsSetupPlugin() + logger.info("Using default StatsSetupPlugin.") + if download_plugin is None: + download_plugin = DefaultStatsDownloadPlugin() + logger.info("Using default StatsDownloadPlugin.") + command_executor = firstValid(args.hub_url, os.getenv('HUB_URL'), default=None) webrtc_internals_path = firstValid( args.webrtc_internals_path, @@ -230,7 +305,8 @@ if __name__ == '__main__': logger.error('VIDEO_URL environment variable or --url argument is required.') raise SystemExit(1) - setupStats(driver, url) + # Use the loaded plugin + driver = setup_plugin.setup_stats(driver, url) socket_url = firstValid(args.socket_url, os.getenv('SOCKET_URL'), default='localhost') try: @@ -240,5 +316,5 @@ if __name__ == '__main__': raise SystemExit(1) logger.info('Starting server collector.') - httpd = HTTPServer(('', 9092), partial(Handler, downloadStats, driver, logger, socket_url, socket_port)) + httpd = HTTPServer(('', 9092), partial(Handler, download_plugin.download_stats, driver, logger, socket_url, socket_port)) httpd.serve_forever() \ No newline at end of file diff --git a/plugins/example_plugin.py b/plugins/example_plugin.py new file mode 100644 index 0000000..9a44a24 --- /dev/null +++ b/plugins/example_plugin.py @@ -0,0 +1,31 @@ +import logging +from selenium import webdriver +from selenium.webdriver.remote.webdriver import WebDriver as Remote +from utils.plugins_base import StatsSetupPlugin, StatsDownloadPlugin + +logger = logging.getLogger(__name__) + +class ExampleStatsSetupPlugin(StatsSetupPlugin): + def setup_stats(self, driver: webdriver.Chrome, url: str, retries: int = 5) -> webdriver.Chrome: + logger.info("Running ExampleStatsSetupPlugin...") + # Here you would implement the custom logic to setup stats + # For example, you could click on a button to display stats. + # You could also wait for an element to appear before continuing. + # This is just an example + + driver.get(url) + + return driver + +class ExampleStatsDownloadPlugin(StatsDownloadPlugin): + def download_stats(self, driver: webdriver.Chrome, peersDict: dict, socket_url: str, socket_port: int): + logger.info("Running ExampleStatsDownloadPlugin...") + stats = {'message': 'Hello from ExampleStatsDownloadPlugin'} + # Here you would implement the custom logic to download stats + # and send them to the socket. + # This is just an example + + print(f"Sending stats: {stats} to {socket_url}:{socket_port}") + + # Remember to call the saveStats method to send the stats to the socket + super().saveStats([stats], socket_url, socket_port) diff --git a/utils/plugins_base.py b/utils/plugins_base.py new file mode 100644 index 0000000..c340782 --- /dev/null +++ b/utils/plugins_base.py @@ -0,0 +1,29 @@ +import abc +import json +import socket +import logging +from selenium import webdriver + +logger = logging.getLogger(__name__) + +# Abstract Base Classes for Plugins +class StatsSetupPlugin(abc.ABC): + @abc.abstractmethod + def setup_stats(self, driver: webdriver.Remote | webdriver.Chrome, url: str, retries: int = 5) -> webdriver.Remote | webdriver.Chrome: + pass + +class StatsDownloadPlugin(abc.ABC): + @abc.abstractmethod + def download_stats(self, driver: webdriver.Remote | webdriver.Chrome, peersDict: dict, socket_url: str, socket_port: int): + pass + + @staticmethod + def saveStats(stats: list, socket_url: str, socket_port: int): + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + logger.debug(f'Saving stats: {json.dumps(stats, indent=4)}') + sock.sendto(json.dumps(stats).encode(), (socket_url, socket_port)) + sock.close() + logger.debug('Sent stats to socket.') + except socket.error as e: + logger.error(f'Got socket error: {e}') \ No newline at end of file