Add database migrations!

This commit is contained in:
luukas 2021-08-18 18:11:53 +03:00
parent aebf50dfc6
commit f8aad8b33d
9 changed files with 883 additions and 95 deletions

BIN
.DS_Store vendored

Binary file not shown.

71
DBCHANGES.md Normal file
View File

@ -0,0 +1,71 @@
# Database change guide for contributors
When updating a database schema modify the schema in `app/classes/shared/models.py` and create a new migration with the `migration add <name>` command (in Crafty's prompt).
A full list of helper functions you can find in `app/classes/shared/models.py`
## Example migration files
### Rename column/field
```py
def migrate(migrator, database, **kwargs):
migrator.rename_column('my_table', 'old_name', 'new_name') # First argument can be model class OR table name
def rollback(migrator, database, **kwargs):
migrator.rename_column('my_table', 'new_name', 'old_name') # First argument can be model class OR table name
```
### Rename table/model
```py
def migrate(migrator, database, **kwargs):
migrator.rename_table('old_name', 'new_name') # First argument can be model class OR table name
def rollback(migrator, database, **kwargs):
migrator.rename_table('new_name', 'old_name') # First argument can be model class OR table name
```
### Create table/model
```py
import peewee
def migrate(migrator, database, **kwargs):
class NewTable(peewee.Model):
my_id = peewee.IntegerField(unique=True, primary_key=True)
class Meta:
table_name = 'new_table'
database = database
create_table(NewTable)
def rollback(migrator, database, **kwargs):
drop_table('new_table') # Can be model class OR table name
```
### Add columns/fields
```py
import peewee
def migrate(migrator, database, **kwargs):
migrator.add_columns('table_name', new_field_name=peewee.CharField(default="")) # First argument can be model class OR table name
def rollback(migrator, database, **kwargs):
migrator.drop_columns('table_name', ['new_field_name']) # First argument can be model class OR table name
```

View File

@ -22,9 +22,10 @@ except ModuleNotFoundError as e:
class MainPrompt(cmd.Cmd, object): class MainPrompt(cmd.Cmd, object):
def __init__(self, tasks_manager): def __init__(self, tasks_manager, migration_manager):
super().__init__() super().__init__()
self.tasks_manager = tasks_manager self.tasks_manager = tasks_manager
self.migration_manager = migration_manager
# overrides the default Prompt # overrides the default Prompt
prompt = "Crafty Controller v{} > ".format(helper.get_version_string()) prompt = "Crafty Controller v{} > ".format(helper.get_version_string())
@ -47,6 +48,27 @@ class MainPrompt(cmd.Cmd, object):
def do_exit(self, line): def do_exit(self, line):
self.universal_exit() self.universal_exit()
def do_migrations(self, line):
if (line == 'up'):
self.migration_manager.up()
elif (line == 'down'):
self.migration_manager.down()
elif (line == 'done'):
console.info(self.migration_manager.done)
elif (line == 'todo'):
console.info(self.migration_manager.todo)
elif (line == 'diff'):
console.info(self.migration_manager.diff)
elif (line == 'info'):
console.info('Done: {}'.format(self.migration_manager.done))
console.info('FS: {}'.format(self.migration_manager.todo))
console.info('Todo: {}'.format(self.migration_manager.diff))
elif (line.startswith('add ')):
migration_name = line[len('add '):]
self.migration_manager.create(migration_name, False)
else:
console.info('Unknown migration command')
def universal_exit(self): def universal_exit(self):
logger.info("Stopping all server daemons / threads") logger.info("Stopping all server daemons / threads")
console.info("Stopping all server daemons / threads - This may take a few seconds") console.info("Stopping all server daemons / threads - This may take a few seconds")
@ -62,3 +84,7 @@ class MainPrompt(cmd.Cmd, object):
@staticmethod @staticmethod
def help_exit(): def help_exit():
console.help("Stops the server if running, Exits the program") console.help("Stops the server if running, Exits the program")
@staticmethod
def help_migrations():
console.help("Only for advanced users. Use with caution")

View File

@ -40,6 +40,7 @@ class Helpers:
self.webroot = os.path.join(self.root_dir, 'app', 'frontend') self.webroot = os.path.join(self.root_dir, 'app', 'frontend')
self.servers_dir = os.path.join(self.root_dir, 'servers') self.servers_dir = os.path.join(self.root_dir, 'servers')
self.backup_path = os.path.join(self.root_dir, 'backups') self.backup_path = os.path.join(self.root_dir, 'backups')
self.migration_dir = os.path.join(self.root_dir, 'app', 'migrations')
self.session_file = os.path.join(self.root_dir, 'app', 'config', 'session.lock') self.session_file = os.path.join(self.root_dir, 'app', 'config', 'session.lock')
self.settings_file = os.path.join(self.root_dir, 'app', 'config', 'config.json') self.settings_file = os.path.join(self.root_dir, 'app', 'config', 'config.json')

View File

@ -0,0 +1,532 @@
from datetime import datetime
import logging
import typing as t
import sys
import os
import re
from importlib import import_module
from functools import wraps
try:
from functools import cached_property
except ImportError:
from cached_property import cached_property
from app.classes.shared.helpers import helper
from app.classes.shared.console import console
logger = logging.getLogger(__name__)
try:
import peewee
from playhouse.migrate import (
SchemaMigrator as ScM,
SqliteMigrator as SqM,
Operation, SQL, operation, SqliteDatabase,
make_index_name, Context
)
except ModuleNotFoundError as e:
logger.critical("Import Error: Unable to load {} module".format(
e.name), exc_info=True)
console.critical("Import Error: Unable to load {} module".format(e.name))
sys.exit(1)
class MigrateHistory(peewee.Model):
"""
Presents the migration history in a database.
"""
name = peewee.CharField(unique=True)
migrated_at = peewee.DateTimeField(default=datetime.utcnow)
def __unicode__(self) -> str:
"""
String representation of this migration
"""
return self.name
MIGRATE_TABLE = 'migratehistory'
MIGRATE_TEMPLATE = '''# Generated by database migrator
def migrate(migrator, database, **kwargs):
"""
Write your migrations here.
"""
{migrate}
def rollback(migrator, database, **kwargs):
"""
Write your rollback migrations here.
"""
{rollback}'''
VOID: t.Callable = lambda m, d: None
def get_model(method):
"""
Convert string to model class.
"""
@wraps(method)
def wrapper(migrator, model, *args, **kwargs):
if isinstance(model, str):
return method(migrator, migrator.orm[model], *args, **kwargs)
return method(migrator, model, *args, **kwargs)
return wrapper
class Migrator(object):
def __init__(self, database: t.Union[peewee.Database, peewee.Proxy]):
"""
Initializes the migrator
"""
if isinstance(database, peewee.Proxy):
database = database.obj
self.database: SqliteDatabase = database
self.orm: t.Dict[str, peewee.Model] = {}
self.operations: t.List[Operation] = []
self.migrator = SqliteMigrator(database)
def run(self):
"""
Runs operations.
"""
for op in self.operations:
if isinstance(op, Operation):
op.run()
else:
op()
self.clean()
def clean(self):
"""
Cleans the operations.
"""
self.operations = list()
def sql(self, sql: str, *params):
"""
Executes raw SQL.
"""
self.operations.append(self.migrator.sql(sql, *params))
def create_table(self, model: peewee.Model) -> peewee.Model:
"""
Creates model and table in database.
"""
self.orm[model._meta.table_name] = model
model._meta.database = self.database
self.operations.append(model.create_table)
return model
@get_model
def drop_table(self, model: peewee.Model):
"""
Drops model and table from database.
"""
del self.orm[model._meta.table_name]
self.operations.append(self.migrator.drop_table(model))
@get_model
def add_columns(self, model: peewee.Model, **fields: peewee.Field) -> peewee.Model:
"""
Creates new fields.
"""
for name, field in fields.items():
model._meta.add_field(name, field)
self.operations.append(self.migrator.add_column(
model._meta.table_name, field.column_name, field))
if field.unique:
self.operations.append(self.migrator.add_index(
model._meta.table_name, (field.column_name,), unique=True))
return model
@get_model
def change_columns(self, model: peewee.Model, **fields: peewee.Field) -> peewee.Model:
"""
Changes fields.
"""
for name, field in fields.items():
old_field = model._meta.fields.get(name, field)
old_column_name = old_field and old_field.column_name
model._meta.add_field(name, field)
if isinstance(old_field, peewee.ForeignKeyField):
self.operations.append(self.migrator.drop_foreign_key_constraint(
model._meta.table_name, old_column_name))
if old_column_name != field.column_name:
self.operations.append(
self.migrator.rename_column(
model._meta.table_name, old_column_name, field.column_name))
if isinstance(field, peewee.ForeignKeyField):
on_delete = field.on_delete if field.on_delete else 'RESTRICT'
on_update = field.on_update if field.on_update else 'RESTRICT'
self.operations.append(self.migrator.add_foreign_key_constraint(
model._meta.table_name, field.column_name,
field.rel_model._meta.table_name, field.rel_field.name,
on_delete, on_update))
continue
self.operations.append(self.migrator.change_column(
model._meta.table_name, field.column_name, field))
if field.unique == old_field.unique:
continue
if field.unique:
index = (field.column_name,), field.unique
self.operations.append(self.migrator.add_index(
model._meta.table_name, *index))
model._meta.indexes.append(index)
else:
index = (field.column_name,), old_field.unique
self.operations.append(self.migrator.drop_index(
model._meta.table_name, *index))
model._meta.indexes.remove(index)
return model
@get_model
def drop_columns(self, model: peewee.Model, names: str, **kwargs) -> peewee.Model:
"""
Removes fields from model.
"""
fields = [field for field in model._meta.fields.values()
if field.name in names]
cascade = kwargs.pop('cascade', True)
for field in fields:
self.__del_field__(model, field)
if field.unique:
index_name = make_index_name(
model._meta.table_name, [field.column_name])
self.operations.append(self.migrator.drop_index(
model._meta.table_name, index_name))
self.operations.append(
self.migrator.drop_column(
model._meta.table_name, field.column_name, cascade=False))
return model
def __del_field__(self, model: peewee.Model, field: peewee.Field):
"""
Deletes field from model.
"""
model._meta.remove_field(field.name)
delattr(model, field.name)
if isinstance(field, peewee.ForeignKeyField):
obj_id_name = field.column_name
if field.column_name == field.name:
obj_id_name += '_id'
delattr(model, obj_id_name)
delattr(field.rel_model, field.backref)
@get_model
def rename_column(self, model: peewee.Model, old_name: str, new_name: str) -> peewee.Model:
"""
Renames field in model.
"""
field = model._meta.fields[old_name]
if isinstance(field, peewee.ForeignKeyField):
old_name = field.column_name
self.__del_field__(model, field)
field.name = field.column_name = new_name
model._meta.add_field(new_name, field)
if isinstance(field, peewee.ForeignKeyField):
field.column_name = new_name = field.column_name + '_id'
self.operations.append(self.migrator.rename_column(
model._meta.table_name, old_name, new_name))
return model
@get_model
def rename_table(self, model: peewee.Model, new_name: str) -> peewee.Model:
"""
Renames table in database.
"""
old_name = model._meta.table_name
del self.orm[model._meta.table_name]
model._meta.table_name = new_name
self.orm[model._meta.table_name] = model
self.operations.append(self.migrator.rename_table(old_name, new_name))
return model
@get_model
def add_index(self, model: peewee.Model, *columns: str, **kwargs) -> peewee.Model:
"""Create indexes."""
unique = kwargs.pop('unique', False)
model._meta.indexes.append((columns, unique))
columns_ = []
for col in columns:
field = model._meta.fields.get(col)
if len(columns) == 1:
field.unique = unique
field.index = not unique
if isinstance(field, peewee.ForeignKeyField):
col = col + '_id'
columns_.append(col)
self.operations.append(self.migrator.add_index(
model._meta.table_name, columns_, unique=unique))
return model
@get_model
def drop_index(self, model: peewee.Model, *columns: str) -> peewee.Model:
"""Drop indexes."""
columns_ = []
for col in columns:
field = model._meta.fields.get(col)
if not field:
continue
if len(columns) == 1:
field.unique = field.index = False
if isinstance(field, peewee.ForeignKeyField):
col = col + '_id'
columns_.append(col)
index_name = make_index_name(model._meta.table_name, columns_)
model._meta.indexes = [(cols, _) for (
cols, _) in model._meta.indexes if columns != cols]
self.operations.append(self.migrator.drop_index(
model._meta.table_name, index_name))
return model
@get_model
def add_not_null(self, model: peewee.Model, *names: str) -> peewee.Model:
"""Add not null."""
for name in names:
field = model._meta.fields[name]
field.null = False
self.operations.append(self.migrator.add_not_null(
model._meta.table_name, field.column_name))
return model
@get_model
def drop_not_null(self, model: peewee.Model, *names: str) -> peewee.Model:
"""Drop not null."""
for name in names:
field = model._meta.fields[name]
field.null = True
self.operations.append(self.migrator.drop_not_null(
model._meta.table_name, field.column_name))
return model
@get_model
def add_default(self, model: peewee.Model, name: str, default: t.Any) -> peewee.Model:
"""Add default."""
field = model._meta.fields[name]
model._meta.defaults[field] = field.default = default
self.operations.append(self.migrator.apply_default(
model._meta.table_name, name, field))
return model
class SqliteMigrator(SqM):
def drop_table(self, model):
return lambda: model.drop_table(cascade=False)
@operation
def change_column(self, table: str, column_name: str, field: peewee.Field):
operations = [self.alter_change_column(table, column_name, field)]
if not field.null:
operations.extend([self.add_not_null(table, column_name)])
return operations
def alter_change_column(self, table: str, column_name: str, field: peewee.Field) -> Operation:
return self._update_column(table, column_name, lambda x, y: y)
@operation
def sql(self, sql: str, *params) -> SQL:
"""
Executes raw SQL.
"""
return SQL(sql, *params)
def alter_add_column(
self, table: str, column_name: str, field: peewee.Field, **kwargs) -> Operation:
"""
Fixes field name for ForeignKeys.
"""
name = field.name
op = super().alter_add_column(
table, column_name, field, **kwargs)
if isinstance(field, peewee.ForeignKeyField):
field.name = name
return op
class MigrationManager(object):
filemask = re.compile(r"[\d]+_[^\.]+\.py$")
def __init__(self, database: t.Union[peewee.Database, peewee.Proxy]):
"""
Initializes the migration manager.
"""
if not isinstance(database, (peewee.Database, peewee.Proxy)):
raise RuntimeError('Invalid database: {}'.format(database))
self.database = database
@cached_property
def model(self) -> peewee.Model:
"""
Initialize and cache the MigrationHistory model.
"""
MigrateHistory._meta.database = self.database
MigrateHistory._meta.table_name = 'migratehistory'
MigrateHistory._meta.schema = None
MigrateHistory.create_table(True)
return MigrateHistory
@property
def done(self) -> t.List[str]:
"""
Scans migrations in the database.
"""
return [mm.name for mm in self.model.select().order_by(self.model.id)]
@property
def todo(self):
"""
Scans migrations in the file system.
"""
if not os.path.exists(helper.migration_dir):
logger.warning('Migration directory: {} does not exist.'.format(
helper.migration_dir))
os.makedirs(helper.migration_dir)
return sorted(f[:-3] for f in os.listdir(helper.migration_dir) if self.filemask.match(f))
@property
def diff(self) -> t.List[str]:
"""
Calculates difference between the filesystem and the database.
"""
done = set(self.done)
return [name for name in self.todo if name not in done]
@cached_property
def migrator(self) -> Migrator:
"""
Create migrator and setup it with fake migrations.
"""
migrator = Migrator(self.database)
for name in self.done:
self.up_one(name, migrator)
return migrator
def compile(self, name, migrate='', rollback=''):
"""
Compiles a migration.
"""
name = datetime.utcnow().strftime('%Y%m%d%H%M%S') + '_' + name
filename = name + '.py'
path = os.path.join(helper.migration_dir, filename)
with open(path, 'w') as f:
f.write(MIGRATE_TEMPLATE.format(
migrate=migrate, rollback=rollback, name=filename))
return name
def create(self, name: str = 'auto', auto: bool = False) -> t.Optional[str]:
"""
Creates a migration.
"""
migrate = rollback = ''
if auto:
raise NotImplementedError
logger.info('Creating migration "{}"'.format(name))
name = self.compile(name, migrate, rollback)
logger.info('Migration has been created as "{}"'.format(name))
return name
def clear(self):
"""Clear migrations."""
self.model.delete().execute()
def up(self, name: t.Optional[str] = None):
"""
Runs all unapplied migrations.
"""
logger.info('Starting migrations')
console.info('Starting migrations')
done = []
diff = self.diff
if not diff:
logger.info('There is nothing to migrate')
console.info('There is nothing to migrate')
return done
migrator = self.migrator
for mname in diff:
done.append(self.up_one(mname, migrator))
if name and name == mname:
break
return done
def read(self, name: str):
"""
Reads a migration from a file.
"""
call_params = dict()
if os.name == 'nt' and sys.version_info >= (3, 0):
# if system is windows - force utf-8 encoding
call_params['encoding'] = 'utf-8'
with open(os.path.join(helper.migration_dir, name + '.py'), **call_params) as f:
code = f.read()
scope = {}
code = compile(code, '<string>', 'exec', dont_inherit=True)
exec(code, scope, None)
return scope.get('migrate', VOID), scope.get('rollback', VOID)
def up_one(self, name: str, migrator: Migrator,
rollback: bool = False) -> str:
"""
Runs a migration with a given name.
"""
try:
migrate_fn, rollback_fn = self.read(name)
with self.database.transaction():
if rollback:
logger.info('Rolling back "{}"'.format(name))
rollback_fn(migrator, self.database)
migrator.run()
self.model.delete().where(self.model.name == name).execute()
else:
logger.info('Migrate "{}"'.format(name))
migrate_fn(migrator, self.database)
migrator.run()
if name not in self.done:
self.model.create(name=name)
logger.info('Done "{}"'.format(name))
return name
except Exception:
self.database.rollback()
operation = 'Rollback' if rollback else 'Migration'
logger.exception('{} failed: {}'.format(operation, name))
raise
def down(self, name: t.Optional[str] = None):
"""
Rolls back migrations.
"""
if not self.done:
raise RuntimeError('No migrations are found.')
name = self.done[-1]
migrator = self.migrator
self.up_one(name, migrator, True)
logger.warning('Rolled back migration: {}'.format(name))

View File

@ -22,32 +22,12 @@ except ModuleNotFoundError as e:
console.critical("Import Error: Unable to load {} module".format(e.name)) console.critical("Import Error: Unable to load {} module".format(e.name))
sys.exit(1) sys.exit(1)
schema_version = (0, 1, 0) # major, minor, patch semver
database = SqliteDatabase(helper.db_path, pragmas={ database = SqliteDatabase(helper.db_path, pragmas={
'journal_mode': 'wal', 'journal_mode': 'wal',
'cache_size': -1024 * 10}) 'cache_size': -1024 * 10})
class BaseModel(Model):
class Meta:
database = database
class SchemaVersion(BaseModel): class Users(Model):
# DO NOT EVER CHANGE THE SCHEMA OF THIS TABLE
# (unless we have a REALLY good reason to)
# There will only ever be one row, and it allows the database loader to detect
# what it needs to do on major version upgrades so you don't have to wipe the DB
# every time you upgrade
schema_major = IntegerField()
schema_minor = IntegerField()
schema_patch = IntegerField()
class Meta:
table_name = 'schema_version'
primary_key = CompositeKey('schema_major', 'schema_minor', 'schema_patch')
class Users(BaseModel):
user_id = AutoField() user_id = AutoField()
created = DateTimeField(default=datetime.datetime.now) created = DateTimeField(default=datetime.datetime.now)
last_login = DateTimeField(default=datetime.datetime.now) last_login = DateTimeField(default=datetime.datetime.now)
@ -61,9 +41,10 @@ class Users(BaseModel):
class Meta: class Meta:
table_name = "users" table_name = "users"
database = database
class Roles(BaseModel): class Roles(Model):
role_id = AutoField() role_id = AutoField()
created = DateTimeField(default=datetime.datetime.now) created = DateTimeField(default=datetime.datetime.now)
last_update = DateTimeField(default=datetime.datetime.now) last_update = DateTimeField(default=datetime.datetime.now)
@ -71,18 +52,20 @@ class Roles(BaseModel):
class Meta: class Meta:
table_name = "roles" table_name = "roles"
database = database
class User_Roles(BaseModel): class User_Roles(Model):
user_id = ForeignKeyField(Users, backref='user_role') user_id = ForeignKeyField(Users, backref='user_role')
role_id = ForeignKeyField(Roles, backref='user_role') role_id = ForeignKeyField(Roles, backref='user_role')
class Meta: class Meta:
table_name = 'user_roles' table_name = 'user_roles'
primary_key = CompositeKey('user_id', 'role_id') primary_key = CompositeKey('user_id', 'role_id')
database = database
class Audit_Log(BaseModel): class Audit_Log(Model):
audit_id = AutoField() audit_id = AutoField()
created = DateTimeField(default=datetime.datetime.now) created = DateTimeField(default=datetime.datetime.now)
user_name = CharField(default="") user_name = CharField(default="")
@ -91,8 +74,11 @@ class Audit_Log(BaseModel):
server_id = IntegerField(default=None, index=True) # When auditing global events, use server ID 0 server_id = IntegerField(default=None, index=True) # When auditing global events, use server ID 0
log_msg = TextField(default='') log_msg = TextField(default='')
class Meta:
database = database
class Host_Stats(BaseModel):
class Host_Stats(Model):
time = DateTimeField(default=datetime.datetime.now, index=True) time = DateTimeField(default=datetime.datetime.now, index=True)
boot_time = CharField(default="") boot_time = CharField(default="")
cpu_usage = FloatField(default=0) cpu_usage = FloatField(default=0)
@ -106,9 +92,10 @@ class Host_Stats(BaseModel):
class Meta: class Meta:
table_name = "host_stats" table_name = "host_stats"
database = database
class Servers(BaseModel): class Servers(Model):
server_id = AutoField() server_id = AutoField()
created = DateTimeField(default=datetime.datetime.now) created = DateTimeField(default=datetime.datetime.now)
server_uuid = CharField(default="", index=True) server_uuid = CharField(default="", index=True)
@ -129,27 +116,30 @@ class Servers(BaseModel):
class Meta: class Meta:
table_name = "servers" table_name = "servers"
database = database
class User_Servers(BaseModel): class User_Servers(Model):
user_id = ForeignKeyField(Users, backref='user_server') user_id = ForeignKeyField(Users, backref='user_server')
server_id = ForeignKeyField(Servers, backref='user_server') server_id = ForeignKeyField(Servers, backref='user_server')
class Meta: class Meta:
table_name = 'user_servers' table_name = 'user_servers'
primary_key = CompositeKey('user_id', 'server_id') primary_key = CompositeKey('user_id', 'server_id')
database = database
class Role_Servers(BaseModel): class Role_Servers(Model):
role_id = ForeignKeyField(Roles, backref='role_server') role_id = ForeignKeyField(Roles, backref='role_server')
server_id = ForeignKeyField(Servers, backref='role_server') server_id = ForeignKeyField(Servers, backref='role_server')
class Meta: class Meta:
table_name = 'role_servers' table_name = 'role_servers'
primary_key = CompositeKey('role_id', 'server_id') primary_key = CompositeKey('role_id', 'server_id')
database = database
class Server_Stats(BaseModel): class Server_Stats(Model):
stats_id = AutoField() stats_id = AutoField()
created = DateTimeField(default=datetime.datetime.now) created = DateTimeField(default=datetime.datetime.now)
server_id = ForeignKeyField(Servers, backref='server', index=True) server_id = ForeignKeyField(Servers, backref='server', index=True)
@ -172,9 +162,10 @@ class Server_Stats(BaseModel):
class Meta: class Meta:
table_name = "server_stats" table_name = "server_stats"
database = database
class Commands(BaseModel): class Commands(Model):
command_id = AutoField() command_id = AutoField()
created = DateTimeField(default=datetime.datetime.now) created = DateTimeField(default=datetime.datetime.now)
server_id = ForeignKeyField(Servers, backref='server', index=True) server_id = ForeignKeyField(Servers, backref='server', index=True)
@ -185,9 +176,10 @@ class Commands(BaseModel):
class Meta: class Meta:
table_name = "commands" table_name = "commands"
database = database
class Webhooks(BaseModel): class Webhooks(Model):
id = AutoField() id = AutoField()
name = CharField(max_length=64, unique=True, index=True) name = CharField(max_length=64, unique=True, index=True)
method = CharField(default="POST") method = CharField(default="POST")
@ -197,8 +189,10 @@ class Webhooks(BaseModel):
class Meta: class Meta:
table_name = "webhooks" table_name = "webhooks"
database = database
class Schedules(BaseModel):
class Schedules(Model):
schedule_id = IntegerField(unique=True, primary_key=True) schedule_id = IntegerField(unique=True, primary_key=True)
server_id = ForeignKeyField(Servers, backref='schedule_server') server_id = ForeignKeyField(Servers, backref='schedule_server')
enabled = BooleanField() enabled = BooleanField()
@ -211,8 +205,10 @@ class Schedules(BaseModel):
class Meta: class Meta:
table_name = 'schedules' table_name = 'schedules'
database = database
class Backups(BaseModel):
class Backups(Model):
directories = CharField(null=True) directories = CharField(null=True)
max_backups = IntegerField() max_backups = IntegerField()
server_id = ForeignKeyField(Servers, backref='backups_server') server_id = ForeignKeyField(Servers, backref='backups_server')
@ -220,39 +216,15 @@ class Backups(BaseModel):
class Meta: class Meta:
table_name = 'backups' table_name = 'backups'
database = database
class db_builder: class db_builder:
@staticmethod
def create_tables():
with database:
database.create_tables([
Backups,
Users,
Roles,
User_Roles,
User_Servers,
Host_Stats,
Webhooks,
Servers,
Role_Servers,
Server_Stats,
Commands,
Audit_Log,
SchemaVersion,
Schedules
])
@staticmethod @staticmethod
def default_settings(): def default_settings():
logger.info("Fresh Install Detected - Creating Default Settings") logger.info("Fresh Install Detected - Creating Default Settings")
console.info("Fresh Install Detected - Creating Default Settings") console.info("Fresh Install Detected - Creating Default Settings")
SchemaVersion.insert({
SchemaVersion.schema_major: schema_version[0],
SchemaVersion.schema_minor: schema_version[1],
SchemaVersion.schema_patch: schema_version[2]
}).execute()
default_data = helper.find_default_password() default_data = helper.find_default_password()
username = default_data.get("username", 'admin') username = default_data.get("username", 'admin')
@ -279,39 +251,8 @@ class db_builder:
return True return True
pass pass
@staticmethod
def check_schema_version():
svs = SchemaVersion.select().execute()
if len(svs) != 1:
raise exceptions.SchemaError("Multiple or no schema versions detected - potentially a failed upgrade?")
sv = svs[0]
svt = (sv.schema_major, sv.schema_minor, sv.schema_patch)
logger.debug("Schema: found {}, expected {}".format(svt, schema_version))
console.debug("Schema: found {}, expected {}".format(svt, schema_version))
if sv.schema_major > schema_version[0]:
raise exceptions.SchemaError("Major version mismatch - possible code reversion")
elif sv.schema_major < schema_version[0]:
db_shortcuts.upgrade_schema()
if sv.schema_minor > schema_version[1]:
logger.warning("Schema minor mismatch detected: found {}, expected {}. Proceed with caution".format(svt, schema_version))
console.warning("Schema minor mismatch detected: found {}, expected {}. Proceed with caution".format(svt, schema_version))
elif sv.schema_minor < schema_version[1]:
db_shortcuts.upgrade_schema()
if sv.schema_patch > schema_version[2]:
logger.info("Schema patch mismatch detected: found {}, expected {}. Proceed with caution".format(svt, schema_version))
console.info("Schema patch mismatch detected: found {}, expected {}. Proceed with caution".format(svt, schema_version))
elif sv.schema_patch < schema_version[2]:
db_shortcuts.upgrade_schema()
logger.info("Schema validation successful! {}".format(schema_version))
class db_shortcuts: class db_shortcuts:
@staticmethod
def upgrade_schema():
raise NotImplemented("I don't know who you are or how you reached this code, but this should NOT have happened. Please report it to the developer with due haste.")
@staticmethod @staticmethod
def return_rows(query): def return_rows(query):
rows = [] rows = []

View File

@ -0,0 +1,215 @@
import peewee
import datetime
def migrate(migrator, database, **kwargs):
db = database
class Users(peewee.Model):
user_id = peewee.AutoField()
created = peewee.DateTimeField(default=datetime.datetime.now)
last_login = peewee.DateTimeField(default=datetime.datetime.now)
last_update = peewee.DateTimeField(default=datetime.datetime.now)
last_ip = peewee.CharField(default="")
username = peewee.CharField(default="", unique=True, index=True)
password = peewee.CharField(default="")
enabled = peewee.BooleanField(default=True)
superuser = peewee.BooleanField(default=False)
# we may need to revisit this
api_token = peewee.CharField(default="", unique=True, index=True)
class Meta:
table_name = "users"
database = db
class Roles(peewee.Model):
role_id = peewee.AutoField()
created = peewee.DateTimeField(default=datetime.datetime.now)
last_update = peewee.DateTimeField(default=datetime.datetime.now)
role_name = peewee.CharField(default="", unique=True, index=True)
class Meta:
table_name = "roles"
database = db
class User_Roles(peewee.Model):
user_id = peewee.ForeignKeyField(Users, backref='user_role')
role_id = peewee.ForeignKeyField(Roles, backref='user_role')
class Meta:
table_name = 'user_roles'
primary_key = peewee.CompositeKey('user_id', 'role_id')
database = db
class Audit_Log(peewee.Model):
audit_id = peewee.AutoField()
created = peewee.DateTimeField(default=datetime.datetime.now)
user_name = peewee.CharField(default="")
user_id = peewee.IntegerField(default=0, index=True)
source_ip = peewee.CharField(default='127.0.0.1')
# When auditing global events, use server ID 0
server_id = peewee.IntegerField(default=None, index=True)
log_msg = peewee.TextField(default='')
class Meta:
database = db
class Host_Stats(peewee.Model):
time = peewee.DateTimeField(default=datetime.datetime.now, index=True)
boot_time = peewee.CharField(default="")
cpu_usage = peewee.FloatField(default=0)
cpu_cores = peewee.IntegerField(default=0)
cpu_cur_freq = peewee.FloatField(default=0)
cpu_max_freq = peewee.FloatField(default=0)
mem_percent = peewee.FloatField(default=0)
mem_usage = peewee.CharField(default="")
mem_total = peewee.CharField(default="")
disk_json = peewee.TextField(default="")
class Meta:
table_name = "host_stats"
database = db
class Servers(peewee.Model):
server_id = peewee.AutoField()
created = peewee.DateTimeField(default=datetime.datetime.now)
server_uuid = peewee.CharField(default="", index=True)
server_name = peewee.CharField(default="Server", index=True)
path = peewee.CharField(default="")
backup_path = peewee.CharField(default="")
executable = peewee.CharField(default="")
log_path = peewee.CharField(default="")
execution_command = peewee.CharField(default="")
auto_start = peewee.BooleanField(default=0)
auto_start_delay = peewee.IntegerField(default=10)
crash_detection = peewee.BooleanField(default=0)
stop_command = peewee.CharField(default="stop")
executable_update_url = peewee.CharField(default="")
server_ip = peewee.CharField(default="127.0.0.1")
server_port = peewee.IntegerField(default=25565)
logs_delete_after = peewee.IntegerField(default=0)
class Meta:
table_name = "servers"
database = db
class User_Servers(peewee.Model):
user_id = peewee.ForeignKeyField(Users, backref='user_server')
server_id = peewee.ForeignKeyField(Servers, backref='user_server')
class Meta:
table_name = 'user_servers'
primary_key = peewee.CompositeKey('user_id', 'server_id')
database = db
class Role_Servers(peewee.Model):
role_id = peewee.ForeignKeyField(Roles, backref='role_server')
server_id = peewee.ForeignKeyField(Servers, backref='role_server')
class Meta:
table_name = 'role_servers'
primary_key = peewee.CompositeKey('role_id', 'server_id')
database = db
class Server_Stats(peewee.Model):
stats_id = peewee.AutoField()
created = peewee.DateTimeField(default=datetime.datetime.now)
server_id = peewee.ForeignKeyField(Servers, backref='server', index=True)
started = peewee.CharField(default="")
running = peewee.BooleanField(default=False)
cpu = peewee.FloatField(default=0)
mem = peewee.FloatField(default=0)
mem_percent = peewee.FloatField(default=0)
world_name = peewee.CharField(default="")
world_size = peewee.CharField(default="")
server_port = peewee.IntegerField(default=25565)
int_ping_results = peewee.CharField(default="")
online = peewee.IntegerField(default=0)
max = peewee.IntegerField(default=0)
players = peewee.CharField(default="")
desc = peewee.CharField(default="Unable to Connect")
version = peewee.CharField(default="")
updating = peewee.BooleanField(default=False)
class Meta:
table_name = "server_stats"
database = db
class Commands(peewee.Model):
command_id = peewee.AutoField()
created = peewee.DateTimeField(default=datetime.datetime.now)
server_id = peewee.ForeignKeyField(Servers, backref='server', index=True)
user = peewee.ForeignKeyField(Users, backref='user', index=True)
source_ip = peewee.CharField(default='127.0.0.1')
command = peewee.CharField(default='')
executed = peewee.BooleanField(default=False)
class Meta:
table_name = "commands"
database = db
class Webhooks(peewee.Model):
id = peewee.AutoField()
name = peewee.CharField(max_length=64, unique=True, index=True)
method = peewee.CharField(default="POST")
url = peewee.CharField(unique=True)
event = peewee.CharField(default="")
send_data = peewee.BooleanField(default=True)
class Meta:
table_name = "webhooks"
database = db
class Schedules(peewee.Model):
schedule_id = peewee.IntegerField(unique=True, primary_key=True)
server_id = peewee.ForeignKeyField(Servers, backref='schedule_server')
enabled = peewee.BooleanField()
action = peewee.CharField()
interval = peewee.IntegerField()
interval_type = peewee.CharField()
start_time = peewee.CharField(null=True)
command = peewee.CharField(null=True)
comment = peewee.CharField()
class Meta:
table_name = 'schedules'
database = db
class Backups(peewee.Model):
directories = peewee.CharField(null=True)
max_backups = peewee.IntegerField()
server_id = peewee.ForeignKeyField(Servers, backref='backups_server')
schedule_id = peewee.ForeignKeyField(Schedules, backref='backups_schedule')
class Meta:
table_name = 'backups'
database = db
migrator.create_table(Backups)
migrator.create_table(Users)
migrator.create_table(Roles)
migrator.create_table(User_Roles)
migrator.create_table(User_Servers)
migrator.create_table(Host_Stats)
migrator.create_table(Webhooks)
migrator.create_table(Servers)
migrator.create_table(Role_Servers)
migrator.create_table(Server_Stats)
migrator.create_table(Commands)
migrator.create_table(Audit_Log)
migrator.create_table(Schedules)
def rollback(migrator, database, **kwargs):
migrator.drop_table('users')
migrator.drop_table('roles')
migrator.drop_table('user_roles')
migrator.drop_table('audit_log') # ? Not 100% sure of the table name, please specify in the schema
migrator.drop_table('host_stats')
migrator.drop_table('servers')
migrator.drop_table('user_servers')
migrator.drop_table('role_servers')
migrator.drop_table('server_stats')
migrator.drop_table('commands')
migrator.drop_table('webhooks')
migrator.drop_table('schedules')
migrator.drop_table('backups')

11
main.py
View File

@ -8,10 +8,11 @@ import logging.config
""" Our custom classes / pip packages """ """ Our custom classes / pip packages """
from app.classes.shared.console import console from app.classes.shared.console import console
from app.classes.shared.helpers import helper from app.classes.shared.helpers import helper
from app.classes.shared.models import installer from app.classes.shared.models import installer, database
from app.classes.shared.tasks import TasksManager from app.classes.shared.tasks import TasksManager
from app.classes.shared.controller import Controller from app.classes.shared.controller import Controller
from app.classes.shared.migration import MigrationManager
from app.classes.shared.cmd import MainPrompt from app.classes.shared.cmd import MainPrompt
@ -90,16 +91,18 @@ if __name__ == '__main__':
# our session file, helps prevent multiple controller agents on the same machine. # our session file, helps prevent multiple controller agents on the same machine.
helper.create_session_file(ignore=args.ignore) helper.create_session_file(ignore=args.ignore)
migration_manager = MigrationManager(database)
migration_manager.up() # Automatically runs migrations
# do our installer stuff # do our installer stuff
fresh_install = installer.is_fresh_install() fresh_install = installer.is_fresh_install()
if fresh_install: if fresh_install:
console.debug("Fresh install detected") console.debug("Fresh install detected")
installer.create_tables()
installer.default_settings() installer.default_settings()
else: else:
console.debug("Existing install detected") console.debug("Existing install detected")
installer.check_schema_version()
# now the tables are created, we can load the tasks_manger and server controller # now the tables are created, we can load the tasks_manger and server controller
controller = Controller() controller = Controller()
@ -127,7 +130,7 @@ if __name__ == '__main__':
# this should always be last # this should always be last
tasks_manager.start_main_kill_switch_watcher() tasks_manager.start_main_kill_switch_watcher()
Crafty = MainPrompt(tasks_manager) Crafty = MainPrompt(tasks_manager, migration_manager)
if not args.daemon: if not args.daemon:
Crafty.cmdloop() Crafty.cmdloop()
else: else:

View File

@ -23,4 +23,3 @@ termcolor==1.1.0
tornado==6.0.4 tornado==6.0.4
urllib3==1.25.10 urllib3==1.25.10
webencodings==0.5.1 webencodings==0.5.1
peewee_migrate==1.4.6