mirror of
https://github.com/keylase/nvidia-patch.git
synced 2024-08-30 18:32:50 +00:00
291 lines
9.0 KiB
Python
291 lines
9.0 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
import sys
|
||
|
import json
|
||
|
import argparse
|
||
|
import hashlib
|
||
|
import importlib
|
||
|
import logging
|
||
|
from abc import ABC, abstractmethod
|
||
|
|
||
|
|
||
|
HASH_DELIM = b'\x00'
|
||
|
HASH = hashlib.sha256
|
||
|
|
||
|
|
||
|
class BaseDB(ABC):
|
||
|
@abstractmethod
|
||
|
def check_key(self, key):
|
||
|
pass
|
||
|
|
||
|
@abstractmethod
|
||
|
def set_key(self, key, value):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class FileDB(BaseDB):
|
||
|
def __init__(self, workdir):
|
||
|
self._ospath = importlib.import_module('os.path')
|
||
|
self._tempfile = importlib.import_module('tempfile')
|
||
|
self._wd = workdir
|
||
|
self._test_writable()
|
||
|
|
||
|
def _test_writable(self):
|
||
|
TEST_STRING = b"test"
|
||
|
with self._tempfile.NamedTemporaryFile('w+b', 0, dir=self._wd) as f:
|
||
|
f.write(TEST_STRING)
|
||
|
f.flush()
|
||
|
with open(f.name, 'rb') as tf:
|
||
|
assert tf.read() == TEST_STRING, "Test write failed"
|
||
|
|
||
|
def _get_key_filename(self, key):
|
||
|
return self._ospath.join(self._wd, key + '.json')
|
||
|
|
||
|
def check_key(self, key):
|
||
|
filename = self._get_key_filename(key)
|
||
|
return self._ospath.isfile(filename)
|
||
|
|
||
|
def set_key(self, key, obj):
|
||
|
filename = self._get_key_filename(key)
|
||
|
with open(filename, 'w') as f:
|
||
|
json.dump(obj, f, indent=4)
|
||
|
f.flush()
|
||
|
|
||
|
|
||
|
class Hasher:
|
||
|
def __init__(self, key_components):
|
||
|
self._key_components = key_components
|
||
|
|
||
|
def _eval_key_component(self, obj, component_path):
|
||
|
res = obj
|
||
|
for path_component in component_path:
|
||
|
res = res[path_component]
|
||
|
return str(res).encode('utf-8')
|
||
|
|
||
|
def hash_object(self, obj):
|
||
|
return HASH(HASH_DELIM.join(
|
||
|
self._eval_key_component(obj, c) for c in self._key_components)
|
||
|
).hexdigest()
|
||
|
|
||
|
|
||
|
class BaseNotifier(ABC):
|
||
|
@abstractmethod
|
||
|
def notify(self, obj):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class EmailNotifier(BaseNotifier):
|
||
|
def __init__(self, name, *,
|
||
|
from_addr,
|
||
|
to_addrs,
|
||
|
host='localhost',
|
||
|
port=None,
|
||
|
local_hostname=None,
|
||
|
use_ssl=False,
|
||
|
use_starttls=False,
|
||
|
login=None,
|
||
|
password=None,
|
||
|
timeout=10):
|
||
|
self.name = name
|
||
|
self._from_addr = from_addr
|
||
|
self._Mailer = importlib.import_module('mailer').Mailer
|
||
|
self._MIMEText = importlib.import_module('email.mime.text').MIMEText
|
||
|
self._MIMEMult = importlib.import_module(
|
||
|
'email.mime.multipart').MIMEMultipart
|
||
|
self._MIMEBase = importlib.import_module('email.mime.base').MIMEBase
|
||
|
self._encoders = importlib.import_module('email.encoders')
|
||
|
self._m = self._Mailer(from_addr=from_addr,
|
||
|
host=host,
|
||
|
port=port,
|
||
|
local_hostname=local_hostname,
|
||
|
use_ssl=use_ssl,
|
||
|
use_starttls=use_starttls,
|
||
|
login=login,
|
||
|
password=password,
|
||
|
timeout=timeout)
|
||
|
self._to_addrs = to_addrs
|
||
|
|
||
|
def notify(self, obj):
|
||
|
msg = self._MIMEMult()
|
||
|
msg['Subject'] = "New Nvidia driver available!"
|
||
|
msg['From'] = self._from_addr
|
||
|
msg['To'] = ', '.join(self._to_addrs)
|
||
|
body = "See attached JSON"
|
||
|
msg.attach(self._MIMEText(body, 'plain'))
|
||
|
p = self._MIMEBase('application', 'octet-stream')
|
||
|
p.set_payload(json.dumps(obj, indent=4).encode('utf-8'))
|
||
|
self._encoders.encode_base64(p)
|
||
|
p.add_header('Content-Disposition', "attachment; filename=obj.json")
|
||
|
msg.attach(p)
|
||
|
self._m.send(self._to_addrs, msg.as_string())
|
||
|
|
||
|
|
||
|
class CommandNotifier(BaseNotifier):
|
||
|
def __init__(self, name, *,
|
||
|
cmdline,
|
||
|
timeout=10):
|
||
|
self.name = name
|
||
|
self._subprocess = importlib.import_module('subprocess')
|
||
|
self._cmdline = cmdline
|
||
|
self._timeout = timeout
|
||
|
|
||
|
def notify(self, obj):
|
||
|
proc = self._subprocess.Popen(self._cmdline,
|
||
|
stdin=self._subprocess.PIPE)
|
||
|
try:
|
||
|
proc.communicate(json.dumps(obj, indent=4).encode('utf-8'),
|
||
|
self._timeout)
|
||
|
except self._subprocess.TimeoutExpired:
|
||
|
proc.kill()
|
||
|
proc.communicate()
|
||
|
|
||
|
|
||
|
class BaseChannel(ABC):
|
||
|
@abstractmethod
|
||
|
def get_latest_driver(self):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class GFEClientChannel(BaseChannel):
|
||
|
def __init__(self, name, **kwargs):
|
||
|
self.name = name
|
||
|
self._kwargs = kwargs
|
||
|
gfe_get_driver = importlib.import_module('gfe_get_driver')
|
||
|
self._get_latest_driver = gfe_get_driver.get_latest_geforce_driver
|
||
|
|
||
|
def get_latest_driver(self):
|
||
|
return self._get_latest_driver(**self._kwargs)
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description="Watches for GeForce experience driver updates for "
|
||
|
"configured systems",
|
||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||
|
parser.add_argument("-c", "--config",
|
||
|
default="/etc/nv-driver-locator.json",
|
||
|
help="config file location")
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
class DriverLocator:
|
||
|
_ret_code = 0
|
||
|
|
||
|
def __init__(self, conf):
|
||
|
self._logger = logging.getLogger(self.__class__.__name__)
|
||
|
self._channels = self._construct_channels(conf['channels'])
|
||
|
self._db = self._construct_db(conf['db'])
|
||
|
self._hasher = Hasher(conf['key_components'])
|
||
|
self._notifiers = self._construct_notifiers(conf['notifiers'])
|
||
|
|
||
|
def _construct_channels(self, channels_config):
|
||
|
channel_types = {
|
||
|
'gfe_client': GFEClientChannel,
|
||
|
}
|
||
|
|
||
|
channels = []
|
||
|
for ch in channels_config:
|
||
|
try:
|
||
|
ctor = channel_types[ch['type']]
|
||
|
C = ctor(ch['name'], **ch['params'])
|
||
|
except Exception as e:
|
||
|
self._perror("Channel construction failed with exception: %s. "
|
||
|
"Skipping..." % (str(e),))
|
||
|
else:
|
||
|
channels.append(C)
|
||
|
return channels
|
||
|
|
||
|
def _construct_db(self, db_config):
|
||
|
db_types = {
|
||
|
'file': FileDB,
|
||
|
}
|
||
|
ctor = db_types[db_config['type']]
|
||
|
db = ctor(**db_config['params'])
|
||
|
return db
|
||
|
|
||
|
def _construct_notifiers(self, notifiers_config):
|
||
|
notifier_types = {
|
||
|
'email': EmailNotifier,
|
||
|
'command': CommandNotifier,
|
||
|
}
|
||
|
|
||
|
notifiers = []
|
||
|
for nc in notifiers_config:
|
||
|
try:
|
||
|
ctor = notifier_types[nc['type']]
|
||
|
N = ctor(nc['name'], **nc['params'])
|
||
|
except Exception as e:
|
||
|
self._perror("Notifier construction failed with exception: %s."
|
||
|
" Skipping..." % (str(e),))
|
||
|
else:
|
||
|
notifiers.append(N)
|
||
|
return notifiers
|
||
|
|
||
|
def _perror(self, err):
|
||
|
self._ret_code = 3
|
||
|
self._logger.error(err)
|
||
|
|
||
|
def _notify_all(self, obj):
|
||
|
fails = 0
|
||
|
for n in self._notifiers:
|
||
|
try:
|
||
|
n.notify(obj)
|
||
|
except Exception as e:
|
||
|
self._perror("Notify channel %s failed with exception: %s." %
|
||
|
(n.name, str(e)))
|
||
|
fails += 1
|
||
|
return fails < len(self._notifiers)
|
||
|
|
||
|
def run(self):
|
||
|
for ch in self._channels:
|
||
|
try:
|
||
|
drv = ch.get_latest_driver()
|
||
|
except Exception as e:
|
||
|
self._perror("get_latest_driver() invocation failed for "
|
||
|
"channel %s. Exception: %s. Continuing..." %
|
||
|
(repr(ch.name), str(e)))
|
||
|
continue
|
||
|
if drv is None:
|
||
|
self._perror("Driver not found for channel %s" %
|
||
|
(repr(ch.name),))
|
||
|
continue
|
||
|
try:
|
||
|
key = self._hasher.hash_object(drv)
|
||
|
except Exception as e:
|
||
|
self._perror("Key evaluation failed for channel %s. "
|
||
|
"Exception: %s" % (repr(name), str(e)))
|
||
|
continue
|
||
|
if not self._db.check_key(key):
|
||
|
if self._notify_all(drv):
|
||
|
self._db.set_key(key, drv)
|
||
|
return self._ret_code
|
||
|
|
||
|
|
||
|
def setup_logger(name, verbosity):
|
||
|
logger = logging.getLogger(name)
|
||
|
logger.setLevel(verbosity)
|
||
|
handler = logging.StreamHandler()
|
||
|
handler.setLevel(verbosity)
|
||
|
handler.setFormatter(logging.Formatter('%(asctime)s '
|
||
|
'%(levelname)-8s '
|
||
|
'%(name)s: %(message)s',
|
||
|
'%Y-%m-%d %H:%M:%S'))
|
||
|
logger.addHandler(handler)
|
||
|
return logger
|
||
|
|
||
|
|
||
|
def main():
|
||
|
args = parse_args()
|
||
|
setup_logger(DriverLocator.__name__, logging.ERROR)
|
||
|
|
||
|
with open(args.config, 'r') as conf_file:
|
||
|
conf = json.load(conf_file)
|
||
|
|
||
|
ret = DriverLocator(conf).run()
|
||
|
sys.exit(ret)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|