#!/usr/bin/env python3 import sys import json import argparse import hashlib import importlib import logging from abc import ABC, abstractmethod HASH_DELIM = b'\x00' HASH = hashlib.sha256 class BaseDB(ABC): @abstractmethod def check_key(self, key): pass @abstractmethod def set_key(self, key, value): pass class FileDB(BaseDB): def __init__(self, workdir): self._ospath = importlib.import_module('os.path') self._tempfile = importlib.import_module('tempfile') self._wd = workdir self._test_writable() def _test_writable(self): TEST_STRING = b"test" with self._tempfile.NamedTemporaryFile('w+b', 0, dir=self._wd) as f: f.write(TEST_STRING) f.flush() with open(f.name, 'rb') as tf: assert tf.read() == TEST_STRING, "Test write failed" def _get_key_filename(self, key): return self._ospath.join(self._wd, key + '.json') def check_key(self, key): filename = self._get_key_filename(key) return self._ospath.isfile(filename) def set_key(self, key, obj): filename = self._get_key_filename(key) with open(filename, 'w') as f: json.dump(obj, f, indent=4) f.flush() class Hasher: def __init__(self, key_components): self._key_components = key_components def _eval_key_component(self, obj, component_path): res = obj for path_component in component_path: res = res[path_component] return str(res).encode('utf-8') def hash_object(self, obj): return HASH(HASH_DELIM.join( self._eval_key_component(obj, c) for c in self._key_components) ).hexdigest() class BaseNotifier(ABC): @abstractmethod def notify(self, obj): pass class EmailNotifier(BaseNotifier): def __init__(self, name, *, from_addr, to_addrs, host='localhost', port=None, local_hostname=None, use_ssl=False, use_starttls=False, login=None, password=None, timeout=10): self.name = name self._from_addr = from_addr self._Mailer = importlib.import_module('mailer').Mailer self._MIMEText = importlib.import_module('email.mime.text').MIMEText self._MIMEMult = importlib.import_module( 'email.mime.multipart').MIMEMultipart self._MIMEBase = importlib.import_module('email.mime.base').MIMEBase self._encoders = importlib.import_module('email.encoders') self._m = self._Mailer(from_addr=from_addr, host=host, port=port, local_hostname=local_hostname, use_ssl=use_ssl, use_starttls=use_starttls, login=login, password=password, timeout=timeout) self._to_addrs = to_addrs def notify(self, obj): msg = self._MIMEMult() msg['Subject'] = "New Nvidia driver available!" msg['From'] = self._from_addr msg['To'] = ', '.join(self._to_addrs) body = "See attached JSON" msg.attach(self._MIMEText(body, 'plain')) p = self._MIMEBase('application', 'octet-stream') p.set_payload(json.dumps(obj, indent=4).encode('utf-8')) self._encoders.encode_base64(p) p.add_header('Content-Disposition', "attachment; filename=obj.json") msg.attach(p) self._m.send(self._to_addrs, msg.as_string()) class CommandNotifier(BaseNotifier): def __init__(self, name, *, cmdline, timeout=10): self.name = name self._subprocess = importlib.import_module('subprocess') self._cmdline = cmdline self._timeout = timeout def notify(self, obj): proc = self._subprocess.Popen(self._cmdline, stdin=self._subprocess.PIPE) try: proc.communicate(json.dumps(obj, indent=4).encode('utf-8'), self._timeout) except self._subprocess.TimeoutExpired: proc.kill() proc.communicate() class BaseChannel(ABC): @abstractmethod def get_latest_driver(self): pass class GFEClientChannel(BaseChannel): def __init__(self, name, **kwargs): self.name = name self._kwargs = kwargs gfe_get_driver = importlib.import_module('gfe_get_driver') self._get_latest_driver = gfe_get_driver.get_latest_geforce_driver def get_latest_driver(self): return self._get_latest_driver(**self._kwargs) def parse_args(): parser = argparse.ArgumentParser( description="Watches for GeForce experience driver updates for " "configured systems", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("-c", "--config", default="/etc/nv-driver-locator.json", help="config file location") args = parser.parse_args() return args class DriverLocator: _ret_code = 0 def __init__(self, conf): self._logger = logging.getLogger(self.__class__.__name__) self._channels = self._construct_channels(conf['channels']) self._db = self._construct_db(conf['db']) self._hasher = Hasher(conf['key_components']) self._notifiers = self._construct_notifiers(conf['notifiers']) def _construct_channels(self, channels_config): channel_types = { 'gfe_client': GFEClientChannel, } channels = [] for ch in channels_config: try: ctor = channel_types[ch['type']] C = ctor(ch['name'], **ch['params']) except Exception as e: self._perror("Channel construction failed with exception: %s. " "Skipping..." % (str(e),)) else: channels.append(C) return channels def _construct_db(self, db_config): db_types = { 'file': FileDB, } ctor = db_types[db_config['type']] db = ctor(**db_config['params']) return db def _construct_notifiers(self, notifiers_config): notifier_types = { 'email': EmailNotifier, 'command': CommandNotifier, } notifiers = [] for nc in notifiers_config: try: ctor = notifier_types[nc['type']] N = ctor(nc['name'], **nc['params']) except Exception as e: self._perror("Notifier construction failed with exception: %s." " Skipping..." % (str(e),)) else: notifiers.append(N) return notifiers def _perror(self, err): self._ret_code = 3 self._logger.error(err) def _notify_all(self, obj): fails = 0 for n in self._notifiers: try: n.notify(obj) except Exception as e: self._perror("Notify channel %s failed with exception: %s." % (n.name, str(e))) fails += 1 return fails < len(self._notifiers) def run(self): for ch in self._channels: try: drv = ch.get_latest_driver() except Exception as e: self._perror("get_latest_driver() invocation failed for " "channel %s. Exception: %s. Continuing..." % (repr(ch.name), str(e))) continue if drv is None: self._perror("Driver not found for channel %s" % (repr(ch.name),)) continue try: key = self._hasher.hash_object(drv) except Exception as e: self._perror("Key evaluation failed for channel %s. " "Exception: %s" % (repr(name), str(e))) continue if not self._db.check_key(key): if self._notify_all(drv): self._db.set_key(key, drv) return self._ret_code def setup_logger(name, verbosity): logger = logging.getLogger(name) logger.setLevel(verbosity) handler = logging.StreamHandler() handler.setLevel(verbosity) handler.setFormatter(logging.Formatter('%(asctime)s ' '%(levelname)-8s ' '%(name)s: %(message)s', '%Y-%m-%d %H:%M:%S')) logger.addHandler(handler) return logger def main(): args = parse_args() setup_logger(DriverLocator.__name__, logging.ERROR) with open(args.config, 'r') as conf_file: conf = json.load(conf_file) ret = DriverLocator(conf).run() sys.exit(ret) if __name__ == '__main__': main()