From f8aad8b33df021faf92f9d0262a2f1143cf68313 Mon Sep 17 00:00:00 2001 From: luukas Date: Wed, 18 Aug 2021 18:11:53 +0300 Subject: [PATCH] Add database migrations! --- .DS_Store | Bin 6148 -> 0 bytes DBCHANGES.md | 71 ++++ app/classes/shared/cmd.py | 28 +- app/classes/shared/helpers.py | 1 + app/classes/shared/migration.py | 532 ++++++++++++++++++++++++++ app/classes/shared/models.py | 119 ++---- app/migrations/20210813111015_init.py | 215 +++++++++++ main.py | 11 +- requirements.txt | 1 - 9 files changed, 883 insertions(+), 95 deletions(-) delete mode 100644 .DS_Store create mode 100644 DBCHANGES.md create mode 100644 app/classes/shared/migration.py create mode 100644 app/migrations/20210813111015_init.py diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 5008ddfcf53c02e82d7eee2e57c38e5672ef89f6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0` command (in Crafty's prompt). + +A full list of helper functions you can find in `app/classes/shared/models.py` + +## Example migration files + +### Rename column/field + +```py +def migrate(migrator, database, **kwargs): + migrator.rename_column('my_table', 'old_name', 'new_name') # First argument can be model class OR table name + + + +def rollback(migrator, database, **kwargs): + migrator.rename_column('my_table', 'new_name', 'old_name') # First argument can be model class OR table name + +``` + +### Rename table/model + +```py +def migrate(migrator, database, **kwargs): + migrator.rename_table('old_name', 'new_name') # First argument can be model class OR table name + + + +def rollback(migrator, database, **kwargs): + migrator.rename_table('new_name', 'old_name') # First argument can be model class OR table name + +``` + +### Create table/model + +```py +import peewee + + +def migrate(migrator, database, **kwargs): + class NewTable(peewee.Model): + my_id = peewee.IntegerField(unique=True, primary_key=True) + + class Meta: + table_name = 'new_table' + database = database + create_table(NewTable) + + + +def rollback(migrator, database, **kwargs): + drop_table('new_table') # Can be model class OR table name + +``` + +### Add columns/fields + +```py +import peewee + + +def migrate(migrator, database, **kwargs): + migrator.add_columns('table_name', new_field_name=peewee.CharField(default="")) # First argument can be model class OR table name + + + +def rollback(migrator, database, **kwargs): + migrator.drop_columns('table_name', ['new_field_name']) # First argument can be model class OR table name + +``` diff --git a/app/classes/shared/cmd.py b/app/classes/shared/cmd.py index 5c65c46d..c769c0d8 100644 --- a/app/classes/shared/cmd.py +++ b/app/classes/shared/cmd.py @@ -22,9 +22,10 @@ except ModuleNotFoundError as e: class MainPrompt(cmd.Cmd, object): - def __init__(self, tasks_manager): + def __init__(self, tasks_manager, migration_manager): super().__init__() self.tasks_manager = tasks_manager + self.migration_manager = migration_manager # overrides the default Prompt prompt = "Crafty Controller v{} > ".format(helper.get_version_string()) @@ -47,6 +48,27 @@ class MainPrompt(cmd.Cmd, object): def do_exit(self, line): self.universal_exit() + def do_migrations(self, line): + if (line == 'up'): + self.migration_manager.up() + elif (line == 'down'): + self.migration_manager.down() + elif (line == 'done'): + console.info(self.migration_manager.done) + elif (line == 'todo'): + console.info(self.migration_manager.todo) + elif (line == 'diff'): + console.info(self.migration_manager.diff) + elif (line == 'info'): + console.info('Done: {}'.format(self.migration_manager.done)) + console.info('FS: {}'.format(self.migration_manager.todo)) + console.info('Todo: {}'.format(self.migration_manager.diff)) + elif (line.startswith('add ')): + migration_name = line[len('add '):] + self.migration_manager.create(migration_name, False) + else: + console.info('Unknown migration command') + def universal_exit(self): logger.info("Stopping all server daemons / threads") console.info("Stopping all server daemons / threads - This may take a few seconds") @@ -62,3 +84,7 @@ class MainPrompt(cmd.Cmd, object): @staticmethod def help_exit(): console.help("Stops the server if running, Exits the program") + + @staticmethod + def help_migrations(): + console.help("Only for advanced users. Use with caution") diff --git a/app/classes/shared/helpers.py b/app/classes/shared/helpers.py index 5d124d32..5c3c8343 100644 --- a/app/classes/shared/helpers.py +++ b/app/classes/shared/helpers.py @@ -40,6 +40,7 @@ class Helpers: self.webroot = os.path.join(self.root_dir, 'app', 'frontend') self.servers_dir = os.path.join(self.root_dir, 'servers') self.backup_path = os.path.join(self.root_dir, 'backups') + self.migration_dir = os.path.join(self.root_dir, 'app', 'migrations') self.session_file = os.path.join(self.root_dir, 'app', 'config', 'session.lock') self.settings_file = os.path.join(self.root_dir, 'app', 'config', 'config.json') diff --git a/app/classes/shared/migration.py b/app/classes/shared/migration.py new file mode 100644 index 00000000..1f1f8d95 --- /dev/null +++ b/app/classes/shared/migration.py @@ -0,0 +1,532 @@ +from datetime import datetime +import logging +import typing as t +import sys +import os +import re +from importlib import import_module +from functools import wraps + +try: + from functools import cached_property +except ImportError: + from cached_property import cached_property + +from app.classes.shared.helpers import helper +from app.classes.shared.console import console + +logger = logging.getLogger(__name__) + +try: + import peewee + from playhouse.migrate import ( + SchemaMigrator as ScM, + SqliteMigrator as SqM, + Operation, SQL, operation, SqliteDatabase, + make_index_name, Context + ) + +except ModuleNotFoundError as e: + logger.critical("Import Error: Unable to load {} module".format( + e.name), exc_info=True) + console.critical("Import Error: Unable to load {} module".format(e.name)) + sys.exit(1) + + +class MigrateHistory(peewee.Model): + """ + Presents the migration history in a database. + """ + + name = peewee.CharField(unique=True) + migrated_at = peewee.DateTimeField(default=datetime.utcnow) + + def __unicode__(self) -> str: + """ + String representation of this migration + """ + return self.name + + +MIGRATE_TABLE = 'migratehistory' +MIGRATE_TEMPLATE = '''# Generated by database migrator + + +def migrate(migrator, database, **kwargs): + """ + Write your migrations here. + """ +{migrate} + + +def rollback(migrator, database, **kwargs): + """ + Write your rollback migrations here. + """ +{rollback}''' +VOID: t.Callable = lambda m, d: None + + +def get_model(method): + """ + Convert string to model class. + """ + + @wraps(method) + def wrapper(migrator, model, *args, **kwargs): + if isinstance(model, str): + return method(migrator, migrator.orm[model], *args, **kwargs) + return method(migrator, model, *args, **kwargs) + return wrapper + + +class Migrator(object): + def __init__(self, database: t.Union[peewee.Database, peewee.Proxy]): + """ + Initializes the migrator + """ + if isinstance(database, peewee.Proxy): + database = database.obj + self.database: SqliteDatabase = database + self.orm: t.Dict[str, peewee.Model] = {} + self.operations: t.List[Operation] = [] + self.migrator = SqliteMigrator(database) + + def run(self): + """ + Runs operations. + """ + for op in self.operations: + if isinstance(op, Operation): + op.run() + else: + op() + self.clean() + + def clean(self): + """ + Cleans the operations. + """ + self.operations = list() + + def sql(self, sql: str, *params): + """ + Executes raw SQL. + """ + self.operations.append(self.migrator.sql(sql, *params)) + + def create_table(self, model: peewee.Model) -> peewee.Model: + """ + Creates model and table in database. + """ + self.orm[model._meta.table_name] = model + model._meta.database = self.database + self.operations.append(model.create_table) + return model + + @get_model + def drop_table(self, model: peewee.Model): + """ + Drops model and table from database. + """ + del self.orm[model._meta.table_name] + self.operations.append(self.migrator.drop_table(model)) + + @get_model + def add_columns(self, model: peewee.Model, **fields: peewee.Field) -> peewee.Model: + """ + Creates new fields. + """ + for name, field in fields.items(): + model._meta.add_field(name, field) + self.operations.append(self.migrator.add_column( + model._meta.table_name, field.column_name, field)) + if field.unique: + self.operations.append(self.migrator.add_index( + model._meta.table_name, (field.column_name,), unique=True)) + return model + + @get_model + def change_columns(self, model: peewee.Model, **fields: peewee.Field) -> peewee.Model: + """ + Changes fields. + """ + for name, field in fields.items(): + old_field = model._meta.fields.get(name, field) + old_column_name = old_field and old_field.column_name + + model._meta.add_field(name, field) + + if isinstance(old_field, peewee.ForeignKeyField): + self.operations.append(self.migrator.drop_foreign_key_constraint( + model._meta.table_name, old_column_name)) + + if old_column_name != field.column_name: + self.operations.append( + self.migrator.rename_column( + model._meta.table_name, old_column_name, field.column_name)) + + if isinstance(field, peewee.ForeignKeyField): + on_delete = field.on_delete if field.on_delete else 'RESTRICT' + on_update = field.on_update if field.on_update else 'RESTRICT' + self.operations.append(self.migrator.add_foreign_key_constraint( + model._meta.table_name, field.column_name, + field.rel_model._meta.table_name, field.rel_field.name, + on_delete, on_update)) + continue + + self.operations.append(self.migrator.change_column( + model._meta.table_name, field.column_name, field)) + + if field.unique == old_field.unique: + continue + + if field.unique: + index = (field.column_name,), field.unique + self.operations.append(self.migrator.add_index( + model._meta.table_name, *index)) + model._meta.indexes.append(index) + else: + index = (field.column_name,), old_field.unique + self.operations.append(self.migrator.drop_index( + model._meta.table_name, *index)) + model._meta.indexes.remove(index) + + return model + + @get_model + def drop_columns(self, model: peewee.Model, names: str, **kwargs) -> peewee.Model: + """ + Removes fields from model. + """ + fields = [field for field in model._meta.fields.values() + if field.name in names] + cascade = kwargs.pop('cascade', True) + for field in fields: + self.__del_field__(model, field) + if field.unique: + index_name = make_index_name( + model._meta.table_name, [field.column_name]) + self.operations.append(self.migrator.drop_index( + model._meta.table_name, index_name)) + self.operations.append( + self.migrator.drop_column( + model._meta.table_name, field.column_name, cascade=False)) + return model + + def __del_field__(self, model: peewee.Model, field: peewee.Field): + """ + Deletes field from model. + """ + model._meta.remove_field(field.name) + delattr(model, field.name) + if isinstance(field, peewee.ForeignKeyField): + obj_id_name = field.column_name + if field.column_name == field.name: + obj_id_name += '_id' + delattr(model, obj_id_name) + delattr(field.rel_model, field.backref) + + @get_model + def rename_column(self, model: peewee.Model, old_name: str, new_name: str) -> peewee.Model: + """ + Renames field in model. + """ + field = model._meta.fields[old_name] + if isinstance(field, peewee.ForeignKeyField): + old_name = field.column_name + self.__del_field__(model, field) + field.name = field.column_name = new_name + model._meta.add_field(new_name, field) + if isinstance(field, peewee.ForeignKeyField): + field.column_name = new_name = field.column_name + '_id' + self.operations.append(self.migrator.rename_column( + model._meta.table_name, old_name, new_name)) + return model + + @get_model + def rename_table(self, model: peewee.Model, new_name: str) -> peewee.Model: + """ + Renames table in database. + """ + old_name = model._meta.table_name + del self.orm[model._meta.table_name] + model._meta.table_name = new_name + self.orm[model._meta.table_name] = model + self.operations.append(self.migrator.rename_table(old_name, new_name)) + return model + + @get_model + def add_index(self, model: peewee.Model, *columns: str, **kwargs) -> peewee.Model: + """Create indexes.""" + unique = kwargs.pop('unique', False) + model._meta.indexes.append((columns, unique)) + columns_ = [] + for col in columns: + field = model._meta.fields.get(col) + + if len(columns) == 1: + field.unique = unique + field.index = not unique + + if isinstance(field, peewee.ForeignKeyField): + col = col + '_id' + + columns_.append(col) + self.operations.append(self.migrator.add_index( + model._meta.table_name, columns_, unique=unique)) + return model + + @get_model + def drop_index(self, model: peewee.Model, *columns: str) -> peewee.Model: + """Drop indexes.""" + columns_ = [] + for col in columns: + field = model._meta.fields.get(col) + if not field: + continue + + if len(columns) == 1: + field.unique = field.index = False + + if isinstance(field, peewee.ForeignKeyField): + col = col + '_id' + columns_.append(col) + index_name = make_index_name(model._meta.table_name, columns_) + model._meta.indexes = [(cols, _) for ( + cols, _) in model._meta.indexes if columns != cols] + self.operations.append(self.migrator.drop_index( + model._meta.table_name, index_name)) + return model + + @get_model + def add_not_null(self, model: peewee.Model, *names: str) -> peewee.Model: + """Add not null.""" + for name in names: + field = model._meta.fields[name] + field.null = False + self.operations.append(self.migrator.add_not_null( + model._meta.table_name, field.column_name)) + return model + + @get_model + def drop_not_null(self, model: peewee.Model, *names: str) -> peewee.Model: + """Drop not null.""" + for name in names: + field = model._meta.fields[name] + field.null = True + self.operations.append(self.migrator.drop_not_null( + model._meta.table_name, field.column_name)) + return model + + @get_model + def add_default(self, model: peewee.Model, name: str, default: t.Any) -> peewee.Model: + """Add default.""" + field = model._meta.fields[name] + model._meta.defaults[field] = field.default = default + self.operations.append(self.migrator.apply_default( + model._meta.table_name, name, field)) + return model + + +class SqliteMigrator(SqM): + def drop_table(self, model): + return lambda: model.drop_table(cascade=False) + + @operation + def change_column(self, table: str, column_name: str, field: peewee.Field): + operations = [self.alter_change_column(table, column_name, field)] + if not field.null: + operations.extend([self.add_not_null(table, column_name)]) + return operations + + def alter_change_column(self, table: str, column_name: str, field: peewee.Field) -> Operation: + return self._update_column(table, column_name, lambda x, y: y) + + @operation + def sql(self, sql: str, *params) -> SQL: + """ + Executes raw SQL. + """ + return SQL(sql, *params) + + def alter_add_column( + self, table: str, column_name: str, field: peewee.Field, **kwargs) -> Operation: + """ + Fixes field name for ForeignKeys. + """ + name = field.name + op = super().alter_add_column( + table, column_name, field, **kwargs) + if isinstance(field, peewee.ForeignKeyField): + field.name = name + return op + + +class MigrationManager(object): + + filemask = re.compile(r"[\d]+_[^\.]+\.py$") + + def __init__(self, database: t.Union[peewee.Database, peewee.Proxy]): + """ + Initializes the migration manager. + """ + if not isinstance(database, (peewee.Database, peewee.Proxy)): + raise RuntimeError('Invalid database: {}'.format(database)) + self.database = database + + @cached_property + def model(self) -> peewee.Model: + """ + Initialize and cache the MigrationHistory model. + """ + MigrateHistory._meta.database = self.database + MigrateHistory._meta.table_name = 'migratehistory' + MigrateHistory._meta.schema = None + MigrateHistory.create_table(True) + return MigrateHistory + + @property + def done(self) -> t.List[str]: + """ + Scans migrations in the database. + """ + return [mm.name for mm in self.model.select().order_by(self.model.id)] + + @property + def todo(self): + """ + Scans migrations in the file system. + """ + if not os.path.exists(helper.migration_dir): + logger.warning('Migration directory: {} does not exist.'.format( + helper.migration_dir)) + os.makedirs(helper.migration_dir) + return sorted(f[:-3] for f in os.listdir(helper.migration_dir) if self.filemask.match(f)) + + @property + def diff(self) -> t.List[str]: + """ + Calculates difference between the filesystem and the database. + """ + done = set(self.done) + return [name for name in self.todo if name not in done] + + @cached_property + def migrator(self) -> Migrator: + """ + Create migrator and setup it with fake migrations. + """ + migrator = Migrator(self.database) + for name in self.done: + self.up_one(name, migrator) + return migrator + + def compile(self, name, migrate='', rollback=''): + """ + Compiles a migration. + """ + name = datetime.utcnow().strftime('%Y%m%d%H%M%S') + '_' + name + filename = name + '.py' + path = os.path.join(helper.migration_dir, filename) + with open(path, 'w') as f: + f.write(MIGRATE_TEMPLATE.format( + migrate=migrate, rollback=rollback, name=filename)) + + return name + + def create(self, name: str = 'auto', auto: bool = False) -> t.Optional[str]: + """ + Creates a migration. + """ + migrate = rollback = '' + if auto: + raise NotImplementedError + + logger.info('Creating migration "{}"'.format(name)) + name = self.compile(name, migrate, rollback) + logger.info('Migration has been created as "{}"'.format(name)) + return name + + def clear(self): + """Clear migrations.""" + self.model.delete().execute() + + def up(self, name: t.Optional[str] = None): + """ + Runs all unapplied migrations. + """ + logger.info('Starting migrations') + console.info('Starting migrations') + + done = [] + diff = self.diff + if not diff: + logger.info('There is nothing to migrate') + console.info('There is nothing to migrate') + return done + + migrator = self.migrator + for mname in diff: + done.append(self.up_one(mname, migrator)) + if name and name == mname: + break + + return done + + def read(self, name: str): + """ + Reads a migration from a file. + """ + call_params = dict() + if os.name == 'nt' and sys.version_info >= (3, 0): + # if system is windows - force utf-8 encoding + call_params['encoding'] = 'utf-8' + with open(os.path.join(helper.migration_dir, name + '.py'), **call_params) as f: + code = f.read() + scope = {} + code = compile(code, '', 'exec', dont_inherit=True) + exec(code, scope, None) + return scope.get('migrate', VOID), scope.get('rollback', VOID) + + def up_one(self, name: str, migrator: Migrator, + rollback: bool = False) -> str: + """ + Runs a migration with a given name. + """ + try: + migrate_fn, rollback_fn = self.read(name) + with self.database.transaction(): + if rollback: + logger.info('Rolling back "{}"'.format(name)) + rollback_fn(migrator, self.database) + migrator.run() + self.model.delete().where(self.model.name == name).execute() + else: + logger.info('Migrate "{}"'.format(name)) + migrate_fn(migrator, self.database) + migrator.run() + if name not in self.done: + self.model.create(name=name) + + logger.info('Done "{}"'.format(name)) + return name + + except Exception: + self.database.rollback() + operation = 'Rollback' if rollback else 'Migration' + logger.exception('{} failed: {}'.format(operation, name)) + raise + + def down(self, name: t.Optional[str] = None): + """ + Rolls back migrations. + """ + if not self.done: + raise RuntimeError('No migrations are found.') + + name = self.done[-1] + + migrator = self.migrator + self.up_one(name, migrator, True) + logger.warning('Rolled back migration: {}'.format(name)) diff --git a/app/classes/shared/models.py b/app/classes/shared/models.py index 5d9924e0..3b8b81eb 100644 --- a/app/classes/shared/models.py +++ b/app/classes/shared/models.py @@ -22,32 +22,12 @@ except ModuleNotFoundError as e: console.critical("Import Error: Unable to load {} module".format(e.name)) sys.exit(1) -schema_version = (0, 1, 0) # major, minor, patch semver - database = SqliteDatabase(helper.db_path, pragmas={ 'journal_mode': 'wal', 'cache_size': -1024 * 10}) -class BaseModel(Model): - class Meta: - database = database -class SchemaVersion(BaseModel): - # DO NOT EVER CHANGE THE SCHEMA OF THIS TABLE - # (unless we have a REALLY good reason to) - # There will only ever be one row, and it allows the database loader to detect - # what it needs to do on major version upgrades so you don't have to wipe the DB - # every time you upgrade - schema_major = IntegerField() - schema_minor = IntegerField() - schema_patch = IntegerField() - - class Meta: - table_name = 'schema_version' - primary_key = CompositeKey('schema_major', 'schema_minor', 'schema_patch') - - -class Users(BaseModel): +class Users(Model): user_id = AutoField() created = DateTimeField(default=datetime.datetime.now) last_login = DateTimeField(default=datetime.datetime.now) @@ -61,9 +41,10 @@ class Users(BaseModel): class Meta: table_name = "users" + database = database -class Roles(BaseModel): +class Roles(Model): role_id = AutoField() created = DateTimeField(default=datetime.datetime.now) last_update = DateTimeField(default=datetime.datetime.now) @@ -71,18 +52,20 @@ class Roles(BaseModel): class Meta: table_name = "roles" + database = database -class User_Roles(BaseModel): +class User_Roles(Model): user_id = ForeignKeyField(Users, backref='user_role') role_id = ForeignKeyField(Roles, backref='user_role') class Meta: table_name = 'user_roles' primary_key = CompositeKey('user_id', 'role_id') + database = database -class Audit_Log(BaseModel): +class Audit_Log(Model): audit_id = AutoField() created = DateTimeField(default=datetime.datetime.now) user_name = CharField(default="") @@ -91,8 +74,11 @@ class Audit_Log(BaseModel): server_id = IntegerField(default=None, index=True) # When auditing global events, use server ID 0 log_msg = TextField(default='') + class Meta: + database = database -class Host_Stats(BaseModel): + +class Host_Stats(Model): time = DateTimeField(default=datetime.datetime.now, index=True) boot_time = CharField(default="") cpu_usage = FloatField(default=0) @@ -106,9 +92,10 @@ class Host_Stats(BaseModel): class Meta: table_name = "host_stats" + database = database -class Servers(BaseModel): +class Servers(Model): server_id = AutoField() created = DateTimeField(default=datetime.datetime.now) server_uuid = CharField(default="", index=True) @@ -129,27 +116,30 @@ class Servers(BaseModel): class Meta: table_name = "servers" + database = database -class User_Servers(BaseModel): +class User_Servers(Model): user_id = ForeignKeyField(Users, backref='user_server') server_id = ForeignKeyField(Servers, backref='user_server') class Meta: table_name = 'user_servers' primary_key = CompositeKey('user_id', 'server_id') + database = database -class Role_Servers(BaseModel): +class Role_Servers(Model): role_id = ForeignKeyField(Roles, backref='role_server') server_id = ForeignKeyField(Servers, backref='role_server') class Meta: table_name = 'role_servers' primary_key = CompositeKey('role_id', 'server_id') + database = database -class Server_Stats(BaseModel): +class Server_Stats(Model): stats_id = AutoField() created = DateTimeField(default=datetime.datetime.now) server_id = ForeignKeyField(Servers, backref='server', index=True) @@ -172,9 +162,10 @@ class Server_Stats(BaseModel): class Meta: table_name = "server_stats" + database = database -class Commands(BaseModel): +class Commands(Model): command_id = AutoField() created = DateTimeField(default=datetime.datetime.now) server_id = ForeignKeyField(Servers, backref='server', index=True) @@ -185,9 +176,10 @@ class Commands(BaseModel): class Meta: table_name = "commands" + database = database -class Webhooks(BaseModel): +class Webhooks(Model): id = AutoField() name = CharField(max_length=64, unique=True, index=True) method = CharField(default="POST") @@ -197,8 +189,10 @@ class Webhooks(BaseModel): class Meta: table_name = "webhooks" + database = database -class Schedules(BaseModel): + +class Schedules(Model): schedule_id = IntegerField(unique=True, primary_key=True) server_id = ForeignKeyField(Servers, backref='schedule_server') enabled = BooleanField() @@ -211,8 +205,10 @@ class Schedules(BaseModel): class Meta: table_name = 'schedules' + database = database -class Backups(BaseModel): + +class Backups(Model): directories = CharField(null=True) max_backups = IntegerField() server_id = ForeignKeyField(Servers, backref='backups_server') @@ -220,39 +216,15 @@ class Backups(BaseModel): class Meta: table_name = 'backups' + database = database class db_builder: - @staticmethod - def create_tables(): - with database: - database.create_tables([ - Backups, - Users, - Roles, - User_Roles, - User_Servers, - Host_Stats, - Webhooks, - Servers, - Role_Servers, - Server_Stats, - Commands, - Audit_Log, - SchemaVersion, - Schedules - ]) - @staticmethod def default_settings(): logger.info("Fresh Install Detected - Creating Default Settings") console.info("Fresh Install Detected - Creating Default Settings") - SchemaVersion.insert({ - SchemaVersion.schema_major: schema_version[0], - SchemaVersion.schema_minor: schema_version[1], - SchemaVersion.schema_patch: schema_version[2] - }).execute() default_data = helper.find_default_password() username = default_data.get("username", 'admin') @@ -279,39 +251,8 @@ class db_builder: return True pass - @staticmethod - def check_schema_version(): - svs = SchemaVersion.select().execute() - if len(svs) != 1: - raise exceptions.SchemaError("Multiple or no schema versions detected - potentially a failed upgrade?") - sv = svs[0] - svt = (sv.schema_major, sv.schema_minor, sv.schema_patch) - logger.debug("Schema: found {}, expected {}".format(svt, schema_version)) - console.debug("Schema: found {}, expected {}".format(svt, schema_version)) - if sv.schema_major > schema_version[0]: - raise exceptions.SchemaError("Major version mismatch - possible code reversion") - elif sv.schema_major < schema_version[0]: - db_shortcuts.upgrade_schema() - - if sv.schema_minor > schema_version[1]: - logger.warning("Schema minor mismatch detected: found {}, expected {}. Proceed with caution".format(svt, schema_version)) - console.warning("Schema minor mismatch detected: found {}, expected {}. Proceed with caution".format(svt, schema_version)) - elif sv.schema_minor < schema_version[1]: - db_shortcuts.upgrade_schema() - - if sv.schema_patch > schema_version[2]: - logger.info("Schema patch mismatch detected: found {}, expected {}. Proceed with caution".format(svt, schema_version)) - console.info("Schema patch mismatch detected: found {}, expected {}. Proceed with caution".format(svt, schema_version)) - elif sv.schema_patch < schema_version[2]: - db_shortcuts.upgrade_schema() - logger.info("Schema validation successful! {}".format(schema_version)) - class db_shortcuts: - @staticmethod - def upgrade_schema(): - raise NotImplemented("I don't know who you are or how you reached this code, but this should NOT have happened. Please report it to the developer with due haste.") - @staticmethod def return_rows(query): rows = [] diff --git a/app/migrations/20210813111015_init.py b/app/migrations/20210813111015_init.py new file mode 100644 index 00000000..66e7c83a --- /dev/null +++ b/app/migrations/20210813111015_init.py @@ -0,0 +1,215 @@ +import peewee +import datetime + + +def migrate(migrator, database, **kwargs): + db = database + class Users(peewee.Model): + user_id = peewee.AutoField() + created = peewee.DateTimeField(default=datetime.datetime.now) + last_login = peewee.DateTimeField(default=datetime.datetime.now) + last_update = peewee.DateTimeField(default=datetime.datetime.now) + last_ip = peewee.CharField(default="") + username = peewee.CharField(default="", unique=True, index=True) + password = peewee.CharField(default="") + enabled = peewee.BooleanField(default=True) + superuser = peewee.BooleanField(default=False) + # we may need to revisit this + api_token = peewee.CharField(default="", unique=True, index=True) + + class Meta: + table_name = "users" + database = db + + class Roles(peewee.Model): + role_id = peewee.AutoField() + created = peewee.DateTimeField(default=datetime.datetime.now) + last_update = peewee.DateTimeField(default=datetime.datetime.now) + role_name = peewee.CharField(default="", unique=True, index=True) + + class Meta: + table_name = "roles" + database = db + + class User_Roles(peewee.Model): + user_id = peewee.ForeignKeyField(Users, backref='user_role') + role_id = peewee.ForeignKeyField(Roles, backref='user_role') + + class Meta: + table_name = 'user_roles' + primary_key = peewee.CompositeKey('user_id', 'role_id') + database = db + + class Audit_Log(peewee.Model): + audit_id = peewee.AutoField() + created = peewee.DateTimeField(default=datetime.datetime.now) + user_name = peewee.CharField(default="") + user_id = peewee.IntegerField(default=0, index=True) + source_ip = peewee.CharField(default='127.0.0.1') + # When auditing global events, use server ID 0 + server_id = peewee.IntegerField(default=None, index=True) + log_msg = peewee.TextField(default='') + + class Meta: + database = db + + class Host_Stats(peewee.Model): + time = peewee.DateTimeField(default=datetime.datetime.now, index=True) + boot_time = peewee.CharField(default="") + cpu_usage = peewee.FloatField(default=0) + cpu_cores = peewee.IntegerField(default=0) + cpu_cur_freq = peewee.FloatField(default=0) + cpu_max_freq = peewee.FloatField(default=0) + mem_percent = peewee.FloatField(default=0) + mem_usage = peewee.CharField(default="") + mem_total = peewee.CharField(default="") + disk_json = peewee.TextField(default="") + + class Meta: + table_name = "host_stats" + database = db + + class Servers(peewee.Model): + server_id = peewee.AutoField() + created = peewee.DateTimeField(default=datetime.datetime.now) + server_uuid = peewee.CharField(default="", index=True) + server_name = peewee.CharField(default="Server", index=True) + path = peewee.CharField(default="") + backup_path = peewee.CharField(default="") + executable = peewee.CharField(default="") + log_path = peewee.CharField(default="") + execution_command = peewee.CharField(default="") + auto_start = peewee.BooleanField(default=0) + auto_start_delay = peewee.IntegerField(default=10) + crash_detection = peewee.BooleanField(default=0) + stop_command = peewee.CharField(default="stop") + executable_update_url = peewee.CharField(default="") + server_ip = peewee.CharField(default="127.0.0.1") + server_port = peewee.IntegerField(default=25565) + logs_delete_after = peewee.IntegerField(default=0) + + class Meta: + table_name = "servers" + database = db + + class User_Servers(peewee.Model): + user_id = peewee.ForeignKeyField(Users, backref='user_server') + server_id = peewee.ForeignKeyField(Servers, backref='user_server') + + class Meta: + table_name = 'user_servers' + primary_key = peewee.CompositeKey('user_id', 'server_id') + database = db + + class Role_Servers(peewee.Model): + role_id = peewee.ForeignKeyField(Roles, backref='role_server') + server_id = peewee.ForeignKeyField(Servers, backref='role_server') + + class Meta: + table_name = 'role_servers' + primary_key = peewee.CompositeKey('role_id', 'server_id') + database = db + + class Server_Stats(peewee.Model): + stats_id = peewee.AutoField() + created = peewee.DateTimeField(default=datetime.datetime.now) + server_id = peewee.ForeignKeyField(Servers, backref='server', index=True) + started = peewee.CharField(default="") + running = peewee.BooleanField(default=False) + cpu = peewee.FloatField(default=0) + mem = peewee.FloatField(default=0) + mem_percent = peewee.FloatField(default=0) + world_name = peewee.CharField(default="") + world_size = peewee.CharField(default="") + server_port = peewee.IntegerField(default=25565) + int_ping_results = peewee.CharField(default="") + online = peewee.IntegerField(default=0) + max = peewee.IntegerField(default=0) + players = peewee.CharField(default="") + desc = peewee.CharField(default="Unable to Connect") + version = peewee.CharField(default="") + updating = peewee.BooleanField(default=False) + + class Meta: + table_name = "server_stats" + database = db + + class Commands(peewee.Model): + command_id = peewee.AutoField() + created = peewee.DateTimeField(default=datetime.datetime.now) + server_id = peewee.ForeignKeyField(Servers, backref='server', index=True) + user = peewee.ForeignKeyField(Users, backref='user', index=True) + source_ip = peewee.CharField(default='127.0.0.1') + command = peewee.CharField(default='') + executed = peewee.BooleanField(default=False) + + class Meta: + table_name = "commands" + database = db + + class Webhooks(peewee.Model): + id = peewee.AutoField() + name = peewee.CharField(max_length=64, unique=True, index=True) + method = peewee.CharField(default="POST") + url = peewee.CharField(unique=True) + event = peewee.CharField(default="") + send_data = peewee.BooleanField(default=True) + + class Meta: + table_name = "webhooks" + database = db + + class Schedules(peewee.Model): + schedule_id = peewee.IntegerField(unique=True, primary_key=True) + server_id = peewee.ForeignKeyField(Servers, backref='schedule_server') + enabled = peewee.BooleanField() + action = peewee.CharField() + interval = peewee.IntegerField() + interval_type = peewee.CharField() + start_time = peewee.CharField(null=True) + command = peewee.CharField(null=True) + comment = peewee.CharField() + + class Meta: + table_name = 'schedules' + database = db + + class Backups(peewee.Model): + directories = peewee.CharField(null=True) + max_backups = peewee.IntegerField() + server_id = peewee.ForeignKeyField(Servers, backref='backups_server') + schedule_id = peewee.ForeignKeyField(Schedules, backref='backups_schedule') + + class Meta: + table_name = 'backups' + database = db + + migrator.create_table(Backups) + migrator.create_table(Users) + migrator.create_table(Roles) + migrator.create_table(User_Roles) + migrator.create_table(User_Servers) + migrator.create_table(Host_Stats) + migrator.create_table(Webhooks) + migrator.create_table(Servers) + migrator.create_table(Role_Servers) + migrator.create_table(Server_Stats) + migrator.create_table(Commands) + migrator.create_table(Audit_Log) + migrator.create_table(Schedules) + + +def rollback(migrator, database, **kwargs): + migrator.drop_table('users') + migrator.drop_table('roles') + migrator.drop_table('user_roles') + migrator.drop_table('audit_log') # ? Not 100% sure of the table name, please specify in the schema + migrator.drop_table('host_stats') + migrator.drop_table('servers') + migrator.drop_table('user_servers') + migrator.drop_table('role_servers') + migrator.drop_table('server_stats') + migrator.drop_table('commands') + migrator.drop_table('webhooks') + migrator.drop_table('schedules') + migrator.drop_table('backups') diff --git a/main.py b/main.py index 71fd2489..2ca05097 100644 --- a/main.py +++ b/main.py @@ -8,10 +8,11 @@ import logging.config """ Our custom classes / pip packages """ from app.classes.shared.console import console from app.classes.shared.helpers import helper -from app.classes.shared.models import installer +from app.classes.shared.models import installer, database from app.classes.shared.tasks import TasksManager from app.classes.shared.controller import Controller +from app.classes.shared.migration import MigrationManager from app.classes.shared.cmd import MainPrompt @@ -90,16 +91,18 @@ if __name__ == '__main__': # our session file, helps prevent multiple controller agents on the same machine. helper.create_session_file(ignore=args.ignore) + + migration_manager = MigrationManager(database) + migration_manager.up() # Automatically runs migrations + # do our installer stuff fresh_install = installer.is_fresh_install() if fresh_install: console.debug("Fresh install detected") - installer.create_tables() installer.default_settings() else: console.debug("Existing install detected") - installer.check_schema_version() # now the tables are created, we can load the tasks_manger and server controller controller = Controller() @@ -127,7 +130,7 @@ if __name__ == '__main__': # this should always be last tasks_manager.start_main_kill_switch_watcher() - Crafty = MainPrompt(tasks_manager) + Crafty = MainPrompt(tasks_manager, migration_manager) if not args.daemon: Crafty.cmdloop() else: diff --git a/requirements.txt b/requirements.txt index 7bb0410b..3f6b29f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,4 +23,3 @@ termcolor==1.1.0 tornado==6.0.4 urllib3==1.25.10 webencodings==0.5.1 -peewee_migrate==1.4.6