from datetime import datetime
import logging
import typing as t
import sys
import os
import re
from importlib import import_module
from functools import wraps

try:
    from functools import cached_property
except ImportError:
    from cached_property import cached_property

from app.classes.shared.helpers import helper
from app.classes.shared.console import console

logger = logging.getLogger(__name__)

try:
    import peewee
    from playhouse.migrate import (
        SchemaMigrator as ScM,
        SqliteMigrator as SqM,
        Operation, SQL, operation, SqliteDatabase,
        make_index_name, Context
    )

except ModuleNotFoundError as e:
    logger.critical("Import Error: Unable to load {} module".format(
        e.name), exc_info=True)
    console.critical("Import Error: Unable to load {} module".format(e.name))
    sys.exit(1)


class MigrateHistory(peewee.Model):
    """
    Presents the migration history in a database.
    """

    name = peewee.CharField(unique=True)
    migrated_at = peewee.DateTimeField(default=datetime.utcnow)

    def __unicode__(self) -> str:
        """
        String representation of this migration
        """
        return self.name


MIGRATE_TABLE = 'migratehistory'
MIGRATE_TEMPLATE = '''# Generated by database migrator


def migrate(migrator, database, **kwargs):
    """
    Write your migrations here.
    """
{migrate}


def rollback(migrator, database, **kwargs):
    """
    Write your rollback migrations here.
    """
{rollback}'''
VOID: t.Callable = lambda m, d: None


def get_model(method):
    """
    Convert string to model class.
    """

    @wraps(method)
    def wrapper(migrator, model, *args, **kwargs):
        if isinstance(model, str):
            return method(migrator, migrator.orm[model], *args, **kwargs)
        return method(migrator, model, *args, **kwargs)
    return wrapper


class Migrator(object):
    def __init__(self, database: t.Union[peewee.Database, peewee.Proxy]):
        """
        Initializes the migrator
        """
        if isinstance(database, peewee.Proxy):
            database = database.obj
        self.database: SqliteDatabase = database
        self.orm: t.Dict[str, peewee.Model] = {}
        self.operations: t.List[Operation] = []
        self.migrator = SqliteMigrator(database)

    def run(self):
        """
        Runs operations.
        """
        for op in self.operations:
            if isinstance(op, Operation):
                op.run()
            else:
                op()
        self.clean()

    def clean(self):
        """
        Cleans the operations.
        """
        self.operations = list()

    def sql(self, sql: str, *params):
        """
        Executes raw SQL.
        """
        self.operations.append(self.migrator.sql(sql, *params))

    def create_table(self, model: peewee.Model) -> peewee.Model:
        """
        Creates model and table in database.
        """
        self.orm[model._meta.table_name] = model
        model._meta.database = self.database
        self.operations.append(model.create_table)
        return model

    @get_model
    def drop_table(self, model: peewee.Model):
        """
        Drops model and table from database.
        """
        del self.orm[model._meta.table_name]
        self.operations.append(self.migrator.drop_table(model))

    @get_model
    def add_columns(self, model: peewee.Model, **fields: peewee.Field) -> peewee.Model:
        """
        Creates new fields.
        """
        for name, field in fields.items():
            model._meta.add_field(name, field)
            self.operations.append(self.migrator.add_column(
                model._meta.table_name, field.column_name, field))
            if field.unique:
                self.operations.append(self.migrator.add_index(
                    model._meta.table_name, (field.column_name,), unique=True))
        return model

    @get_model
    def change_columns(self, model: peewee.Model, **fields: peewee.Field) -> peewee.Model:
        """
        Changes fields.
        """
        for name, field in fields.items():
            old_field = model._meta.fields.get(name, field)
            old_column_name = old_field and old_field.column_name

            model._meta.add_field(name, field)

            if isinstance(old_field, peewee.ForeignKeyField):
                self.operations.append(self.migrator.drop_foreign_key_constraint(
                    model._meta.table_name, old_column_name))

            if old_column_name != field.column_name:
                self.operations.append(
                    self.migrator.rename_column(
                        model._meta.table_name, old_column_name, field.column_name))

            if isinstance(field, peewee.ForeignKeyField):
                on_delete = field.on_delete if field.on_delete else 'RESTRICT'
                on_update = field.on_update if field.on_update else 'RESTRICT'
                self.operations.append(self.migrator.add_foreign_key_constraint(
                    model._meta.table_name, field.column_name,
                    field.rel_model._meta.table_name, field.rel_field.name,
                    on_delete, on_update))
                continue

            self.operations.append(self.migrator.change_column(
                model._meta.table_name, field.column_name, field))

            if field.unique == old_field.unique:
                continue

            if field.unique:
                index = (field.column_name,), field.unique
                self.operations.append(self.migrator.add_index(
                    model._meta.table_name, *index))
                model._meta.indexes.append(index)
            else:
                index = (field.column_name,), old_field.unique
                self.operations.append(self.migrator.drop_index(
                    model._meta.table_name, *index))
                model._meta.indexes.remove(index)

        return model

    @get_model
    def drop_columns(self, model: peewee.Model, names: str, **kwargs) -> peewee.Model:
        """
        Removes fields from model.
        """
        fields = [field for field in model._meta.fields.values()
                  if field.name in names]
        cascade = kwargs.pop('cascade', True)
        for field in fields:
            self.__del_field__(model, field)
            if field.unique:
                index_name = make_index_name(
                    model._meta.table_name, [field.column_name])
                self.operations.append(self.migrator.drop_index(
                    model._meta.table_name, index_name))
            self.operations.append(
                self.migrator.drop_column(
                    model._meta.table_name, field.column_name, cascade=False))
        return model

    def __del_field__(self, model: peewee.Model, field: peewee.Field):
        """
        Deletes field from model.
        """
        model._meta.remove_field(field.name)
        delattr(model, field.name)
        if isinstance(field, peewee.ForeignKeyField):
            obj_id_name = field.column_name
            if field.column_name == field.name:
                obj_id_name += '_id'
            delattr(model, obj_id_name)
            delattr(field.rel_model, field.backref)

    @get_model
    def rename_column(self, model: peewee.Model, old_name: str, new_name: str) -> peewee.Model:
        """
        Renames field in model.
        """
        field = model._meta.fields[old_name]
        if isinstance(field, peewee.ForeignKeyField):
            old_name = field.column_name
        self.__del_field__(model, field)
        field.name = field.column_name = new_name
        model._meta.add_field(new_name, field)
        if isinstance(field, peewee.ForeignKeyField):
            field.column_name = new_name = field.column_name + '_id'
        self.operations.append(self.migrator.rename_column(
            model._meta.table_name, old_name, new_name))
        return model

    @get_model
    def rename_table(self, model: peewee.Model, new_name: str) -> peewee.Model:
        """
        Renames table in database.
        """
        old_name = model._meta.table_name
        del self.orm[model._meta.table_name]
        model._meta.table_name = new_name
        self.orm[model._meta.table_name] = model
        self.operations.append(self.migrator.rename_table(old_name, new_name))
        return model

    @get_model
    def add_index(self, model: peewee.Model, *columns: str, **kwargs) -> peewee.Model:
        """Create indexes."""
        unique = kwargs.pop('unique', False)
        model._meta.indexes.append((columns, unique))
        columns_ = []
        for col in columns:
            field = model._meta.fields.get(col)

            if len(columns) == 1:
                field.unique = unique
                field.index = not unique

            if isinstance(field, peewee.ForeignKeyField):
                col = col + '_id'

            columns_.append(col)
        self.operations.append(self.migrator.add_index(
            model._meta.table_name, columns_, unique=unique))
        return model

    @get_model
    def drop_index(self, model: peewee.Model, *columns: str) -> peewee.Model:
        """Drop indexes."""
        columns_ = []
        for col in columns:
            field = model._meta.fields.get(col)
            if not field:
                continue

            if len(columns) == 1:
                field.unique = field.index = False

            if isinstance(field, peewee.ForeignKeyField):
                col = col + '_id'
            columns_.append(col)
        index_name = make_index_name(model._meta.table_name, columns_)
        model._meta.indexes = [(cols, _) for (
            cols, _) in model._meta.indexes if columns != cols]
        self.operations.append(self.migrator.drop_index(
            model._meta.table_name, index_name))
        return model

    @get_model
    def add_not_null(self, model: peewee.Model, *names: str) -> peewee.Model:
        """Add not null."""
        for name in names:
            field = model._meta.fields[name]
            field.null = False
            self.operations.append(self.migrator.add_not_null(
                model._meta.table_name, field.column_name))
        return model

    @get_model
    def drop_not_null(self, model: peewee.Model, *names: str) -> peewee.Model:
        """Drop not null."""
        for name in names:
            field = model._meta.fields[name]
            field.null = True
            self.operations.append(self.migrator.drop_not_null(
                model._meta.table_name, field.column_name))
        return model

    @get_model
    def add_default(self, model: peewee.Model, name: str, default: t.Any) -> peewee.Model:
        """Add default."""
        field = model._meta.fields[name]
        model._meta.defaults[field] = field.default = default
        self.operations.append(self.migrator.apply_default(
            model._meta.table_name, name, field))
        return model


class SqliteMigrator(SqM):
    def drop_table(self, model):
        return lambda: model.drop_table(cascade=False)

    @operation
    def change_column(self, table: str, column_name: str, field: peewee.Field):
        operations = [self.alter_change_column(table, column_name, field)]
        if not field.null:
            operations.extend([self.add_not_null(table, column_name)])
        return operations

    def alter_change_column(self, table: str, column_name: str, field: peewee.Field) -> Operation:
        return self._update_column(table, column_name, lambda x, y: y)

    @operation
    def sql(self, sql: str, *params) -> SQL:
        """
        Executes raw SQL.
        """
        return SQL(sql, *params)

    def alter_add_column(
            self, table: str, column_name: str, field: peewee.Field, **kwargs) -> Operation:
        """
        Fixes field name for ForeignKeys.
        """
        name = field.name
        op = super().alter_add_column(
            table, column_name, field, **kwargs)
        if isinstance(field, peewee.ForeignKeyField):
            field.name = name
        return op


class MigrationManager(object):

    filemask = re.compile(r"[\d]+_[^\.]+\.py$")

    def __init__(self, database: t.Union[peewee.Database, peewee.Proxy]):
        """
        Initializes the migration manager.
        """
        if not isinstance(database, (peewee.Database, peewee.Proxy)):
            raise RuntimeError('Invalid database: {}'.format(database))
        self.database = database

    @cached_property
    def model(self) -> peewee.Model:
        """
        Initialize and cache the MigrationHistory model.
        """
        MigrateHistory._meta.database = self.database
        MigrateHistory._meta.table_name = 'migratehistory'
        MigrateHistory._meta.schema = None
        MigrateHistory.create_table(True)
        return MigrateHistory

    @property
    def done(self) -> t.List[str]:
        """
        Scans migrations in the database.
        """
        return [mm.name for mm in self.model.select().order_by(self.model.id)]

    @property
    def todo(self):
        """
        Scans migrations in the file system.
        """
        if not os.path.exists(helper.migration_dir):
            logger.warning('Migration directory: {} does not exist.'.format(
                helper.migration_dir))
            os.makedirs(helper.migration_dir)
        return sorted(f[:-3] for f in os.listdir(helper.migration_dir) if self.filemask.match(f))

    @property
    def diff(self) -> t.List[str]:
        """
        Calculates difference between the filesystem and the database.
        """
        done = set(self.done)
        return [name for name in self.todo if name not in done]

    @cached_property
    def migrator(self) -> Migrator:
        """
        Create migrator and setup it with fake migrations.
        """
        migrator = Migrator(self.database)
        for name in self.done:
            self.up_one(name, migrator, True)
        return migrator

    def compile(self, name, migrate='', rollback=''):
        """
        Compiles a migration.
        """
        name = datetime.utcnow().strftime('%Y%m%d%H%M%S') + '_' + name
        filename = name + '.py'
        path = os.path.join(helper.migration_dir, filename)
        with open(path, 'w') as f:
            f.write(MIGRATE_TEMPLATE.format(
                migrate=migrate, rollback=rollback, name=filename))

        return name

    def create(self, name: str = 'auto', auto: bool = False) -> t.Optional[str]:
        """
        Creates a migration.
        """
        migrate = rollback = ''
        if auto:
            raise NotImplementedError

        logger.info('Creating migration "{}"'.format(name))
        name = self.compile(name, migrate, rollback)
        logger.info('Migration has been created as "{}"'.format(name))
        return name

    def clear(self):
        """Clear migrations."""
        self.model.delete().execute()

    def up(self, name: t.Optional[str] = None):
        """
        Runs all unapplied migrations.
        """
        logger.info('Starting migrations')
        console.info('Starting migrations')

        done = []
        diff = self.diff
        if not diff:
            logger.info('There is nothing to migrate')
            console.info('There is nothing to migrate')
            return done

        migrator = self.migrator
        for mname in diff:
            done.append(self.up_one(mname, self.migrator))
            if name and name == mname:
                break

        return done

    def read(self, name: str):
        """
        Reads a migration from a file.
        """
        call_params = dict()
        if helper.is_os_windows() and sys.version_info >= (3, 0):
            # if system is windows - force utf-8 encoding
            call_params['encoding'] = 'utf-8'
        with open(os.path.join(helper.migration_dir, name + '.py'), **call_params) as f:
            code = f.read()
            scope = {}
            code = compile(code, '<string>', 'exec', dont_inherit=True)
            exec(code, scope, None)
            return scope.get('migrate', VOID), scope.get('rollback', VOID)

    def up_one(self, name: str, migrator: Migrator,
               fake: bool = False, rollback: bool = False) -> str:
        """
        Runs a migration with a given name.
        """
        try:
            migrate_fn, rollback_fn = self.read(name)
            if fake:
                migrate_fn(migrator, self.database)
                migrator.clean()
                return name
            with self.database.transaction():
                if rollback:
                    logger.info('Rolling back "{}"'.format(name))
                    rollback_fn(migrator, self.database)
                    migrator.run()
                    self.model.delete().where(self.model.name == name).execute()
                else:
                    logger.info('Migrate "{}"'.format(name))
                    migrate_fn(migrator, self.database)
                    migrator.run()
                    if name not in self.done:
                        self.model.create(name=name)

                logger.info('Done "{}"'.format(name))
                return name

        except Exception:
            self.database.rollback()
            operation = 'Rollback' if rollback else 'Migration'
            logger.exception('{} failed: {}'.format(operation, name))
            raise

    def down(self, name: t.Optional[str] = None):
        """
        Rolls back migrations.
        """
        if not self.done:
            raise RuntimeError('No migrations are found.')

        name = self.done[-1]

        migrator = self.migrator
        self.up_one(name, migrator, False, True)
        logger.warning('Rolled back migration: {}'.format(name))