[CI] Enable python autoformat (#6169)

* Squashed commit of the following:

commit f5cf7b2e78
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:36:57 2024 +0100

    fixed reqs

commit 9d845bee98
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:32:35 2024 +0100

    disable autofix/format

commit aff5f27148
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:50 2024 +0100

    adjust checks

commit 47271cf1ef
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:22 2024 +0100

    reorder order of operations

commit e1bf178b40
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:01:09 2024 +0100

    adapted ruff settings to better fit code base

commit ad7d88a6f4
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:59:45 2024 +0100

    auto fixed docstring

commit a2e54a760e
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:46:35 2024 +0100

    fix getattr useage

commit cb80c73bc6
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:25:09 2024 +0100

    fix requirements file

commit b7780bbd21
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:42:28 2024 +0100

    fix removed sections

commit 71f1681f55
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:41:21 2024 +0100

    fix djlint syntax

commit a0bcf1bcce
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:35:28 2024 +0100

    remove flake8 from code base

commit 22475b31cc
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:34:56 2024 +0100

    remove flake8 from code base

commit 0413350f14
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:39 2024 +0100

    moved ruff section

commit d90c48a0bf
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:24 2024 +0100

    move djlint config to pyproject

commit c5ce55d511
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:20:39 2024 +0100

    added isort again

commit 42a41d23af
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:19:02 2024 +0100

    move config section

commit 8569233181
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:17:52 2024 +0100

    fix codespell error

commit 2897c6704d
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 17:29:21 2024 +0100

    replaced flake8 with ruff
    mostly for speed improvements

* enable autoformat

* added autofixes

* switched to single quotes everywhere

* switched to ruff for import sorting

* fix wrong url response

* switched to pathlib for lookup

* fixed lookup

* Squashed commit of the following:

commit d3b795824b
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:56:17 2024 +0100

    fixed source path

commit 0bac0c19b8
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:47:53 2024 +0100

    fixed req

commit 9f61f01d9c
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:45:18 2024 +0100

    added missing toml req

commit 91b71ed24a
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:49:50 2024 +0100

    moved isort config

commit 12460b0419
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:43:22 2024 +0100

    remove flake8 section from setup.cfg

commit f5cf7b2e78
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:36:57 2024 +0100

    fixed reqs

commit 9d845bee98
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:32:35 2024 +0100

    disable autofix/format

commit aff5f27148
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:50 2024 +0100

    adjust checks

commit 47271cf1ef
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:22 2024 +0100

    reorder order of operations

commit e1bf178b40
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:01:09 2024 +0100

    adapted ruff settings to better fit code base

commit ad7d88a6f4
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:59:45 2024 +0100

    auto fixed docstring

commit a2e54a760e
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:46:35 2024 +0100

    fix getattr useage

commit cb80c73bc6
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:25:09 2024 +0100

    fix requirements file

commit b7780bbd21
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:42:28 2024 +0100

    fix removed sections

commit 71f1681f55
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:41:21 2024 +0100

    fix djlint syntax

commit a0bcf1bcce
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:35:28 2024 +0100

    remove flake8 from code base

commit 22475b31cc
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:34:56 2024 +0100

    remove flake8 from code base

commit 0413350f14
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:39 2024 +0100

    moved ruff section

commit d90c48a0bf
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:24 2024 +0100

    move djlint config to pyproject

commit c5ce55d511
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:20:39 2024 +0100

    added isort again

commit 42a41d23af
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:19:02 2024 +0100

    move config section

commit 8569233181
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:17:52 2024 +0100

    fix codespell error

commit 2897c6704d
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 17:29:21 2024 +0100

    replaced flake8 with ruff
    mostly for speed improvements

* fix coverage souce format

---------

Co-authored-by: Oliver Walters <oliver.henry.walters@gmail.com>
This commit is contained in:
Matthias Mair 2024-01-11 01:28:58 +01:00 committed by GitHub
parent 9715af564f
commit 4b14986591
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
257 changed files with 13422 additions and 12200 deletions

View File

@ -15,18 +15,14 @@ repos:
- id: end-of-file-fixer - id: end-of-file-fixer
- id: check-yaml - id: check-yaml
- id: mixed-line-ending - id: mixed-line-ending
- repo: https://github.com/pycqa/isort
rev: '5.12.0'
hooks:
- id: isort
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11 rev: v0.1.11
hooks: hooks:
# - id: ruff-format - id: ruff-format
# args: [--preview] args: [--preview]
- id: ruff - id: ruff
args: [ args: [
#--fix, --fix,
--preview --preview
] ]
- repo: https://github.com/jazzband/pip-tools - repo: https://github.com/jazzband/pip-tools

View File

@ -30,14 +30,18 @@ class InvenTreeResource(ModelResource):
using_transactions, using_transactions,
collect_failed_rows, collect_failed_rows,
rollback_on_validation_errors=None, rollback_on_validation_errors=None,
**kwargs **kwargs,
): ):
"""Override the default import_data_inner function to provide better error handling""" """Override the default import_data_inner function to provide better error handling"""
if len(dataset) > self.MAX_IMPORT_ROWS: if len(dataset) > self.MAX_IMPORT_ROWS:
raise ImportExportError(f"Dataset contains too many rows (max {self.MAX_IMPORT_ROWS})") raise ImportExportError(
f'Dataset contains too many rows (max {self.MAX_IMPORT_ROWS})'
)
if len(dataset.headers) > self.MAX_IMPORT_COLS: if len(dataset.headers) > self.MAX_IMPORT_COLS:
raise ImportExportError(f"Dataset contains too many columns (max {self.MAX_IMPORT_COLS})") raise ImportExportError(
f'Dataset contains too many columns (max {self.MAX_IMPORT_COLS})'
)
return super().import_data_inner( return super().import_data_inner(
dataset, dataset,
@ -46,7 +50,7 @@ class InvenTreeResource(ModelResource):
using_transactions, using_transactions,
collect_failed_rows, collect_failed_rows,
rollback_on_validation_errors=rollback_on_validation_errors, rollback_on_validation_errors=rollback_on_validation_errors,
**kwargs **kwargs,
) )
def export_resource(self, obj): def export_resource(self, obj):
@ -75,10 +79,7 @@ class InvenTreeResource(ModelResource):
"""Return fields, with some common exclusions""" """Return fields, with some common exclusions"""
fields = super().get_fields(**kwargs) fields = super().get_fields(**kwargs)
fields_to_exclude = [ fields_to_exclude = ['metadata', 'lft', 'rght', 'tree_id', 'level']
'metadata',
'lft', 'rght', 'tree_id', 'level',
]
return [f for f in fields if f.column_name not in fields_to_exclude] return [f for f in fields if f.column_name not in fields_to_exclude]

View File

@ -31,9 +31,7 @@ from .views import AjaxView
class VersionView(APIView): class VersionView(APIView):
"""Simple JSON endpoint for InvenTree version information.""" """Simple JSON endpoint for InvenTree version information."""
permission_classes = [ permission_classes = [permissions.IsAdminUser]
permissions.IsAdminUser,
]
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Return information about the InvenTree server.""" """Return information about the InvenTree server."""
@ -47,20 +45,21 @@ class VersionView(APIView):
'commit_date': InvenTree.version.inventreeCommitDate(), 'commit_date': InvenTree.version.inventreeCommitDate(),
'commit_branch': InvenTree.version.inventreeBranch(), 'commit_branch': InvenTree.version.inventreeBranch(),
'python': InvenTree.version.inventreePythonVersion(), 'python': InvenTree.version.inventreePythonVersion(),
'django': InvenTree.version.inventreeDjangoVersion() 'django': InvenTree.version.inventreeDjangoVersion(),
}, },
'links': { 'links': {
'doc': InvenTree.version.inventreeDocUrl(), 'doc': InvenTree.version.inventreeDocUrl(),
'code': InvenTree.version.inventreeGithubUrl(), 'code': InvenTree.version.inventreeGithubUrl(),
'credit': InvenTree.version.inventreeCreditsUrl(), 'credit': InvenTree.version.inventreeCreditsUrl(),
'app': InvenTree.version.inventreeAppUrl(), 'app': InvenTree.version.inventreeAppUrl(),
'bug': f'{InvenTree.version.inventreeGithubUrl()}/issues' 'bug': f'{InvenTree.version.inventreeGithubUrl()}/issues',
} },
}) })
class VersionSerializer(serializers.Serializer): class VersionSerializer(serializers.Serializer):
"""Serializer for a single version.""" """Serializer for a single version."""
version = serializers.CharField() version = serializers.CharField()
date = serializers.CharField() date = serializers.CharField()
gh = serializers.CharField() gh = serializers.CharField()
@ -69,16 +68,19 @@ class VersionSerializer(serializers.Serializer):
class Meta: class Meta:
"""Meta class for VersionSerializer.""" """Meta class for VersionSerializer."""
fields = ['version', 'date', 'gh', 'text', 'latest'] fields = ['version', 'date', 'gh', 'text', 'latest']
class VersionApiSerializer(serializers.Serializer): class VersionApiSerializer(serializers.Serializer):
"""Serializer for the version api endpoint.""" """Serializer for the version api endpoint."""
VersionSerializer(many=True) VersionSerializer(many=True)
class VersionTextView(ListAPI): class VersionTextView(ListAPI):
"""Simple JSON endpoint for InvenTree version text.""" """Simple JSON endpoint for InvenTree version text."""
permission_classes = [permissions.IsAdminUser] permission_classes = [permissions.IsAdminUser]
@extend_schema(responses={200: OpenApiResponse(response=VersionApiSerializer)}) @extend_schema(responses={200: OpenApiResponse(response=VersionApiSerializer)})
@ -119,10 +121,10 @@ class InfoView(AjaxView):
'debug_mode': settings.DEBUG, 'debug_mode': settings.DEBUG,
'docker_mode': settings.DOCKER, 'docker_mode': settings.DOCKER,
'system_health': check_system_health() if is_staff else None, 'system_health': check_system_health() if is_staff else None,
'database': InvenTree.version.inventreeDatabase()if is_staff else None, 'database': InvenTree.version.inventreeDatabase() if is_staff else None,
'platform': InvenTree.version.inventreePlatform() if is_staff else None, 'platform': InvenTree.version.inventreePlatform() if is_staff else None,
'installer': InvenTree.version.inventreeInstaller() if is_staff else None, 'installer': InvenTree.version.inventreeInstaller() if is_staff else None,
'target': InvenTree.version.inventreeTarget()if is_staff else None, 'target': InvenTree.version.inventreeTarget() if is_staff else None,
} }
return JsonResponse(data) return JsonResponse(data)
@ -130,7 +132,9 @@ class InfoView(AjaxView):
def check_auth_header(self, request): def check_auth_header(self, request):
"""Check if user is authenticated via a token in the header.""" """Check if user is authenticated via a token in the header."""
# TODO @matmair: remove after refacgtor of Token check is done # TODO @matmair: remove after refacgtor of Token check is done
headers = request.headers.get('Authorization', request.headers.get('authorization')) headers = request.headers.get(
'Authorization', request.headers.get('authorization')
)
if not headers: if not headers:
return False return False
@ -160,7 +164,7 @@ class NotFoundView(AjaxView):
'detail': _('API endpoint not found'), 'detail': _('API endpoint not found'),
'url': request.build_absolute_uri(), 'url': request.build_absolute_uri(),
}, },
status=404 status=404,
) )
def options(self, request, *args, **kwargs): def options(self, request, *args, **kwargs):
@ -228,24 +232,25 @@ class BulkDeleteMixin:
if not items and not filters: if not items and not filters:
raise ValidationError({ raise ValidationError({
"non_field_errors": ["List of items or filters must be provided for bulk deletion"], 'non_field_errors': [
'List of items or filters must be provided for bulk deletion'
]
}) })
if items and type(items) is not list: if items and type(items) is not list:
raise ValidationError({ raise ValidationError({
"items": ["'items' must be supplied as a list object"] 'items': ["'items' must be supplied as a list object"]
}) })
if filters and type(filters) is not dict: if filters and type(filters) is not dict:
raise ValidationError({ raise ValidationError({
"filters": ["'filters' must be supplied as a dict object"] 'filters': ["'filters' must be supplied as a dict object"]
}) })
# Keep track of how many items we deleted # Keep track of how many items we deleted
n_deleted = 0 n_deleted = 0
with transaction.atomic(): with transaction.atomic():
# Start with *all* models and perform basic filtering # Start with *all* models and perform basic filtering
queryset = model.objects.all() queryset = model.objects.all()
queryset = self.filter_delete_queryset(queryset, request) queryset = self.filter_delete_queryset(queryset, request)
@ -261,16 +266,12 @@ class BulkDeleteMixin:
n_deleted = queryset.count() n_deleted = queryset.count()
queryset.delete() queryset.delete()
return Response( return Response({'success': f'Deleted {n_deleted} items'}, status=204)
{
'success': f"Deleted {n_deleted} items",
},
status=204
)
class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI): class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI):
"""Custom API endpoint which provides BulkDelete functionality in addition to List and Create""" """Custom API endpoint which provides BulkDelete functionality in addition to List and Create"""
... ...
@ -307,24 +308,17 @@ class APIDownloadMixin:
def download_queryset(self, queryset, export_format): def download_queryset(self, queryset, export_format):
"""This function must be implemented to provide a downloadFile request.""" """This function must be implemented to provide a downloadFile request."""
raise NotImplementedError("download_queryset method not implemented!") raise NotImplementedError('download_queryset method not implemented!')
class AttachmentMixin: class AttachmentMixin:
"""Mixin for creating attachment objects, and ensuring the user information is saved correctly.""" """Mixin for creating attachment objects, and ensuring the user information is saved correctly."""
permission_classes = [ permission_classes = [permissions.IsAuthenticated, RolePermission]
permissions.IsAuthenticated,
RolePermission,
]
filter_backends = SEARCH_ORDER_FILTER filter_backends = SEARCH_ORDER_FILTER
search_fields = [ search_fields = ['attachment', 'comment', 'link']
'attachment',
'comment',
'link',
]
def perform_create(self, serializer): def perform_create(self, serializer):
"""Save the user information when a file is uploaded.""" """Save the user information when a file is uploaded."""
@ -342,9 +336,7 @@ class APISearchView(APIView):
Is much more efficient and simplifies code! Is much more efficient and simplifies code!
""" """
permission_classes = [ permission_classes = [permissions.IsAuthenticated]
permissions.IsAuthenticated,
]
def get_result_types(self): def get_result_types(self):
"""Construct a list of search types we can return""" """Construct a list of search types we can return"""
@ -384,14 +376,11 @@ class APISearchView(APIView):
} }
if 'search' not in data: if 'search' not in data:
raise ValidationError({ raise ValidationError({'search': 'Search term must be provided'})
'search': 'Search term must be provided',
})
for key, cls in self.get_result_types().items(): for key, cls in self.get_result_types().items():
# Only return results which are specifically requested # Only return results which are specifically requested
if key in data: if key in data:
params = data[key] params = data[key]
for k, v in pass_through_params.items(): for k, v in pass_through_params.items():
@ -418,16 +407,18 @@ class APISearchView(APIView):
table = f'{app_label}_{model_name}' table = f'{app_label}_{model_name}'
try: try:
if users.models.RuleSet.check_table_permission(request.user, table, 'view'): if users.models.RuleSet.check_table_permission(
request.user, table, 'view'
):
results[key] = view.list(request, *args, **kwargs).data results[key] = view.list(request, *args, **kwargs).data
else: else:
results[key] = { results[key] = {
'error': _('User does not have permission to view this model') 'error': _(
'User does not have permission to view this model'
)
} }
except Exception as exc: except Exception as exc:
results[key] = { results[key] = {'error': str(exc)}
'error': str(exc)
}
return Response(results) return Response(results)
@ -442,7 +433,9 @@ class MetadataView(RetrieveUpdateAPI):
model = self.kwargs.get(self.MODEL_REF, None) model = self.kwargs.get(self.MODEL_REF, None)
if model is None: if model is None:
raise ValidationError(f"MetadataView called without '{self.MODEL_REF}' parameter") raise ValidationError(
f"MetadataView called without '{self.MODEL_REF}' parameter"
)
return model return model

View File

@ -1,6 +1,5 @@
"""InvenTree API version information.""" """InvenTree API version information."""
# InvenTree API version # InvenTree API version
INVENTREE_API_VERSION = 160 INVENTREE_API_VERSION = 160
"""Increment this API version number whenever there is a significant change to the API that any clients need to know about.""" """Increment this API version number whenever there is a significant change to the API that any clients need to know about."""

View File

@ -16,11 +16,12 @@ import InvenTree.ready
import InvenTree.tasks import InvenTree.tasks
from InvenTree.config import get_setting from InvenTree.config import get_setting
logger = logging.getLogger("inventree") logger = logging.getLogger('inventree')
class InvenTreeConfig(AppConfig): class InvenTreeConfig(AppConfig):
"""AppConfig for inventree app.""" """AppConfig for inventree app."""
name = 'InvenTree' name = 'InvenTree'
def ready(self): def ready(self):
@ -36,7 +37,10 @@ class InvenTreeConfig(AppConfig):
- Adding users set in the current environment - Adding users set in the current environment
""" """
# skip loading if plugin registry is not loaded or we run in a background thread # skip loading if plugin registry is not loaded or we run in a background thread
if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread(): if (
not InvenTree.ready.isPluginRegistryLoaded()
or not InvenTree.ready.isInMainThread()
):
return return
# Skip if running migrations # Skip if running migrations
@ -44,7 +48,6 @@ class InvenTreeConfig(AppConfig):
return return
if InvenTree.ready.canAppAccessDatabase() or settings.TESTING_ENV: if InvenTree.ready.canAppAccessDatabase() or settings.TESTING_ENV:
self.remove_obsolete_tasks() self.remove_obsolete_tasks()
self.collect_tasks() self.collect_tasks()
@ -81,11 +84,11 @@ class InvenTreeConfig(AppConfig):
try: try:
Schedule.objects.filter(func__in=obsolete).delete() Schedule.objects.filter(func__in=obsolete).delete()
except Exception: except Exception:
logger.exception("Failed to remove obsolete tasks - database not ready") logger.exception('Failed to remove obsolete tasks - database not ready')
def start_background_tasks(self): def start_background_tasks(self):
"""Start all background tests for InvenTree.""" """Start all background tests for InvenTree."""
logger.info("Starting background tasks...") logger.info('Starting background tasks...')
from django_q.models import Schedule from django_q.models import Schedule
@ -102,15 +105,16 @@ class InvenTreeConfig(AppConfig):
tasks = InvenTree.tasks.tasks.task_list tasks = InvenTree.tasks.tasks.task_list
for task in tasks: for task in tasks:
ref_name = f'{task.func.__module__}.{task.func.__name__}' ref_name = f'{task.func.__module__}.{task.func.__name__}'
if ref_name in existing_tasks.keys(): if ref_name in existing_tasks.keys():
# This task already exists - update the details if required # This task already exists - update the details if required
existing_task = existing_tasks[ref_name] existing_task = existing_tasks[ref_name]
if existing_task.schedule_type != task.interval or existing_task.minutes != task.minutes: if (
existing_task.schedule_type != task.interval
or existing_task.minutes != task.minutes
):
existing_task.schedule_type = task.interval existing_task.schedule_type = task.interval
existing_task.minutes = task.minutes existing_task.minutes = task.minutes
tasks_to_update.append(existing_task) tasks_to_update.append(existing_task)
@ -128,20 +132,17 @@ class InvenTreeConfig(AppConfig):
if len(tasks_to_create) > 0: if len(tasks_to_create) > 0:
Schedule.objects.bulk_create(tasks_to_create) Schedule.objects.bulk_create(tasks_to_create)
logger.info("Created %s new scheduled tasks", len(tasks_to_create)) logger.info('Created %s new scheduled tasks', len(tasks_to_create))
if len(tasks_to_update) > 0: if len(tasks_to_update) > 0:
Schedule.objects.bulk_update(tasks_to_update, ['schedule_type', 'minutes']) Schedule.objects.bulk_update(tasks_to_update, ['schedule_type', 'minutes'])
logger.info("Updated %s existing scheduled tasks", len(tasks_to_update)) logger.info('Updated %s existing scheduled tasks', len(tasks_to_update))
# Put at least one task onto the background worker stack, # Put at least one task onto the background worker stack,
# which will be processed as soon as the worker comes online # which will be processed as soon as the worker comes online
InvenTree.tasks.offload_task( InvenTree.tasks.offload_task(InvenTree.tasks.heartbeat, force_async=True)
InvenTree.tasks.heartbeat,
force_async=True,
)
logger.info("Started %s scheduled background tasks...", len(tasks)) logger.info('Started %s scheduled background tasks...', len(tasks))
def collect_tasks(self): def collect_tasks(self):
"""Collect all background tasks.""" """Collect all background tasks."""
@ -153,7 +154,7 @@ class InvenTreeConfig(AppConfig):
try: try:
import_module(f'{app.module.__package__}.tasks') import_module(f'{app.module.__package__}.tasks')
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.exception("Error loading tasks for %s: %s", app_name, e) logger.exception('Error loading tasks for %s: %s', app_name, e)
def update_exchange_rates(self): # pragma: no cover def update_exchange_rates(self): # pragma: no cover
"""Update exchange rates each time the server is started. """Update exchange rates each time the server is started.
@ -184,16 +185,20 @@ class InvenTreeConfig(AppConfig):
if last_update is None: if last_update is None:
# Never been updated # Never been updated
logger.info("Exchange backend has never been updated") logger.info('Exchange backend has never been updated')
update = True update = True
# Backend currency has changed? # Backend currency has changed?
if base_currency != backend.base_currency: if base_currency != backend.base_currency:
logger.info("Base currency changed from %s to %s", backend.base_currency, base_currency) logger.info(
'Base currency changed from %s to %s',
backend.base_currency,
base_currency,
)
update = True update = True
except (ExchangeBackend.DoesNotExist): except ExchangeBackend.DoesNotExist:
logger.info("Exchange backend not found - updating") logger.info('Exchange backend not found - updating')
update = True update = True
except Exception: except Exception:
@ -204,9 +209,9 @@ class InvenTreeConfig(AppConfig):
try: try:
update_exchange_rates() update_exchange_rates()
except OperationalError: except OperationalError:
logger.warning("Could not update exchange rates - database not ready") logger.warning('Could not update exchange rates - database not ready')
except Exception as e: except Exception as e:
logger.exception("Error updating exchange rates: %s (%s)", e, type(e)) logger.exception('Error updating exchange rates: %s (%s)', e, type(e))
def add_user_on_startup(self): def add_user_on_startup(self):
"""Add a user on startup.""" """Add a user on startup."""
@ -218,7 +223,9 @@ class InvenTreeConfig(AppConfig):
add_user = get_setting('INVENTREE_ADMIN_USER', 'admin_user') add_user = get_setting('INVENTREE_ADMIN_USER', 'admin_user')
add_email = get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email') add_email = get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email')
add_password = get_setting('INVENTREE_ADMIN_PASSWORD', 'admin_password') add_password = get_setting('INVENTREE_ADMIN_PASSWORD', 'admin_password')
add_password_file = get_setting("INVENTREE_ADMIN_PASSWORD_FILE", "admin_password_file", None) add_password_file = get_setting(
'INVENTREE_ADMIN_PASSWORD_FILE', 'admin_password_file', None
)
# check if all values are present # check if all values are present
set_variables = 0 set_variables = 0
@ -239,7 +246,9 @@ class InvenTreeConfig(AppConfig):
# if a password file is present, do not warn - will be handled later # if a password file is present, do not warn - will be handled later
if add_password_file: if add_password_file:
return return
logger.warning('Not all required settings for adding a user on startup are present:\nINVENTREE_ADMIN_USER, INVENTREE_ADMIN_EMAIL, INVENTREE_ADMIN_PASSWORD') logger.warning(
'Not all required settings for adding a user on startup are present:\nINVENTREE_ADMIN_USER, INVENTREE_ADMIN_EMAIL, INVENTREE_ADMIN_PASSWORD'
)
return return
# good to go -> create user # good to go -> create user
@ -253,9 +262,11 @@ class InvenTreeConfig(AppConfig):
try: try:
with transaction.atomic(): with transaction.atomic():
if user.objects.filter(username=add_user).exists(): if user.objects.filter(username=add_user).exists():
logger.info("User %s already exists - skipping creation", add_user) logger.info('User %s already exists - skipping creation', add_user)
else: else:
new_user = user.objects.create_superuser(add_user, add_email, add_password) new_user = user.objects.create_superuser(
add_user, add_email, add_password
)
logger.info('User %s was created!', str(new_user)) logger.info('User %s was created!', str(new_user))
except IntegrityError: except IntegrityError:
logger.warning('The user "%s" could not be created', add_user) logger.warning('The user "%s" could not be created', add_user)
@ -263,12 +274,12 @@ class InvenTreeConfig(AppConfig):
def add_user_from_file(self): def add_user_from_file(self):
"""Add the superuser from a file.""" """Add the superuser from a file."""
# stop if checks were already created # stop if checks were already created
if hasattr(settings, "USER_ADDED_FILE") and settings.USER_ADDED_FILE: if hasattr(settings, 'USER_ADDED_FILE') and settings.USER_ADDED_FILE:
return return
# get values # get values
add_password_file = get_setting( add_password_file = get_setting(
"INVENTREE_ADMIN_PASSWORD_FILE", "admin_password_file", None 'INVENTREE_ADMIN_PASSWORD_FILE', 'admin_password_file', None
) )
# no variable set -> do not try anything # no variable set -> do not try anything
@ -284,7 +295,11 @@ class InvenTreeConfig(AppConfig):
return return
# good to go -> create user # good to go -> create user
self._create_admin_user(get_setting('INVENTREE_ADMIN_USER', 'admin_user', 'admin'), get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email', ''), add_password_file.read_text(encoding="utf-8")) self._create_admin_user(
get_setting('INVENTREE_ADMIN_USER', 'admin_user', 'admin'),
get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email', ''),
add_password_file.read_text(encoding='utf-8'),
)
# do not try again # do not try again
settings.USER_ADDED_FILE = True settings.USER_ADDED_FILE = True

View File

@ -24,22 +24,14 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))
output_dir = os.path.join( output_dir = os.path.join(here, '..', '..', 'js_tmp')
here,
'..',
'..',
'js_tmp',
)
output_dir = os.path.abspath(output_dir) output_dir = os.path.abspath(output_dir)
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.mkdir(output_dir) os.mkdir(output_dir)
output_file = os.path.join( output_file = os.path.join(output_dir, filename)
output_dir,
filename,
)
with open(output_file, 'wb') as output: with open(output_file, 'wb') as output:
output.write(response.content) output.write(response.content)
@ -48,12 +40,7 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
"""Download files in directory.""" """Download files in directory."""
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))
js_template_dir = os.path.join( js_template_dir = os.path.join(here, '..', 'templates', 'js')
here,
'..',
'templates',
'js',
)
directory = os.path.join(js_template_dir, subdir) directory = os.path.join(js_template_dir, subdir)
@ -76,9 +63,9 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
"""Look for all javascript files.""" """Look for all javascript files."""
n = 0 n = 0
print("Rendering javascript files...") print('Rendering javascript files...')
n += self.download_files('translated', '/js/i18n') n += self.download_files('translated', '/js/i18n')
n += self.download_files('dynamic', '/js/dynamic') n += self.download_files('dynamic', '/js/dynamic')
print(f"Rendered {n} javascript files.") print(f'Rendered {n} javascript files.')

View File

@ -51,7 +51,11 @@ def to_dict(value):
try: try:
return json.loads(value) return json.loads(value)
except Exception as error: except Exception as error:
logger.exception("Failed to parse value '%s' as JSON with error %s. Ensure value is a valid JSON string.", value, error) logger.exception(
"Failed to parse value '%s' as JSON with error %s. Ensure value is a valid JSON string.",
value,
error,
)
return {} return {}
@ -90,12 +94,14 @@ def get_config_file(create=True) -> Path:
cfg_filename = base_dir.joinpath('config.yaml').resolve() cfg_filename = base_dir.joinpath('config.yaml').resolve()
if not cfg_filename.exists() and create: if not cfg_filename.exists() and create:
print("InvenTree configuration file 'config.yaml' not found - creating default file") print(
"InvenTree configuration file 'config.yaml' not found - creating default file"
)
ensure_dir(cfg_filename.parent) ensure_dir(cfg_filename.parent)
cfg_template = base_dir.joinpath("config_template.yaml") cfg_template = base_dir.joinpath('config_template.yaml')
shutil.copyfile(cfg_template, cfg_filename) shutil.copyfile(cfg_template, cfg_filename)
print(f"Created config file {cfg_filename}") print(f'Created config file {cfg_filename}')
return cfg_filename return cfg_filename
@ -153,7 +159,13 @@ def do_typecast(value, type, var_name=None):
return val return val
except Exception as error: except Exception as error:
if var_name: if var_name:
logger.exception("Failed to typecast '%s' with value '%s' to type '%s' with error %s", var_name, value, type, error) logger.exception(
"Failed to typecast '%s' with value '%s' to type '%s' with error %s",
var_name,
value,
type,
error,
)
return value return value
@ -174,7 +186,12 @@ def get_setting(env_var=None, config_key=None, default_value=None, typecast=None
def set_metadata(source: str): def set_metadata(source: str):
"""Set lookup metadata for the setting.""" """Set lookup metadata for the setting."""
key = env_var or config_key key = env_var or config_key
CONFIG_LOOKUPS[key] = {'env_var': env_var, 'config_key': config_key, 'source': source, 'accessed': datetime.datetime.now()} CONFIG_LOOKUPS[key] = {
'env_var': env_var,
'config_key': config_key,
'source': source,
'accessed': datetime.datetime.now(),
}
# First, try to load from the environment variables # First, try to load from the environment variables
if env_var is not None: if env_var is not None:
@ -192,7 +209,6 @@ def get_setting(env_var=None, config_key=None, default_value=None, typecast=None
# Hack to allow 'path traversal' in configuration file # Hack to allow 'path traversal' in configuration file
for key in config_key.strip().split('.'): for key in config_key.strip().split('.'):
if type(cfg_data) is not dict or key not in cfg_data: if type(cfg_data) is not dict or key not in cfg_data:
result = None result = None
break break
@ -276,12 +292,16 @@ def get_plugin_file():
plugin_file = Path(plugin_file) plugin_file = Path(plugin_file)
if not plugin_file.exists(): if not plugin_file.exists():
logger.warning("Plugin configuration file does not exist - creating default file") logger.warning(
'Plugin configuration file does not exist - creating default file'
)
logger.info("Creating plugin file at '%s'", plugin_file) logger.info("Creating plugin file at '%s'", plugin_file)
ensure_dir(plugin_file.parent) ensure_dir(plugin_file.parent)
# If opening the file fails (no write permission, for example), then this will throw an error # If opening the file fails (no write permission, for example), then this will throw an error
plugin_file.write_text("# InvenTree Plugins (uses PIP framework to install)\n\n") plugin_file.write_text(
'# InvenTree Plugins (uses PIP framework to install)\n\n'
)
return plugin_file return plugin_file
@ -303,7 +323,7 @@ def get_secret_key():
""" """
# Look for environment variable # Look for environment variable
if secret_key := get_setting('INVENTREE_SECRET_KEY', 'secret_key'): if secret_key := get_setting('INVENTREE_SECRET_KEY', 'secret_key'):
logger.info("SECRET_KEY loaded by INVENTREE_SECRET_KEY") # pragma: no cover logger.info('SECRET_KEY loaded by INVENTREE_SECRET_KEY') # pragma: no cover
return secret_key return secret_key
# Look for secret key file # Look for secret key file
@ -311,7 +331,7 @@ def get_secret_key():
secret_key_file = Path(secret_key_file).resolve() secret_key_file = Path(secret_key_file).resolve()
else: else:
# Default location for secret key file # Default location for secret key file
secret_key_file = get_base_dir().joinpath("secret_key.txt").resolve() secret_key_file = get_base_dir().joinpath('secret_key.txt').resolve()
if not secret_key_file.exists(): if not secret_key_file.exists():
logger.info("Generating random key file at '%s'", secret_key_file) logger.info("Generating random key file at '%s'", secret_key_file)
@ -329,7 +349,9 @@ def get_secret_key():
return key_data return key_data
def get_custom_file(env_ref: str, conf_ref: str, log_ref: str, lookup_media: bool = False): def get_custom_file(
env_ref: str, conf_ref: str, log_ref: str, lookup_media: bool = False
):
"""Returns the checked path to a custom file. """Returns the checked path to a custom file.
Set lookup_media to True to also search in the media folder. Set lookup_media to True to also search in the media folder.
@ -345,12 +367,17 @@ def get_custom_file(env_ref: str, conf_ref: str, log_ref: str, lookup_media: boo
static_storage = StaticFilesStorage() static_storage = StaticFilesStorage()
if static_storage.exists(value): if static_storage.exists(value):
logger.info("Loading %s from %s directory: %s", log_ref, 'static', value) logger.info('Loading %s from %s directory: %s', log_ref, 'static', value)
elif lookup_media and default_storage.exists(value): elif lookup_media and default_storage.exists(value):
logger.info("Loading %s from %s directory: %s", log_ref, 'media', value) logger.info('Loading %s from %s directory: %s', log_ref, 'media', value)
else: else:
add_dir_str = ' or media' if lookup_media else '' add_dir_str = ' or media' if lookup_media else ''
logger.warning("The %s file '%s' could not be found in the static %s directories", log_ref, value, add_dir_str) logger.warning(
"The %s file '%s' could not be found in the static %s directories",
log_ref,
value,
add_dir_str,
)
value = False value = False
return value return value
@ -364,16 +391,21 @@ def get_frontend_settings(debug=True):
""" """
# Legacy settings # Legacy settings
pui_settings = get_setting('INVENTREE_PUI_SETTINGS', 'pui_settings', {}, typecast=dict) pui_settings = get_setting(
'INVENTREE_PUI_SETTINGS', 'pui_settings', {}, typecast=dict
)
if len(pui_settings) > 0: if len(pui_settings) > 0:
warnings.warn( warnings.warn(
"The 'INVENTREE_PUI_SETTINGS' key is deprecated. Please use 'INVENTREE_FRONTEND_SETTINGS' instead", "The 'INVENTREE_PUI_SETTINGS' key is deprecated. Please use 'INVENTREE_FRONTEND_SETTINGS' instead",
DeprecationWarning, stacklevel=2 DeprecationWarning,
stacklevel=2,
) )
# New settings # New settings
frontend_settings = get_setting('INVENTREE_FRONTEND_SETTINGS', 'frontend_settings', {}, typecast=dict) frontend_settings = get_setting(
'INVENTREE_FRONTEND_SETTINGS', 'frontend_settings', {}, typecast=dict
)
# Merge settings # Merge settings
settings = {**pui_settings, **frontend_settings} settings = {**pui_settings, **frontend_settings}
@ -385,10 +417,13 @@ def get_frontend_settings(debug=True):
if base_url: if base_url:
warnings.warn( warnings.warn(
"The 'INVENTREE_PUI_URL_BASE' key is deprecated. Please use 'INVENTREE_FRONTEND_URL_BASE' instead", "The 'INVENTREE_PUI_URL_BASE' key is deprecated. Please use 'INVENTREE_FRONTEND_URL_BASE' instead",
DeprecationWarning, stacklevel=2 DeprecationWarning,
stacklevel=2,
) )
else: else:
base_url = get_setting('INVENTREE_FRONTEND_URL_BASE', 'frontend_url_base', 'platform') base_url = get_setting(
'INVENTREE_FRONTEND_URL_BASE', 'frontend_url_base', 'platform'
)
settings['base_url'] = base_url settings['base_url'] = base_url

View File

@ -31,9 +31,7 @@ def health_status(request):
} }
# The following keys are required to denote system health # The following keys are required to denote system health
health_keys = [ health_keys = ['django_q_running']
'django_q_running',
]
all_healthy = True all_healthy = True
@ -72,11 +70,9 @@ def user_roles(request):
""" """
user = request.user user = request.user
roles = { roles = {}
}
for role in RuleSet.RULESET_MODELS.keys(): for role in RuleSet.RULESET_MODELS.keys():
permissions = {} permissions = {}
for perm in ['view', 'add', 'change', 'delete']: for perm in ['view', 'add', 'change', 'delete']:

View File

@ -29,6 +29,7 @@ def reload_unit_registry():
This function is called at startup, and whenever the database is updated. This function is called at startup, and whenever the database is updated.
""" """
import time import time
t_start = time.time() t_start = time.time()
global _unit_registry global _unit_registry
@ -52,7 +53,9 @@ def reload_unit_registry():
try: try:
reg.define(cu.fmt_string()) reg.define(cu.fmt_string())
except Exception as e: except Exception as e:
logger.exception('Failed to load custom unit: %s - %s', cu.fmt_string(), e) logger.exception(
'Failed to load custom unit: %s - %s', cu.fmt_string(), e
)
# Once custom units are loaded, save registry # Once custom units are loaded, save registry
_unit_registry = reg _unit_registry = reg
@ -124,7 +127,7 @@ def convert_physical_value(value: str, unit: str = None, strip_units=True):
if unit: if unit:
raise ValidationError(_(f'Could not convert {original} to {unit}')) raise ValidationError(_(f'Could not convert {original} to {unit}'))
else: else:
raise ValidationError(_("Invalid quantity supplied")) raise ValidationError(_('Invalid quantity supplied'))
# Calculate the "magnitude" of the value, as a float # Calculate the "magnitude" of the value, as a float
# If the value is specified strangely (e.g. as a fraction or a dozen), this can cause issues # If the value is specified strangely (e.g. as a fraction or a dozen), this can cause issues

View File

@ -30,22 +30,22 @@ def is_email_configured():
# Display warning unless in test mode # Display warning unless in test mode
if not testing: # pragma: no cover if not testing: # pragma: no cover
logger.debug("EMAIL_HOST is not configured") logger.debug('EMAIL_HOST is not configured')
# Display warning unless in test mode # Display warning unless in test mode
if not settings.EMAIL_HOST_USER and not testing: # pragma: no cover if not settings.EMAIL_HOST_USER and not testing: # pragma: no cover
logger.debug("EMAIL_HOST_USER is not configured") logger.debug('EMAIL_HOST_USER is not configured')
# Display warning unless in test mode # Display warning unless in test mode
if not settings.EMAIL_HOST_PASSWORD and testing: # pragma: no cover if not settings.EMAIL_HOST_PASSWORD and testing: # pragma: no cover
logger.debug("EMAIL_HOST_PASSWORD is not configured") logger.debug('EMAIL_HOST_PASSWORD is not configured')
# Email sender must be configured # Email sender must be configured
if not settings.DEFAULT_FROM_EMAIL: if not settings.DEFAULT_FROM_EMAIL:
configured = False configured = False
if not testing: # pragma: no cover if not testing: # pragma: no cover
logger.debug("DEFAULT_FROM_EMAIL is not configured") logger.debug('DEFAULT_FROM_EMAIL is not configured')
return configured return configured
@ -75,7 +75,7 @@ def send_email(subject, body, recipients, from_email=None, html_message=None):
if settings.TESTING: if settings.TESTING:
from_email = 'from@test.com' from_email = 'from@test.com'
else: else:
logger.error("send_email failed: DEFAULT_FROM_EMAIL not specified") logger.error('send_email failed: DEFAULT_FROM_EMAIL not specified')
return return
InvenTree.tasks.offload_task( InvenTree.tasks.offload_task(
@ -85,5 +85,5 @@ def send_email(subject, body, recipients, from_email=None, html_message=None):
from_email, from_email,
recipients, recipients,
fail_silently=False, fail_silently=False,
html_message=html_message html_message=html_message,
) )

View File

@ -86,7 +86,7 @@ def exception_handler(exc, context):
# If in DEBUG mode, provide error information in the response # If in DEBUG mode, provide error information in the response
error_detail = str(exc) error_detail = str(exc)
else: else:
error_detail = _("Error details can be found in the admin panel") error_detail = _('Error details can be found in the admin panel')
response_data = { response_data = {
'error': type(exc).__name__, 'error': type(exc).__name__,

View File

@ -18,7 +18,7 @@ class InvenTreeExchange(SimpleExchangeBackend):
Uses the plugin system to actually fetch the rates from an external API. Uses the plugin system to actually fetch the rates from an external API.
""" """
name = "InvenTreeExchange" name = 'InvenTreeExchange'
def get_rates(self, **kwargs) -> None: def get_rates(self, **kwargs) -> None:
"""Set the requested currency codes and get rates.""" """Set the requested currency codes and get rates."""
@ -44,7 +44,9 @@ class InvenTreeExchange(SimpleExchangeBackend):
plugin = plugins[0] plugin = plugins[0]
if not plugin: if not plugin:
logger.warning('No active currency exchange plugins found - skipping update') logger.warning(
'No active currency exchange plugins found - skipping update'
)
return {} return {}
logger.info("Running exchange rate update using plugin '%s'", plugin.name) logger.info("Running exchange rate update using plugin '%s'", plugin.name)
@ -53,16 +55,22 @@ class InvenTreeExchange(SimpleExchangeBackend):
try: try:
rates = plugin.update_exchange_rates(base_currency, symbols) rates = plugin.update_exchange_rates(base_currency, symbols)
except Exception as exc: except Exception as exc:
logger.exception("Exchange rate update failed: %s", exc) logger.exception('Exchange rate update failed: %s', exc)
return {} return {}
if not rates: if not rates:
logger.warning("Exchange rate update failed - no data returned from plugin %s", slug) logger.warning(
'Exchange rate update failed - no data returned from plugin %s', slug
)
return {} return {}
# Update exchange rates based on returned data # Update exchange rates based on returned data
if type(rates) is not dict: if type(rates) is not dict:
logger.warning("Invalid exchange rate data returned from plugin %s (type %s)", slug, type(rates)) logger.warning(
'Invalid exchange rate data returned from plugin %s (type %s)',
slug,
type(rates),
)
return {} return {}
# Ensure base currency is provided # Ensure base currency is provided
@ -73,14 +81,20 @@ class InvenTreeExchange(SimpleExchangeBackend):
@atomic @atomic
def update_rates(self, base_currency=None, **kwargs): def update_rates(self, base_currency=None, **kwargs):
"""Call to update all exchange rates""" """Call to update all exchange rates"""
backend, _ = ExchangeBackend.objects.update_or_create(name=self.name, defaults={"base_currency": base_currency}) backend, _ = ExchangeBackend.objects.update_or_create(
name=self.name, defaults={'base_currency': base_currency}
)
if base_currency is None: if base_currency is None:
base_currency = currency_code_default() base_currency = currency_code_default()
symbols = currency_codes() symbols = currency_codes()
logger.info("Updating exchange rates for %s (%s currencies)", base_currency, len(symbols)) logger.info(
'Updating exchange rates for %s (%s currencies)',
base_currency,
len(symbols),
)
# Fetch new rates from the backend # Fetch new rates from the backend
# If the backend fails, the existing rates will not be updated # If the backend fails, the existing rates will not be updated
@ -95,6 +109,8 @@ class InvenTreeExchange(SimpleExchangeBackend):
for currency, amount in rates.items() for currency, amount in rates.items()
]) ])
else: else:
logger.info("No exchange rates returned from backend - currencies not updated") logger.info(
'No exchange rates returned from backend - currencies not updated'
)
logger.info("Updated exchange rates for %s", base_currency) logger.info('Updated exchange rates for %s', base_currency)

View File

@ -35,7 +35,9 @@ class InvenTreeRestURLField(RestURLField):
import common.models import common.models
strict_urls = common.models.InvenTreeSetting.get_setting('INVENTREE_STRICT_URLS', True, cache=False) strict_urls = common.models.InvenTreeSetting.get_setting(
'INVENTREE_STRICT_URLS', True, cache=False
)
if not strict_urls and data is not empty: if not strict_urls and data is not empty:
if '://' not in data: if '://' not in data:
@ -97,11 +99,8 @@ class InvenTreeModelMoneyField(ModelMoneyField):
# If no validators are provided, add some "standard" ones # If no validators are provided, add some "standard" ones
if len(validators) == 0: if len(validators) == 0:
if not allow_negative: if not allow_negative:
validators.append( validators.append(MinMoneyValidator(0))
MinMoneyValidator(0),
)
kwargs['validators'] = validators kwargs['validators'] = validators
@ -144,11 +143,7 @@ class DatePickerFormField(forms.DateField):
required = kwargs.get('required', False) required = kwargs.get('required', False)
initial = kwargs.get('initial', None) initial = kwargs.get('initial', None)
widget = forms.DateInput( widget = forms.DateInput(attrs={'type': 'date'})
attrs={
'type': 'date',
}
)
forms.DateField.__init__( forms.DateField.__init__(
self, self,
@ -156,7 +151,7 @@ class DatePickerFormField(forms.DateField):
initial=initial, initial=initial,
help_text=help_text, help_text=help_text,
widget=widget, widget=widget,
label=label label=label,
) )

View File

@ -36,7 +36,9 @@ class InvenTreeSearchFilter(filters.SearchFilter):
The following query params are available to 'augment' the search (in decreasing order of priority) The following query params are available to 'augment' the search (in decreasing order of priority)
- search_regex: If True, search is performed on 'regex' comparison - search_regex: If True, search is performed on 'regex' comparison
""" """
regex = InvenTree.helpers.str2bool(request.query_params.get('search_regex', False)) regex = InvenTree.helpers.str2bool(
request.query_params.get('search_regex', False)
)
search_fields = super().get_search_fields(view, request) search_fields = super().get_search_fields(view, request)
@ -56,7 +58,9 @@ class InvenTreeSearchFilter(filters.SearchFilter):
Depending on the request parameters, we may "augment" these somewhat Depending on the request parameters, we may "augment" these somewhat
""" """
whole = InvenTree.helpers.str2bool(request.query_params.get('search_whole', False)) whole = InvenTree.helpers.str2bool(
request.query_params.get('search_whole', False)
)
terms = [] terms = []
@ -72,7 +76,7 @@ class InvenTreeSearchFilter(filters.SearchFilter):
if whole: if whole:
# Wrap the search term to enable word-boundary matching # Wrap the search term to enable word-boundary matching
term = r"\y" + term + r"\y" term = r'\y' + term + r'\y'
terms.append(term) terms.append(term)
@ -110,7 +114,6 @@ class InvenTreeOrderingFilter(filters.OrderingFilter):
ordering = [] ordering = []
for field in ordering_initial: for field in ordering_initial:
reverse = field.startswith('-') reverse = field.startswith('-')
if reverse: if reverse:
@ -164,7 +167,4 @@ SEARCH_ORDER_FILTER_ALIAS = [
InvenTreeOrderingFilter, InvenTreeOrderingFilter,
] ]
ORDER_FILTER = [ ORDER_FILTER = [rest_filters.DjangoFilterBackend, filters.OrderingFilter]
rest_filters.DjangoFilterBackend,
filters.OrderingFilter,
]

View File

@ -36,10 +36,7 @@ def parse_format_string(fmt_string: str) -> dict:
else: else:
seen_groups.add(name) seen_groups.add(name)
info[group[1]] = { info[group[1]] = {'format': group[1], 'prefix': group[0]}
'format': group[1],
'prefix': group[0],
}
return info return info
@ -67,7 +64,7 @@ def construct_format_regex(fmt_string: str) -> str:
Raises: Raises:
ValueError: Format string is invalid ValueError: Format string is invalid
""" """
pattern = "^" pattern = '^'
for group in string.Formatter().parse(fmt_string): for group in string.Formatter().parse(fmt_string):
prefix = group[0] # Prefix (literal text appearing before this group) prefix = group[0] # Prefix (literal text appearing before this group)
@ -75,9 +72,23 @@ def construct_format_regex(fmt_string: str) -> str:
format = group[2] # Format specifier e.g :04d format = group[2] # Format specifier e.g :04d
rep = [ rep = [
'+', '-', '.', '+',
'{', '}', '(', ')', '-',
'^', '$', '~', '!', '@', ':', ';', '|', '\'', '"', '.',
'{',
'}',
'(',
')',
'^',
'$',
'~',
'!',
'@',
':',
';',
'|',
"'",
'"',
] ]
# Escape any special regex characters # Escape any special regex characters
@ -94,7 +105,6 @@ def construct_format_regex(fmt_string: str) -> str:
# Add a named capture group for the format entry # Add a named capture group for the format entry
if name: if name:
# Check if integer values are required # Check if integer values are required
if format.endswith('d'): if format.endswith('d'):
chr = '\d' chr = '\d'
@ -105,9 +115,9 @@ def construct_format_regex(fmt_string: str) -> str:
# TODO: Introspect required width # TODO: Introspect required width
w = '+' w = '+'
pattern += f"(?P<{name}>{chr}{w})" pattern += f'(?P<{name}>{chr}{w})'
pattern += "$" pattern += '$'
return pattern return pattern
@ -161,7 +171,9 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str:
result = re.match(pattern, value) result = re.match(pattern, value)
if not result: if not result:
raise ValueError(_("Provided value does not match required pattern: ") + fmt_string) raise ValueError(
_('Provided value does not match required pattern: ') + fmt_string
)
# And return the value we are interested in # And return the value we are interested in
# Note: This will raise an IndexError if the named group was not matched # Note: This will raise an IndexError if the named group was not matched
@ -186,7 +198,7 @@ def format_money(money: Money, decimal_places: int = None, format: str = None) -
if format: if format:
pattern = parse_pattern(format) pattern = parse_pattern(format)
else: else:
pattern = locale.currency_formats["standard"] pattern = locale.currency_formats['standard']
if decimal_places is not None: if decimal_places is not None:
pattern.frac_prec = (decimal_places, decimal_places) pattern.frac_prec = (decimal_places, decimal_places)

View File

@ -17,8 +17,7 @@ from allauth.core.exceptions import ImmediateHttpResponse
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
from allauth_2fa.adapter import OTPAdapter from allauth_2fa.adapter import OTPAdapter
from allauth_2fa.utils import user_has_valid_totp_device from allauth_2fa.utils import user_has_valid_totp_device
from crispy_forms.bootstrap import (AppendedText, PrependedAppendedText, from crispy_forms.bootstrap import AppendedText, PrependedAppendedText, PrependedText
PrependedText)
from crispy_forms.helper import FormHelper from crispy_forms.helper import FormHelper
from crispy_forms.layout import Field, Layout from crispy_forms.layout import Field, Layout
from dj_rest_auth.registration.serializers import RegisterSerializer from dj_rest_auth.registration.serializers import RegisterSerializer
@ -80,31 +79,19 @@ class HelperForm(forms.ModelForm):
field, field,
prepended_text=prefix, prepended_text=prefix,
appended_text=suffix, appended_text=suffix,
placeholder=placeholder placeholder=placeholder,
) )
) )
) )
elif prefix: elif prefix:
layouts.append( layouts.append(
Field( Field(PrependedText(field, prefix, placeholder=placeholder))
PrependedText(
field,
prefix,
placeholder=placeholder
)
)
) )
elif suffix: elif suffix:
layouts.append( layouts.append(
Field( Field(AppendedText(field, suffix, placeholder=placeholder))
AppendedText(
field,
suffix,
placeholder=placeholder
)
)
) )
else: else:
@ -120,10 +107,7 @@ class EditUserForm(HelperForm):
"""Metaclass options.""" """Metaclass options."""
model = User model = User
fields = [ fields = ['first_name', 'last_name']
'first_name',
'last_name',
]
class SetPasswordForm(HelperForm): class SetPasswordForm(HelperForm):
@ -133,11 +117,7 @@ class SetPasswordForm(HelperForm):
"""Metaclass options.""" """Metaclass options."""
model = User model = User
fields = [ fields = ['enter_password', 'confirm_password', 'old_password']
'enter_password',
'confirm_password',
'old_password',
]
enter_password = forms.CharField( enter_password = forms.CharField(
max_length=100, max_length=100,
@ -146,7 +126,7 @@ class SetPasswordForm(HelperForm):
initial='', initial='',
widget=forms.PasswordInput(attrs={'autocomplete': 'off'}), widget=forms.PasswordInput(attrs={'autocomplete': 'off'}),
label=_('Enter password'), label=_('Enter password'),
help_text=_('Enter new password') help_text=_('Enter new password'),
) )
confirm_password = forms.CharField( confirm_password = forms.CharField(
@ -156,14 +136,16 @@ class SetPasswordForm(HelperForm):
initial='', initial='',
widget=forms.PasswordInput(attrs={'autocomplete': 'off'}), widget=forms.PasswordInput(attrs={'autocomplete': 'off'}),
label=_('Confirm password'), label=_('Confirm password'),
help_text=_('Confirm new password') help_text=_('Confirm new password'),
) )
old_password = forms.CharField( old_password = forms.CharField(
label=_("Old password"), label=_('Old password'),
strip=False, strip=False,
required=False, required=False,
widget=forms.PasswordInput(attrs={'autocomplete': 'current-password', 'autofocus': True}), widget=forms.PasswordInput(
attrs={'autocomplete': 'current-password', 'autofocus': True}
),
) )
@ -196,22 +178,24 @@ class CustomSignupForm(SignupForm):
# check for two mail fields # check for two mail fields
if InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_TWICE'): if InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_TWICE'):
self.fields["email2"] = forms.EmailField( self.fields['email2'] = forms.EmailField(
label=_("Email (again)"), label=_('Email (again)'),
widget=forms.TextInput( widget=forms.TextInput(
attrs={ attrs={
"type": "email", 'type': 'email',
"placeholder": _("Email address confirmation"), 'placeholder': _('Email address confirmation'),
} }
), ),
) )
# check for two password fields # check for two password fields
if not InvenTreeSetting.get_setting('LOGIN_SIGNUP_PWD_TWICE'): if not InvenTreeSetting.get_setting('LOGIN_SIGNUP_PWD_TWICE'):
self.fields.pop("password2") self.fields.pop('password2')
# reorder fields # reorder fields
set_form_field_order(self, ["username", "email", "email2", "password1", "password2", ]) set_form_field_order(
self, ['username', 'email', 'email2', 'password1', 'password2']
)
def clean(self): def clean(self):
"""Make sure the supplied emails match if enabled in settings.""" """Make sure the supplied emails match if enabled in settings."""
@ -219,21 +203,26 @@ class CustomSignupForm(SignupForm):
# check for two mail fields # check for two mail fields
if InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_TWICE'): if InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_TWICE'):
email = cleaned_data.get("email") email = cleaned_data.get('email')
email2 = cleaned_data.get("email2") email2 = cleaned_data.get('email2')
if (email and email2) and email != email2: if (email and email2) and email != email2:
self.add_error("email2", _("You must type the same email each time.")) self.add_error('email2', _('You must type the same email each time.'))
return cleaned_data return cleaned_data
def registration_enabled(): def registration_enabled():
"""Determine whether user registration is enabled.""" """Determine whether user registration is enabled."""
if InvenTreeSetting.get_setting('LOGIN_ENABLE_REG') or InvenTree.sso.registration_enabled(): if (
InvenTreeSetting.get_setting('LOGIN_ENABLE_REG')
or InvenTree.sso.registration_enabled()
):
if settings.EMAIL_HOST: if settings.EMAIL_HOST:
return True return True
else: else:
logger.error("Registration cannot be enabled, because EMAIL_HOST is not configured.") logger.error(
'Registration cannot be enabled, because EMAIL_HOST is not configured.'
)
return False return False
@ -251,20 +240,26 @@ class RegistratonMixin:
def clean_email(self, email): def clean_email(self, email):
"""Check if the mail is valid to the pattern in LOGIN_SIGNUP_MAIL_RESTRICTION (if enabled in settings).""" """Check if the mail is valid to the pattern in LOGIN_SIGNUP_MAIL_RESTRICTION (if enabled in settings)."""
mail_restriction = InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_RESTRICTION', None) mail_restriction = InvenTreeSetting.get_setting(
'LOGIN_SIGNUP_MAIL_RESTRICTION', None
)
if not mail_restriction: if not mail_restriction:
return super().clean_email(email) return super().clean_email(email)
split_email = email.split('@') split_email = email.split('@')
if len(split_email) != 2: if len(split_email) != 2:
logger.error('The user %s has an invalid email address', email) logger.error('The user %s has an invalid email address', email)
raise forms.ValidationError(_('The provided primary email address is not valid.')) raise forms.ValidationError(
_('The provided primary email address is not valid.')
)
mailoptions = mail_restriction.split(',') mailoptions = mail_restriction.split(',')
for option in mailoptions: for option in mailoptions:
if not option.startswith('@'): if not option.startswith('@'):
log_error('LOGIN_SIGNUP_MAIL_RESTRICTION is not configured correctly') log_error('LOGIN_SIGNUP_MAIL_RESTRICTION is not configured correctly')
raise forms.ValidationError(_('The provided primary email address is not valid.')) raise forms.ValidationError(
_('The provided primary email address is not valid.')
)
else: else:
if split_email[1] == option[1:]: if split_email[1] == option[1:]:
return super().clean_email(email) return super().clean_email(email)
@ -284,7 +279,10 @@ class RegistratonMixin:
group = Group.objects.get(id=start_group) group = Group.objects.get(id=start_group)
user.groups.add(group) user.groups.add(group)
except Group.DoesNotExist: except Group.DoesNotExist:
logger.exception('The setting `SIGNUP_GROUP` contains an non existent group', start_group) logger.exception(
'The setting `SIGNUP_GROUP` contains an non existent group',
start_group,
)
user.save() user.save()
return user return user
@ -294,11 +292,13 @@ class CustomUrlMixin:
def get_email_confirmation_url(self, request, emailconfirmation): def get_email_confirmation_url(self, request, emailconfirmation):
"""Custom email confirmation (activation) url.""" """Custom email confirmation (activation) url."""
url = reverse("account_confirm_email", args=[emailconfirmation.key]) url = reverse('account_confirm_email', args=[emailconfirmation.key])
return Site.objects.get_current().domain + url return Site.objects.get_current().domain + url
class CustomAccountAdapter(CustomUrlMixin, RegistratonMixin, OTPAdapter, DefaultAccountAdapter): class CustomAccountAdapter(
CustomUrlMixin, RegistratonMixin, OTPAdapter, DefaultAccountAdapter
):
"""Override of adapter to use dynamic settings.""" """Override of adapter to use dynamic settings."""
def send_mail(self, template_prefix, email, context): def send_mail(self, template_prefix, email, context):
@ -325,7 +325,9 @@ class CustomAccountAdapter(CustomUrlMixin, RegistratonMixin, OTPAdapter, Default
return url return url
class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocialAccountAdapter): class CustomSocialAccountAdapter(
CustomUrlMixin, RegistratonMixin, DefaultSocialAccountAdapter
):
"""Override of adapter to use dynamic settings.""" """Override of adapter to use dynamic settings."""
def is_auto_signup_allowed(self, request, sociallogin): def is_auto_signup_allowed(self, request, sociallogin):
@ -352,14 +354,14 @@ class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocial
if request.GET: if request.GET:
redirect_url += '?' + urlencode(request.GET) redirect_url += '?' + urlencode(request.GET)
raise ImmediateHttpResponse( raise ImmediateHttpResponse(response=HttpResponseRedirect(redirect_url))
response=HttpResponseRedirect(redirect_url)
)
# Otherwise defer to the original allauth adapter. # Otherwise defer to the original allauth adapter.
return super().login(request, user) return super().login(request, user)
def authentication_error(self, request, provider_id, error=None, exception=None, extra_context=None): def authentication_error(
self, request, provider_id, error=None, exception=None, extra_context=None
):
"""Callback method for authentication errors.""" """Callback method for authentication errors."""
# Log the error to the database # Log the error to the database
@ -370,6 +372,7 @@ class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocial
# override dj-rest-auth # override dj-rest-auth
class CustomRegisterSerializer(RegisterSerializer): class CustomRegisterSerializer(RegisterSerializer):
"""Override of serializer to use dynamic settings.""" """Override of serializer to use dynamic settings."""
email = serializers.EmailField() email = serializers.EmailField()
def __init__(self, instance=None, data=..., **kwargs): def __init__(self, instance=None, data=..., **kwargs):

View File

@ -36,7 +36,7 @@ def generateTestKey(test_name):
Tests must be named such that they will have unique keys. Tests must be named such that they will have unique keys.
""" """
key = test_name.strip().lower() key = test_name.strip().lower()
key = key.replace(" ", "") key = key.replace(' ', '')
# Remove any characters that cannot be used to represent a variable # Remove any characters that cannot be used to represent a variable
key = re.sub(r'[^a-zA-Z0-9]', '', key) key = re.sub(r'[^a-zA-Z0-9]', '', key)
@ -56,7 +56,7 @@ def constructPathString(path, max_chars=250):
# Replace middle elements to limit the pathstring # Replace middle elements to limit the pathstring
if len(pathstring) > max_chars: if len(pathstring) > max_chars:
n = int(max_chars / 2 - 2) n = int(max_chars / 2 - 2)
pathstring = pathstring[:n] + "..." + pathstring[-n:] pathstring = pathstring[:n] + '...' + pathstring[-n:]
return pathstring return pathstring
@ -82,19 +82,18 @@ def TestIfImage(img):
def getBlankImage(): def getBlankImage():
"""Return the qualified path for the 'blank image' placeholder.""" """Return the qualified path for the 'blank image' placeholder."""
return getStaticUrl("img/blank_image.png") return getStaticUrl('img/blank_image.png')
def getBlankThumbnail(): def getBlankThumbnail():
"""Return the qualified path for the 'blank image' thumbnail placeholder.""" """Return the qualified path for the 'blank image' thumbnail placeholder."""
return getStaticUrl("img/blank_image.thumbnail.png") return getStaticUrl('img/blank_image.thumbnail.png')
def getLogoImage(as_file=False, custom=True): def getLogoImage(as_file=False, custom=True):
"""Return the InvenTree logo image, or a custom logo if available.""" """Return the InvenTree logo image, or a custom logo if available."""
"""Return the path to the logo-file.""" """Return the path to the logo-file."""
if custom and settings.CUSTOM_LOGO: if custom and settings.CUSTOM_LOGO:
static_storage = StaticFilesStorage() static_storage = StaticFilesStorage()
if static_storage.exists(settings.CUSTOM_LOGO): if static_storage.exists(settings.CUSTOM_LOGO):
@ -106,13 +105,13 @@ def getLogoImage(as_file=False, custom=True):
if storage is not None: if storage is not None:
if as_file: if as_file:
return f"file://{storage.path(settings.CUSTOM_LOGO)}" return f'file://{storage.path(settings.CUSTOM_LOGO)}'
return storage.url(settings.CUSTOM_LOGO) return storage.url(settings.CUSTOM_LOGO)
# If we have got to this point, return the default logo # If we have got to this point, return the default logo
if as_file: if as_file:
path = settings.STATIC_ROOT.joinpath('img/inventree.png') path = settings.STATIC_ROOT.joinpath('img/inventree.png')
return f"file://{path}" return f'file://{path}'
return getStaticUrl('img/inventree.png') return getStaticUrl('img/inventree.png')
@ -121,12 +120,11 @@ def getSplashScreen(custom=True):
static_storage = StaticFilesStorage() static_storage = StaticFilesStorage()
if custom and settings.CUSTOM_SPLASH: if custom and settings.CUSTOM_SPLASH:
if static_storage.exists(settings.CUSTOM_SPLASH): if static_storage.exists(settings.CUSTOM_SPLASH):
return static_storage.url(settings.CUSTOM_SPLASH) return static_storage.url(settings.CUSTOM_SPLASH)
# No custom splash screen # No custom splash screen
return static_storage.url("img/inventree_splash.jpg") return static_storage.url('img/inventree_splash.jpg')
def TestIfImageURL(url): def TestIfImageURL(url):
@ -135,10 +133,15 @@ def TestIfImageURL(url):
Simply tests the extension against a set of allowed values Simply tests the extension against a set of allowed values
""" """
return os.path.splitext(os.path.basename(url))[-1].lower() in [ return os.path.splitext(os.path.basename(url))[-1].lower() in [
'.jpg', '.jpeg', '.j2k', '.jpg',
'.png', '.bmp', '.jpeg',
'.tif', '.tiff', '.j2k',
'.webp', '.gif', '.png',
'.bmp',
'.tif',
'.tiff',
'.webp',
'.gif',
] ]
@ -153,8 +156,8 @@ def str2bool(text, test=True):
True if the text looks like the selected boolean value True if the text looks like the selected boolean value
""" """
if test: if test:
return str(text).lower() in ['1', 'y', 'yes', 't', 'true', 'ok', 'on', ] return str(text).lower() in ['1', 'y', 'yes', 't', 'true', 'ok', 'on']
return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off', ] return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off']
def str2int(text, default=None): def str2int(text, default=None):
@ -191,7 +194,15 @@ def isNull(text):
Returns: Returns:
True if the text looks like a null value True if the text looks like a null value
""" """
return str(text).strip().lower() in ['top', 'null', 'none', 'empty', 'false', '-1', ''] return str(text).strip().lower() in [
'top',
'null',
'none',
'empty',
'false',
'-1',
'',
]
def normalize(d): def normalize(d):
@ -223,7 +234,7 @@ def increment(value):
# Provide a default value if provided with a null input # Provide a default value if provided with a null input
return '1' return '1'
pattern = r"(.*?)(\d+)?$" pattern = r'(.*?)(\d+)?$'
result = re.search(pattern, value) result = re.search(pattern, value)
@ -282,7 +293,7 @@ def decimal2string(d):
if '.' not in s: if '.' not in s:
return s return s
return s.rstrip("0").rstrip(".") return s.rstrip('0').rstrip('.')
def decimal2money(d, currency=None): def decimal2money(d, currency=None):
@ -354,17 +365,12 @@ def MakeBarcode(cls_name, object_pk: int, object_data=None, **kwargs):
def GetExportFormats(): def GetExportFormats():
"""Return a list of allowable file formats for exporting data.""" """Return a list of allowable file formats for exporting data."""
return [ return ['csv', 'tsv', 'xls', 'xlsx', 'json', 'yaml']
'csv',
'tsv',
'xls',
'xlsx',
'json',
'yaml',
]
def DownloadFile(data, filename, content_type='application/text', inline=False) -> StreamingHttpResponse: def DownloadFile(
data, filename, content_type='application/text', inline=False
) -> StreamingHttpResponse:
"""Create a dynamic file for the user to download. """Create a dynamic file for the user to download.
Args: Args:
@ -389,7 +395,7 @@ def DownloadFile(data, filename, content_type='application/text', inline=False)
length = len(bytes(data, response.charset)) length = len(bytes(data, response.charset))
response['Content-Length'] = length response['Content-Length'] = length
disposition = "inline" if inline else "attachment" disposition = 'inline' if inline else 'attachment'
response['Content-Disposition'] = f'{disposition}; filename={filename}' response['Content-Disposition'] = f'{disposition}; filename={filename}'
@ -449,7 +455,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
try: try:
expected_quantity = int(expected_quantity) expected_quantity = int(expected_quantity)
except ValueError: except ValueError:
raise ValidationError([_("Invalid quantity provided")]) raise ValidationError([_('Invalid quantity provided')])
if input_string: if input_string:
input_string = str(input_string).strip() input_string = str(input_string).strip()
@ -457,7 +463,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
input_string = '' input_string = ''
if len(input_string) == 0: if len(input_string) == 0:
raise ValidationError([_("Empty serial number string")]) raise ValidationError([_('Empty serial number string')])
next_value = increment_serial_number(starting_value) next_value = increment_serial_number(starting_value)
@ -467,7 +473,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
next_value = increment_serial_number(next_value) next_value = increment_serial_number(next_value)
# Split input string by whitespace or comma (,) characters # Split input string by whitespace or comma (,) characters
groups = re.split(r"[\s,]+", input_string) groups = re.split(r'[\s,]+', input_string)
serials = [] serials = []
errors = [] errors = []
@ -487,7 +493,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
return return
if serial in serials: if serial in serials:
add_error(_("Duplicate serial") + f": {serial}") add_error(_('Duplicate serial') + f': {serial}')
else: else:
serials.append(serial) serials.append(serial)
@ -502,7 +508,6 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
return serials return serials
for group in groups: for group in groups:
# Calculate the "remaining" quantity of serial numbers # Calculate the "remaining" quantity of serial numbers
remaining = expected_quantity - len(serials) remaining = expected_quantity - len(serials)
@ -520,7 +525,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
if a == b: if a == b:
# Invalid group # Invalid group
add_error(_(f"Invalid group range: {group}")) add_error(_(f'Invalid group range: {group}'))
continue continue
group_items = [] group_items = []
@ -549,13 +554,21 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
break break
if len(group_items) > remaining: if len(group_items) > remaining:
add_error(_(f"Group range {group} exceeds allowed quantity ({expected_quantity})")) add_error(
elif len(group_items) > 0 and group_items[0] == a and group_items[-1] == b: _(
f'Group range {group} exceeds allowed quantity ({expected_quantity})'
)
)
elif (
len(group_items) > 0
and group_items[0] == a
and group_items[-1] == b
):
# In this case, the range extraction looks like it has worked # In this case, the range extraction looks like it has worked
for item in group_items: for item in group_items:
add_serial(item) add_serial(item)
else: else:
add_error(_(f"Invalid group range: {group}")) add_error(_(f'Invalid group range: {group}'))
else: else:
# In the case of a different number of hyphens, simply add the entire group # In the case of a different number of hyphens, simply add the entire group
@ -573,20 +586,24 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
sequence_count = max(0, expected_quantity - len(serials)) sequence_count = max(0, expected_quantity - len(serials))
if len(items) > 2 or len(items) == 0: if len(items) > 2 or len(items) == 0:
add_error(_(f"Invalid group sequence: {group}")) add_error(_(f'Invalid group sequence: {group}'))
continue continue
elif len(items) == 2: elif len(items) == 2:
try: try:
if items[1]: if items[1]:
sequence_count = int(items[1]) + 1 sequence_count = int(items[1]) + 1
except ValueError: except ValueError:
add_error(_(f"Invalid group sequence: {group}")) add_error(_(f'Invalid group sequence: {group}'))
continue continue
value = items[0] value = items[0]
# Keep incrementing up to the specified quantity # Keep incrementing up to the specified quantity
while value is not None and value not in sequence_items and counter < sequence_count: while (
value is not None
and value not in sequence_items
and counter < sequence_count
):
sequence_items.append(value) sequence_items.append(value)
value = increment_serial_number(value) value = increment_serial_number(value)
counter += 1 counter += 1
@ -595,7 +612,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
for item in sequence_items: for item in sequence_items:
add_serial(item) add_serial(item)
else: else:
add_error(_(f"Invalid group sequence: {group}")) add_error(_(f'Invalid group sequence: {group}'))
else: else:
# At this point, we assume that the 'group' is just a single serial value # At this point, we assume that the 'group' is just a single serial value
@ -605,10 +622,14 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
raise ValidationError(errors) raise ValidationError(errors)
if len(serials) == 0: if len(serials) == 0:
raise ValidationError([_("No serial numbers found")]) raise ValidationError([_('No serial numbers found')])
if len(errors) == 0 and len(serials) != expected_quantity: if len(errors) == 0 and len(serials) != expected_quantity:
raise ValidationError([_(f"Number of unique serial numbers ({len(serials)}) must match quantity ({expected_quantity})")]) raise ValidationError([
_(
f'Number of unique serial numbers ({len(serials)}) must match quantity ({expected_quantity})'
)
])
return serials return serials
@ -645,9 +666,7 @@ def validateFilterString(value, model=None):
pair = group.split('=') pair = group.split('=')
if len(pair) != 2: if len(pair) != 2:
raise ValidationError( raise ValidationError(f'Invalid group: {group}')
f"Invalid group: {group}"
)
k, v = pair k, v = pair
@ -655,9 +674,7 @@ def validateFilterString(value, model=None):
v = v.strip() v = v.strip()
if not k or not v: if not k or not v:
raise ValidationError( raise ValidationError(f'Invalid group: {group}')
f"Invalid group: {group}"
)
results[k] = v results[k] = v
@ -666,9 +683,7 @@ def validateFilterString(value, model=None):
try: try:
model.objects.filter(**results) model.objects.filter(**results)
except FieldError as e: except FieldError as e:
raise ValidationError( raise ValidationError(str(e))
str(e),
)
return results return results
@ -706,7 +721,11 @@ def clean_decimal(number):
# Number cannot be converted to Decimal (eg. a string containing letters) # Number cannot be converted to Decimal (eg. a string containing letters)
return Decimal(0) return Decimal(0)
return clean_number.quantize(Decimal(1)) if clean_number == clean_number.to_integral() else clean_number.normalize() return (
clean_number.quantize(Decimal(1))
if clean_number == clean_number.to_integral()
else clean_number.normalize()
)
def strip_html_tags(value: str, raise_error=True, field_name=None): def strip_html_tags(value: str, raise_error=True, field_name=None):
@ -714,55 +733,45 @@ def strip_html_tags(value: str, raise_error=True, field_name=None):
If raise_error is True, a ValidationError will be thrown if HTML tags are detected If raise_error is True, a ValidationError will be thrown if HTML tags are detected
""" """
cleaned = clean( cleaned = clean(value, strip=True, tags=[], attributes=[])
value,
strip=True,
tags=[],
attributes=[],
)
# Add escaped characters back in # Add escaped characters back in
replacements = { replacements = {'&gt;': '>', '&lt;': '<', '&amp;': '&'}
'&gt;': '>',
'&lt;': '<',
'&amp;': '&',
}
for o, r in replacements.items(): for o, r in replacements.items():
cleaned = cleaned.replace(o, r) cleaned = cleaned.replace(o, r)
# If the length changed, it means that HTML tags were removed! # If the length changed, it means that HTML tags were removed!
if len(cleaned) != len(value) and raise_error: if len(cleaned) != len(value) and raise_error:
field = field_name or 'non_field_errors' field = field_name or 'non_field_errors'
raise ValidationError({ raise ValidationError({field: [_('Remove HTML tags from this value')]})
field: [_("Remove HTML tags from this value")]
})
return cleaned return cleaned
def remove_non_printable_characters(value: str, remove_newline=True, remove_ascii=True, remove_unicode=True): def remove_non_printable_characters(
value: str, remove_newline=True, remove_ascii=True, remove_unicode=True
):
"""Remove non-printable / control characters from the provided string""" """Remove non-printable / control characters from the provided string"""
cleaned = value cleaned = value
if remove_ascii: if remove_ascii:
# Remove ASCII control characters # Remove ASCII control characters
# Note that we do not sub out 0x0A (\n) here, it is done separately below # Note that we do not sub out 0x0A (\n) here, it is done separately below
cleaned = regex.sub(u'[\x00-\x09]+', '', cleaned) cleaned = regex.sub('[\x00-\x09]+', '', cleaned)
cleaned = regex.sub(u'[\x0b-\x1F\x7F]+', '', cleaned) cleaned = regex.sub('[\x0b-\x1f\x7f]+', '', cleaned)
if remove_newline: if remove_newline:
cleaned = regex.sub(u'[\x0a]+', '', cleaned) cleaned = regex.sub('[\x0a]+', '', cleaned)
if remove_unicode: if remove_unicode:
# Remove Unicode control characters # Remove Unicode control characters
if remove_newline: if remove_newline:
cleaned = regex.sub(u'[^\P{C}]+', '', cleaned) cleaned = regex.sub('[^\P{C}]+', '', cleaned)
else: else:
# Use 'negative-lookahead' to exclude newline character # Use 'negative-lookahead' to exclude newline character
cleaned = regex.sub(u'(?![\x0A])[^\P{C}]+', '', cleaned) cleaned = regex.sub('(?![\x0a])[^\P{C}]+', '', cleaned)
return cleaned return cleaned
@ -783,7 +792,9 @@ def hash_barcode(barcode_data):
return str(hash.hexdigest()) return str(hash.hexdigest())
def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = 'object_id'): def get_objectreference(
obj, type_ref: str = 'content_type', object_ref: str = 'object_id'
):
"""Lookup method for the GenericForeignKey fields. """Lookup method for the GenericForeignKey fields.
Attributes: Attributes:
@ -821,11 +832,7 @@ def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = '
ret = {} ret = {}
if url_fnc: if url_fnc:
ret['link'] = url_fnc() ret['link'] = url_fnc()
return { return {'name': str(item), 'model': str(model_cls._meta.verbose_name), **ret}
'name': str(item),
'model': str(model_cls._meta.verbose_name),
**ret
}
def inheritors(cls): def inheritors(cls):

View File

@ -19,8 +19,11 @@ import common.models
import InvenTree import InvenTree
import InvenTree.helpers_model import InvenTree.helpers_model
import InvenTree.version import InvenTree.version
from common.notifications import (InvenTreeNotificationBodies, from common.notifications import (
NotificationBody, trigger_notification) InvenTreeNotificationBodies,
NotificationBody,
trigger_notification,
)
from InvenTree.format import format_money from InvenTree.format import format_money
logger = logging.getLogger('inventree') logger = logging.getLogger('inventree')
@ -49,7 +52,9 @@ def construct_absolute_url(*arg, **kwargs):
if not site_url: if not site_url:
# Otherwise, try to use the InvenTree setting # Otherwise, try to use the InvenTree setting
try: try:
site_url = common.models.InvenTreeSetting.get_setting('INVENTREE_BASE_URL', create=False, cache=False) site_url = common.models.InvenTreeSetting.get_setting(
'INVENTREE_BASE_URL', create=False, cache=False
)
except (ProgrammingError, OperationalError): except (ProgrammingError, OperationalError):
pass pass
@ -100,12 +105,22 @@ def download_image_from_url(remote_url, timeout=2.5):
validator(remote_url) validator(remote_url)
# Calculate maximum allowable image size (in bytes) # Calculate maximum allowable image size (in bytes)
max_size = int(common.models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE')) * 1024 * 1024 max_size = (
int(
common.models.InvenTreeSetting.get_setting(
'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE'
)
)
* 1024
* 1024
)
# Add user specified user-agent to request (if specified) # Add user specified user-agent to request (if specified)
user_agent = common.models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT') user_agent = common.models.InvenTreeSetting.get_setting(
'INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT'
)
if user_agent: if user_agent:
headers = {"User-Agent": user_agent} headers = {'User-Agent': user_agent}
else: else:
headers = None headers = None
@ -120,24 +135,28 @@ def download_image_from_url(remote_url, timeout=2.5):
# Throw an error if anything goes wrong # Throw an error if anything goes wrong
response.raise_for_status() response.raise_for_status()
except requests.exceptions.ConnectionError as exc: except requests.exceptions.ConnectionError as exc:
raise Exception(_("Connection error") + f": {str(exc)}") raise Exception(_('Connection error') + f': {str(exc)}')
except requests.exceptions.Timeout as exc: except requests.exceptions.Timeout as exc:
raise exc raise exc
except requests.exceptions.HTTPError: except requests.exceptions.HTTPError:
raise requests.exceptions.HTTPError(_("Server responded with invalid status code") + f": {response.status_code}") raise requests.exceptions.HTTPError(
_('Server responded with invalid status code') + f': {response.status_code}'
)
except Exception as exc: except Exception as exc:
raise Exception(_("Exception occurred") + f": {str(exc)}") raise Exception(_('Exception occurred') + f': {str(exc)}')
if response.status_code != 200: if response.status_code != 200:
raise Exception(_("Server responded with invalid status code") + f": {response.status_code}") raise Exception(
_('Server responded with invalid status code') + f': {response.status_code}'
)
try: try:
content_length = int(response.headers.get('Content-Length', 0)) content_length = int(response.headers.get('Content-Length', 0))
except ValueError: except ValueError:
raise ValueError(_("Server responded with invalid Content-Length value")) raise ValueError(_('Server responded with invalid Content-Length value'))
if content_length > max_size: if content_length > max_size:
raise ValueError(_("Image size is too large")) raise ValueError(_('Image size is too large'))
# Download the file, ensuring we do not exceed the reported size # Download the file, ensuring we do not exceed the reported size
file = io.BytesIO() file = io.BytesIO()
@ -149,12 +168,12 @@ def download_image_from_url(remote_url, timeout=2.5):
dl_size += len(chunk) dl_size += len(chunk)
if dl_size > max_size: if dl_size > max_size:
raise ValueError(_("Image download exceeded maximum size")) raise ValueError(_('Image download exceeded maximum size'))
file.write(chunk) file.write(chunk)
if dl_size == 0: if dl_size == 0:
raise ValueError(_("Remote server returned empty response")) raise ValueError(_('Remote server returned empty response'))
# Now, attempt to convert the downloaded data to a valid image file # Now, attempt to convert the downloaded data to a valid image file
# img.verify() will throw an exception if the image is not valid # img.verify() will throw an exception if the image is not valid
@ -162,12 +181,18 @@ def download_image_from_url(remote_url, timeout=2.5):
img = Image.open(file).convert() img = Image.open(file).convert()
img.verify() img.verify()
except Exception: except Exception:
raise TypeError(_("Supplied URL is not a valid image file")) raise TypeError(_('Supplied URL is not a valid image file'))
return img return img
def render_currency(money, decimal_places=None, currency=None, min_decimal_places=None, max_decimal_places=None): def render_currency(
money,
decimal_places=None,
currency=None,
min_decimal_places=None,
max_decimal_places=None,
):
"""Render a currency / Money object to a formatted string (e.g. for reports) """Render a currency / Money object to a formatted string (e.g. for reports)
Arguments: Arguments:
@ -192,13 +217,19 @@ def render_currency(money, decimal_places=None, currency=None, min_decimal_place
pass pass
if decimal_places is None: if decimal_places is None:
decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES', 6) decimal_places = common.models.InvenTreeSetting.get_setting(
'PRICING_DECIMAL_PLACES', 6
)
if min_decimal_places is None: if min_decimal_places is None:
min_decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES_MIN', 0) min_decimal_places = common.models.InvenTreeSetting.get_setting(
'PRICING_DECIMAL_PLACES_MIN', 0
)
if max_decimal_places is None: if max_decimal_places is None:
max_decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES', 6) max_decimal_places = common.models.InvenTreeSetting.get_setting(
'PRICING_DECIMAL_PLACES', 6
)
value = Decimal(str(money.amount)).normalize() value = Decimal(str(money.amount)).normalize()
value = str(value) value = str(value)
@ -229,7 +260,9 @@ def getModelsWithMixin(mixin_class) -> list:
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
try: try:
db_models = [x.model_class() for x in ContentType.objects.all() if x is not None] db_models = [
x.model_class() for x in ContentType.objects.all() if x is not None
]
except (OperationalError, ProgrammingError): except (OperationalError, ProgrammingError):
# Database is likely not yet ready # Database is likely not yet ready
db_models = [] db_models = []
@ -237,7 +270,12 @@ def getModelsWithMixin(mixin_class) -> list:
return [x for x in db_models if x is not None and issubclass(x, mixin_class)] return [x for x in db_models if x is not None and issubclass(x, mixin_class)]
def notify_responsible(instance, sender, content: NotificationBody = InvenTreeNotificationBodies.NewOrder, exclude=None): def notify_responsible(
instance,
sender,
content: NotificationBody = InvenTreeNotificationBodies.NewOrder,
exclude=None,
):
"""Notify all responsible parties of a change in an instance. """Notify all responsible parties of a change in an instance.
Parses the supplied content with the provided instance and sender and sends a notification to all responsible users, Parses the supplied content with the provided instance and sender and sends a notification to all responsible users,
@ -249,10 +287,18 @@ def notify_responsible(instance, sender, content: NotificationBody = InvenTreeNo
content (NotificationBody, optional): _description_. Defaults to InvenTreeNotificationBodies.NewOrder. content (NotificationBody, optional): _description_. Defaults to InvenTreeNotificationBodies.NewOrder.
exclude (User, optional): User instance that should be excluded. Defaults to None. exclude (User, optional): User instance that should be excluded. Defaults to None.
""" """
notify_users([instance.responsible], instance, sender, content=content, exclude=exclude) notify_users(
[instance.responsible], instance, sender, content=content, exclude=exclude
)
def notify_users(users, instance, sender, content: NotificationBody = InvenTreeNotificationBodies.NewOrder, exclude=None): def notify_users(
users,
instance,
sender,
content: NotificationBody = InvenTreeNotificationBodies.NewOrder,
exclude=None,
):
"""Notify all passed users or groups. """Notify all passed users or groups.
Parses the supplied content with the provided instance and sender and sends a notification to all users, Parses the supplied content with the provided instance and sender and sends a notification to all users,
@ -278,10 +324,10 @@ def notify_users(users, instance, sender, content: NotificationBody = InvenTreeN
'instance': instance, 'instance': instance,
'name': content.name.format(**content_context), 'name': content.name.format(**content_context),
'message': content.message.format(**content_context), 'message': content.message.format(**content_context),
'link': InvenTree.helpers_model.construct_absolute_url(instance.get_absolute_url()), 'link': InvenTree.helpers_model.construct_absolute_url(
'template': { instance.get_absolute_url()
'subject': content.name.format(**content_context), ),
} 'template': {'subject': content.name.format(**content_context)},
} }
if content.template: if content.template:

View File

@ -1,4 +1,5 @@
"""Functions for magic login.""" """Functions for magic login."""
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.sites.models import Site from django.contrib.sites.models import Site
@ -17,15 +18,13 @@ def send_simple_login_email(user, link):
"""Send an email with the login link to this user.""" """Send an email with the login link to this user."""
site = Site.objects.get_current() site = Site.objects.get_current()
context = { context = {'username': user.username, 'site_name': site.name, 'link': link}
"username": user.username, email_plaintext_message = render_to_string(
"site_name": site.name, 'InvenTree/user_simple_login.txt', context
"link": link, )
}
email_plaintext_message = render_to_string("InvenTree/user_simple_login.txt", context)
send_mail( send_mail(
_(f"[{site.name}] Log in to the app"), _(f'[{site.name}] Log in to the app'),
email_plaintext_message, email_plaintext_message,
settings.DEFAULT_FROM_EMAIL, settings.DEFAULT_FROM_EMAIL,
[user.email], [user.email],
@ -35,7 +34,7 @@ def send_simple_login_email(user, link):
class GetSimpleLoginSerializer(serializers.Serializer): class GetSimpleLoginSerializer(serializers.Serializer):
"""Serializer for the simple login view.""" """Serializer for the simple login view."""
email = serializers.CharField(label=_("Email")) email = serializers.CharField(label=_('Email'))
class GetSimpleLoginView(APIView): class GetSimpleLoginView(APIView):
@ -48,14 +47,14 @@ class GetSimpleLoginView(APIView):
"""Get the token for the current user or fail.""" """Get the token for the current user or fail."""
serializer = self.serializer_class(data=request.data) serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.email_submitted(email=serializer.data["email"]) self.email_submitted(email=serializer.data['email'])
return Response({"status": "ok"}) return Response({'status': 'ok'})
def email_submitted(self, email): def email_submitted(self, email):
"""Notify user about link.""" """Notify user about link."""
user = self.get_user(email) user = self.get_user(email)
if user is None: if user is None:
print("user not found:", email) print('user not found:', email)
return return
link = self.create_link(user) link = self.create_link(user)
send_simple_login_email(user, link) send_simple_login_email(user, link)
@ -69,7 +68,7 @@ class GetSimpleLoginView(APIView):
def create_link(self, user): def create_link(self, user):
"""Create a login link for this user.""" """Create a login link for this user."""
link = reverse("sesame-login") link = reverse('sesame-login')
link = self.request.build_absolute_uri(link) link = self.request.build_absolute_uri(link)
link += sesame.utils.get_query_string(user) link += sesame.utils.get_query_string(user)
return link return link

View File

@ -12,7 +12,7 @@ class Command(BaseCommand):
def handle(self, *args, **kwargs): def handle(self, *args, **kwargs):
"""Cleanup old (undefined) settings in the database.""" """Cleanup old (undefined) settings in the database."""
logger.info("Collecting settings") logger.info('Collecting settings')
from common.models import InvenTreeSetting, InvenTreeUserSetting from common.models import InvenTreeSetting, InvenTreeUserSetting
# general settings # general settings
@ -35,4 +35,4 @@ class Command(BaseCommand):
setting.delete() setting.delete()
logger.info("deleted user setting '%s'", setting.key) logger.info("deleted user setting '%s'", setting.key)
logger.info("checked all settings") logger.info('checked all settings')

View File

@ -13,7 +13,6 @@ from django.utils.translation import override as lang_over
def render_file(file_name, source, target, locales, ctx): def render_file(file_name, source, target, locales, ctx):
"""Renders a file into all provided locales.""" """Renders a file into all provided locales."""
for locale in locales: for locale in locales:
# Enforce lower-case for locale names # Enforce lower-case for locale names
locale = locale.lower() locale = locale.lower()
locale = locale.replace('_', '-') locale = locale.replace('_', '-')
@ -50,15 +49,19 @@ class Command(BaseCommand):
# render! # render!
request = HttpRequest() request = HttpRequest()
ctx = {} ctx = {}
processors = tuple(import_string(path) for path in settings.STATFILES_I18_PROCESSORS) processors = tuple(
import_string(path) for path in settings.STATFILES_I18_PROCESSORS
)
for processor in processors: for processor in processors:
ctx.update(processor(request)) ctx.update(processor(request))
for file in os.listdir(SOURCE_DIR, ): for file in os.listdir(SOURCE_DIR):
path = os.path.join(SOURCE_DIR, file) path = os.path.join(SOURCE_DIR, file)
if os.path.exists(path) and os.path.isfile(path): if os.path.exists(path) and os.path.isfile(path):
print(f"render {file}") print(f'render {file}')
render_file(file, SOURCE_DIR, TARGET_DIR, locales, ctx) render_file(file, SOURCE_DIR, TARGET_DIR, locales, ctx)
else: else:
raise NotImplementedError('Using multi-level directories is not implemented at this point') # TODO multilevel dir if needed raise NotImplementedError(
print(f"rendered all files in {SOURCE_DIR}") 'Using multi-level directories is not implemented at this point'
) # TODO multilevel dir if needed
print(f'rendered all files in {SOURCE_DIR}')

View File

@ -13,45 +13,50 @@ class Command(BaseCommand):
"""Rebuild all database models which leverage the MPTT structure.""" """Rebuild all database models which leverage the MPTT structure."""
# Part model # Part model
try: try:
print("Rebuilding Part objects") print('Rebuilding Part objects')
from part.models import Part from part.models import Part
Part.objects.rebuild() Part.objects.rebuild()
except Exception: except Exception:
print("Error rebuilding Part objects") print('Error rebuilding Part objects')
# Part category # Part category
try: try:
print("Rebuilding PartCategory objects") print('Rebuilding PartCategory objects')
from part.models import PartCategory from part.models import PartCategory
PartCategory.objects.rebuild() PartCategory.objects.rebuild()
except Exception: except Exception:
print("Error rebuilding PartCategory objects") print('Error rebuilding PartCategory objects')
# StockItem model # StockItem model
try: try:
print("Rebuilding StockItem objects") print('Rebuilding StockItem objects')
from stock.models import StockItem from stock.models import StockItem
StockItem.objects.rebuild() StockItem.objects.rebuild()
except Exception: except Exception:
print("Error rebuilding StockItem objects") print('Error rebuilding StockItem objects')
# StockLocation model # StockLocation model
try: try:
print("Rebuilding StockLocation objects") print('Rebuilding StockLocation objects')
from stock.models import StockLocation from stock.models import StockLocation
StockLocation.objects.rebuild() StockLocation.objects.rebuild()
except Exception: except Exception:
print("Error rebuilding StockLocation objects") print('Error rebuilding StockLocation objects')
# Build model # Build model
try: try:
print("Rebuilding Build objects") print('Rebuilding Build objects')
from build.models import Build from build.models import Build
Build.objects.rebuild() Build.objects.rebuild()
except Exception: except Exception:
print("Error rebuilding Build objects") print('Error rebuilding Build objects')

View File

@ -37,20 +37,20 @@ class Command(BaseCommand):
def handle(self, *args, **kwargs): def handle(self, *args, **kwargs):
"""Rebuild all thumbnail images.""" """Rebuild all thumbnail images."""
logger.info("Rebuilding Part thumbnails") logger.info('Rebuilding Part thumbnails')
for part in Part.objects.exclude(image=None): for part in Part.objects.exclude(image=None):
try: try:
self.rebuild_thumbnail(part) self.rebuild_thumbnail(part)
except (OperationalError, ProgrammingError): except (OperationalError, ProgrammingError):
logger.exception("ERROR: Database read error.") logger.exception('ERROR: Database read error.')
break break
logger.info("Rebuilding Company thumbnails") logger.info('Rebuilding Company thumbnails')
for company in Company.objects.exclude(image=None): for company in Company.objects.exclude(image=None):
try: try:
self.rebuild_thumbnail(company) self.rebuild_thumbnail(company)
except (OperationalError, ProgrammingError): except (OperationalError, ProgrammingError):
logger.exception("ERROR: abase read error.") logger.exception('ERROR: abase read error.')
break break

View File

@ -18,7 +18,12 @@ class Command(BaseCommand):
if not mail: if not mail:
raise KeyError('A mail is required') raise KeyError('A mail is required')
user = get_user_model() user = get_user_model()
mfa_user = [*set(user.objects.filter(email=mail) | user.objects.filter(emailaddress__email=mail))] mfa_user = [
*set(
user.objects.filter(email=mail)
| user.objects.filter(emailaddress__email=mail)
)
]
if len(mfa_user) == 0: if len(mfa_user) == 0:
print('No user with this mail associated') print('No user with this mail associated')

View File

@ -12,12 +12,11 @@ class Command(BaseCommand):
def handle(self, *args, **kwargs): def handle(self, *args, **kwargs):
"""Wait till the database is ready.""" """Wait till the database is ready."""
self.stdout.write("Waiting for database...") self.stdout.write('Waiting for database...')
connected = False connected = False
while not connected: while not connected:
time.sleep(2) time.sleep(2)
try: try:
@ -26,12 +25,12 @@ class Command(BaseCommand):
connected = True connected = True
except OperationalError as e: except OperationalError as e:
self.stdout.write(f"Could not connect to database: {e}") self.stdout.write(f'Could not connect to database: {e}')
except ImproperlyConfigured as e: except ImproperlyConfigured as e:
self.stdout.write(f"Improperly configured: {e}") self.stdout.write(f'Improperly configured: {e}')
else: else:
if not connection.is_usable(): if not connection.is_usable():
self.stdout.write("Database configuration is not usable") self.stdout.write('Database configuration is not usable')
if connected: if connected:
self.stdout.write("Database connection successful!") self.stdout.write('Database connection successful!')

View File

@ -45,8 +45,9 @@ class InvenTreeMetadata(SimpleMetadata):
context = {} context = {}
if str2bool(request.query_params.get('context', False)): if str2bool(request.query_params.get('context', False)):
if hasattr(self, 'serializer') and hasattr(
if hasattr(self, 'serializer') and hasattr(self.serializer, 'get_context_data'): self.serializer, 'get_context_data'
):
context = self.serializer.get_context_data() context = self.serializer.get_context_data()
metadata['context'] = context metadata['context'] = context
@ -68,7 +69,7 @@ class InvenTreeMetadata(SimpleMetadata):
metadata['model'] = tbl_label metadata['model'] = tbl_label
table = f"{app_label}_{tbl_label}" table = f'{app_label}_{tbl_label}'
actions = metadata.get('actions', None) actions = metadata.get('actions', None)
@ -86,12 +87,11 @@ class InvenTreeMetadata(SimpleMetadata):
} }
# let the view define a custom rolemap # let the view define a custom rolemap
if hasattr(view, "rolemap"): if hasattr(view, 'rolemap'):
rolemap.update(view.rolemap) rolemap.update(view.rolemap)
# Remove any HTTP methods that the user does not have permission for # Remove any HTTP methods that the user does not have permission for
for method, permission in rolemap.items(): for method, permission in rolemap.items():
result = check(user, table, permission) result = check(user, table, permission)
if method in actions and not result: if method in actions and not result:
@ -123,10 +123,7 @@ class InvenTreeMetadata(SimpleMetadata):
model_class = None model_class = None
# Attributes to copy extra attributes from the model to the field (if they don't exist) # Attributes to copy extra attributes from the model to the field (if they don't exist)
extra_attributes = [ extra_attributes = ['help_text', 'max_length']
'help_text',
'max_length',
]
try: try:
model_class = serializer.Meta.model model_class = serializer.Meta.model
@ -142,11 +139,8 @@ class InvenTreeMetadata(SimpleMetadata):
# Iterate through simple fields # Iterate through simple fields
for name, field in model_fields.fields.items(): for name, field in model_fields.fields.items():
if name in serializer_info.keys(): if name in serializer_info.keys():
if field.has_default(): if field.has_default():
default = field.default default = field.default
if callable(default): if callable(default):
@ -162,13 +156,11 @@ class InvenTreeMetadata(SimpleMetadata):
for attr in extra_attributes: for attr in extra_attributes:
if attr not in serializer_info[name]: if attr not in serializer_info[name]:
if hasattr(field, attr): if hasattr(field, attr):
serializer_info[name][attr] = getattr(field, attr) serializer_info[name][attr] = getattr(field, attr)
# Iterate through relations # Iterate through relations
for name, relation in model_fields.relations.items(): for name, relation in model_fields.relations.items():
if name not in serializer_info.keys(): if name not in serializer_info.keys():
# Skip relation not defined in serializer # Skip relation not defined in serializer
continue continue
@ -179,11 +171,17 @@ class InvenTreeMetadata(SimpleMetadata):
# Extract and provide the "limit_choices_to" filters # Extract and provide the "limit_choices_to" filters
# This is used to automatically filter AJAX requests # This is used to automatically filter AJAX requests
serializer_info[name]['filters'] = relation.model_field.get_limit_choices_to() serializer_info[name]['filters'] = (
relation.model_field.get_limit_choices_to()
)
for attr in extra_attributes: for attr in extra_attributes:
if attr not in serializer_info[name] and hasattr(relation.model_field, attr): if attr not in serializer_info[name] and hasattr(
serializer_info[name][attr] = getattr(relation.model_field, attr) relation.model_field, attr
):
serializer_info[name][attr] = getattr(
relation.model_field, attr
)
if name in model_default_values: if name in model_default_values:
serializer_info[name]['default'] = model_default_values[name] serializer_info[name]['default'] = model_default_values[name]
@ -220,11 +218,9 @@ class InvenTreeMetadata(SimpleMetadata):
"""If there is an instance associated with this API View, introspect that instance to find any specific API info.""" """If there is an instance associated with this API View, introspect that instance to find any specific API info."""
if hasattr(instance, 'api_instance_filters'): if hasattr(instance, 'api_instance_filters'):
instance_filters = instance.api_instance_filters() instance_filters = instance.api_instance_filters()
for field_name, field_filters in instance_filters.items(): for field_name, field_filters in instance_filters.items():
if field_name not in serializer_info.keys(): if field_name not in serializer_info.keys():
# The field might be missing, but is added later on # The field might be missing, but is added later on
# This function seems to get called multiple times? # This function seems to get called multiple times?
@ -256,17 +252,20 @@ class InvenTreeMetadata(SimpleMetadata):
# Force non-nullable fields to read as "required" # Force non-nullable fields to read as "required"
# (even if there is a default value!) # (even if there is a default value!)
if not field.allow_null and not (hasattr(field, 'allow_blank') and field.allow_blank): if not field.allow_null and not (
hasattr(field, 'allow_blank') and field.allow_blank
):
field_info['required'] = True field_info['required'] = True
# Introspect writable related fields # Introspect writable related fields
if field_info['type'] == 'field' and not field_info['read_only']: if field_info['type'] == 'field' and not field_info['read_only']:
# If the field is a PrimaryKeyRelatedField, we can extract the model from the queryset # If the field is a PrimaryKeyRelatedField, we can extract the model from the queryset
if isinstance(field, serializers.PrimaryKeyRelatedField): if isinstance(field, serializers.PrimaryKeyRelatedField):
model = field.queryset.model model = field.queryset.model
else: else:
logger.debug("Could not extract model for:", field_info.get('label'), '->', field) logger.debug(
'Could not extract model for:', field_info.get('label'), '->', field
)
model = None model = None
if model: if model:
@ -287,4 +286,4 @@ class InvenTreeMetadata(SimpleMetadata):
return field_info return field_info
InvenTreeMetadata.label_lookup[DependentField] = "dependent field" InvenTreeMetadata.label_lookup[DependentField] = 'dependent field'

View File

@ -9,14 +9,13 @@ from django.http import HttpResponse
from django.shortcuts import redirect from django.shortcuts import redirect
from django.urls import Resolver404, include, re_path, resolve, reverse_lazy from django.urls import Resolver404, include, re_path, resolve, reverse_lazy
from allauth_2fa.middleware import (AllauthTwoFactorMiddleware, from allauth_2fa.middleware import AllauthTwoFactorMiddleware, BaseRequire2FAMiddleware
BaseRequire2FAMiddleware)
from error_report.middleware import ExceptionProcessor from error_report.middleware import ExceptionProcessor
from InvenTree.urls import frontendpatterns from InvenTree.urls import frontendpatterns
from users.models import ApiToken from users.models import ApiToken
logger = logging.getLogger("inventree") logger = logging.getLogger('inventree')
class AuthRequiredMiddleware(object): class AuthRequiredMiddleware(object):
@ -64,11 +63,20 @@ class AuthRequiredMiddleware(object):
elif request.path_info.startswith('/accounts/'): elif request.path_info.startswith('/accounts/'):
authorized = True authorized = True
elif request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/') or request.path_info.startswith('/assets/') or request.path_info == f'/{settings.FRONTEND_URL_BASE}': elif (
request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/')
or request.path_info.startswith('/assets/')
or request.path_info == f'/{settings.FRONTEND_URL_BASE}'
):
authorized = True authorized = True
elif 'Authorization' in request.headers.keys() or 'authorization' in request.headers.keys(): elif (
auth = request.headers.get('Authorization', request.headers.get('authorization')).strip() 'Authorization' in request.headers.keys()
or 'authorization' in request.headers.keys()
):
auth = request.headers.get(
'Authorization', request.headers.get('authorization')
).strip()
if auth.lower().startswith('token') and len(auth.split()) == 2: if auth.lower().startswith('token') and len(auth.split()) == 2:
token_key = auth.split()[1] token_key = auth.split()[1]
@ -78,13 +86,12 @@ class AuthRequiredMiddleware(object):
token = ApiToken.objects.get(key=token_key) token = ApiToken.objects.get(key=token_key)
if token.active and token.user: if token.active and token.user:
# Provide the user information to the request # Provide the user information to the request
request.user = token.user request.user = token.user
authorized = True authorized = True
except ApiToken.DoesNotExist: except ApiToken.DoesNotExist:
logger.warning("Access denied for unknown token %s", token_key) logger.warning('Access denied for unknown token %s', token_key)
# No authorization was found for the request # No authorization was found for the request
if not authorized: if not authorized:
@ -99,17 +106,16 @@ class AuthRequiredMiddleware(object):
] ]
# Do not redirect requests to any of these paths # Do not redirect requests to any of these paths
paths_ignore = [ paths_ignore = ['/api/', '/js/', '/media/', '/static/']
'/api/',
'/js/',
'/media/',
'/static/',
]
if path not in urls and not any(path.startswith(p) for p in paths_ignore): if path not in urls and not any(
path.startswith(p) for p in paths_ignore
):
# Save the 'next' parameter to pass through to the login view # Save the 'next' parameter to pass through to the login view
return redirect(f'{reverse_lazy("account_login")}?next={request.path}') return redirect(
f'{reverse_lazy("account_login")}?next={request.path}'
)
# Return a 401 (Unauthorized) response code for this request # Return a 401 (Unauthorized) response code for this request
return HttpResponse('Unauthorized', status=401) return HttpResponse('Unauthorized', status=401)
@ -123,6 +129,7 @@ url_matcher = re_path('', include(frontendpatterns))
class Check2FAMiddleware(BaseRequire2FAMiddleware): class Check2FAMiddleware(BaseRequire2FAMiddleware):
"""Check if user is required to have MFA enabled.""" """Check if user is required to have MFA enabled."""
def require_2fa(self, request): def require_2fa(self, request):
"""Use setting to check if MFA should be enforced for frontend page.""" """Use setting to check if MFA should be enforced for frontend page."""
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
@ -137,6 +144,7 @@ class Check2FAMiddleware(BaseRequire2FAMiddleware):
class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware): class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware):
"""This function ensures only frontend code triggers the MFA auth cycle.""" """This function ensures only frontend code triggers the MFA auth cycle."""
def process_request(self, request): def process_request(self, request):
"""Check if requested url is forntend and enforce MFA check.""" """Check if requested url is forntend and enforce MFA check."""
try: try:
@ -148,6 +156,7 @@ class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware):
class InvenTreeRemoteUserMiddleware(PersistentRemoteUserMiddleware): class InvenTreeRemoteUserMiddleware(PersistentRemoteUserMiddleware):
"""Middleware to check if HTTP-header based auth is enabled and to set it up.""" """Middleware to check if HTTP-header based auth is enabled and to set it up."""
header = settings.REMOTE_LOGIN_HEADER header = settings.REMOTE_LOGIN_HEADER
def process_request(self, request): def process_request(self, request):

View File

@ -62,7 +62,7 @@ class DiffMixin:
return field_name in self.get_field_deltas() return field_name in self.get_field_deltas()
class CleanMixin(): class CleanMixin:
"""Model mixin class which cleans inputs using the Mozilla bleach tools.""" """Model mixin class which cleans inputs using the Mozilla bleach tools."""
# Define a list of field names which will *not* be cleaned # Define a list of field names which will *not* be cleaned
@ -74,13 +74,17 @@ class CleanMixin():
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_create(serializer) self.perform_create(serializer)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
"""Override to clean data before processing it.""" """Override to clean data before processing it."""
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
instance = self.get_object() instance = self.get_object()
serializer = self.get_serializer(instance, data=self.clean_data(request.data), partial=partial) serializer = self.get_serializer(
instance, data=self.clean_data(request.data), partial=partial
)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_update(serializer) self.perform_update(serializer)
@ -113,9 +117,7 @@ class CleanMixin():
field = model._meta.get_field(field) field = model._meta.get_field(field)
# The following field types allow newline characters # The following field types allow newline characters
allow_newline = [ allow_newline = [InvenTreeNotesField]
InvenTreeNotesField,
]
for field_type in allow_newline: for field_type in allow_newline:
if issubclass(type(field), field_type): if issubclass(type(field), field_type):
@ -127,7 +129,9 @@ class CleanMixin():
except FieldDoesNotExist: except FieldDoesNotExist:
pass pass
cleaned = remove_non_printable_characters(cleaned, remove_newline=remove_newline) cleaned = remove_non_printable_characters(
cleaned, remove_newline=remove_newline
)
return cleaned return cleaned
@ -148,7 +152,6 @@ class CleanMixin():
clean_data = {} clean_data = {}
for k, v in data.items(): for k, v in data.items():
if k in self.SAFE_FIELDS: if k in self.SAFE_FIELDS:
ret = v ret = v
elif isinstance(v, str): elif isinstance(v, str):
@ -177,16 +180,19 @@ class CreateAPI(CleanMixin, generics.CreateAPIView):
class RetrieveAPI(generics.RetrieveAPIView): class RetrieveAPI(generics.RetrieveAPIView):
"""View for retrieve API.""" """View for retrieve API."""
pass pass
class RetrieveUpdateAPI(CleanMixin, generics.RetrieveUpdateAPIView): class RetrieveUpdateAPI(CleanMixin, generics.RetrieveUpdateAPIView):
"""View for retrieve and update API.""" """View for retrieve and update API."""
pass pass
class CustomDestroyModelMixin: class CustomDestroyModelMixin:
"""This mixin was created pass the kwargs from the API to the models.""" """This mixin was created pass the kwargs from the API to the models."""
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
"""Custom destroy method to pass kwargs.""" """Custom destroy method to pass kwargs."""
instance = self.get_object() instance = self.get_object()
@ -198,11 +204,14 @@ class CustomDestroyModelMixin:
instance.delete(**kwargs) instance.delete(**kwargs)
class CustomRetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, class CustomRetrieveUpdateDestroyAPIView(
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin, mixins.UpdateModelMixin,
CustomDestroyModelMixin, CustomDestroyModelMixin,
generics.GenericAPIView): generics.GenericAPIView,
):
"""This APIView was created pass the kwargs from the API to the models.""" """This APIView was created pass the kwargs from the API to the models."""
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Custom get method to pass kwargs.""" """Custom get method to pass kwargs."""
return self.retrieve(request, *args, **kwargs) return self.retrieve(request, *args, **kwargs)

View File

@ -58,6 +58,7 @@ class MetadataMixin(models.Model):
class Meta: class Meta:
"""Meta for MetadataMixin.""" """Meta for MetadataMixin."""
abstract = True abstract = True
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
@ -78,10 +79,13 @@ class MetadataMixin(models.Model):
self.metadata = {} self.metadata = {}
if type(self.metadata) is not dict: if type(self.metadata) is not dict:
raise ValidationError({'metadata': _('Metadata must be a python dict object')}) raise ValidationError({
'metadata': _('Metadata must be a python dict object')
})
metadata = models.JSONField( metadata = models.JSONField(
blank=True, null=True, blank=True,
null=True,
verbose_name=_('Plugin Metadata'), verbose_name=_('Plugin Metadata'),
help_text=_('JSON metadata field, for use by external plugins'), help_text=_('JSON metadata field, for use by external plugins'),
) )
@ -100,7 +104,9 @@ class MetadataMixin(models.Model):
return self.metadata.get(key, backup_value) return self.metadata.get(key, backup_value)
def set_metadata(self, key: str, data, commit: bool = True, overwrite: bool = False): def set_metadata(
self, key: str, data, commit: bool = True, overwrite: bool = False
):
"""Save the provided metadata under the provided key. """Save the provided metadata under the provided key.
Args: Args:
@ -137,7 +143,6 @@ class DataImportMixin(object):
fields = cls.IMPORT_FIELDS fields = cls.IMPORT_FIELDS
for name, field in fields.items(): for name, field in fields.items():
# Attempt to extract base field information from the model # Attempt to extract base field information from the model
base_field = None base_field = None
@ -207,7 +212,10 @@ class ReferenceIndexingMixin(models.Model):
# import at function level to prevent cyclic imports # import at function level to prevent cyclic imports
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
return InvenTreeSetting.get_setting(cls.REFERENCE_PATTERN_SETTING, create=False).strip()
return InvenTreeSetting.get_setting(
cls.REFERENCE_PATTERN_SETTING, create=False
).strip()
@classmethod @classmethod
def get_reference_context(cls): def get_reference_context(cls):
@ -216,10 +224,7 @@ class ReferenceIndexingMixin(models.Model):
- Returns a python dict object which contains the context data for formatting the reference string. - Returns a python dict object which contains the context data for formatting the reference string.
- The default implementation provides some default context information - The default implementation provides some default context information
""" """
return { return {'ref': cls.get_next_reference(), 'date': datetime.now()}
'ref': cls.get_next_reference(),
'date': datetime.now(),
}
@classmethod @classmethod
def get_most_recent_item(cls): def get_most_recent_item(cls):
@ -246,7 +251,9 @@ class ReferenceIndexingMixin(models.Model):
reference = latest.reference.strip reference = latest.reference.strip
try: try:
reference = InvenTree.format.extract_named_group('ref', reference, cls.get_reference_pattern()) reference = InvenTree.format.extract_named_group(
'ref', reference, cls.get_reference_pattern()
)
except Exception: except Exception:
# If reference cannot be extracted using the pattern, try just the integer value # If reference cannot be extracted using the pattern, try just the integer value
reference = str(latest.reference_int) reference = str(latest.reference_int)
@ -296,7 +303,7 @@ class ReferenceIndexingMixin(models.Model):
if recent: if recent:
reference = recent.reference reference = recent.reference
else: else:
reference = "" reference = ''
return reference return reference
@ -309,20 +316,20 @@ class ReferenceIndexingMixin(models.Model):
info = InvenTree.format.parse_format_string(pattern) info = InvenTree.format.parse_format_string(pattern)
except Exception as exc: except Exception as exc:
raise ValidationError({ raise ValidationError({
"value": _("Improperly formatted pattern") + ": " + str(exc) 'value': _('Improperly formatted pattern') + ': ' + str(exc)
}) })
# Check that only 'allowed' keys are provided # Check that only 'allowed' keys are provided
for key in info.keys(): for key in info.keys():
if key not in ctx.keys(): if key not in ctx.keys():
raise ValidationError({ raise ValidationError({
"value": _("Unknown format key specified") + f": '{key}'" 'value': _('Unknown format key specified') + f": '{key}'"
}) })
# Check that the 'ref' variable is specified # Check that the 'ref' variable is specified
if 'ref' not in info.keys(): if 'ref' not in info.keys():
raise ValidationError({ raise ValidationError({
'value': _("Missing required format key") + ": 'ref'" 'value': _('Missing required format key') + ": 'ref'"
}) })
@classmethod @classmethod
@ -333,14 +340,16 @@ class ReferenceIndexingMixin(models.Model):
value = str(value).strip() value = str(value).strip()
if len(value) == 0: if len(value) == 0:
raise ValidationError(_("Reference field cannot be empty")) raise ValidationError(_('Reference field cannot be empty'))
# An 'empty' pattern means no further validation is required # An 'empty' pattern means no further validation is required
if not pattern: if not pattern:
return return
if not InvenTree.format.validate_string(value, pattern): if not InvenTree.format.validate_string(value, pattern):
raise ValidationError(_("Reference must match required pattern") + ": " + pattern) raise ValidationError(
_('Reference must match required pattern') + ': ' + pattern
)
# Check that the reference field can be rebuild # Check that the reference field can be rebuild
cls.rebuild_reference_field(value, validate=True) cls.rebuild_reference_field(value, validate=True)
@ -361,7 +370,9 @@ class ReferenceIndexingMixin(models.Model):
""" """
try: try:
# Extract named group based on provided pattern # Extract named group based on provided pattern
reference = InvenTree.format.extract_named_group('ref', reference, cls.get_reference_pattern()) reference = InvenTree.format.extract_named_group(
'ref', reference, cls.get_reference_pattern()
)
except Exception: except Exception:
pass pass
@ -369,16 +380,14 @@ class ReferenceIndexingMixin(models.Model):
if validate: if validate:
if reference_int > models.BigIntegerField.MAX_BIGINT: if reference_int > models.BigIntegerField.MAX_BIGINT:
raise ValidationError({ raise ValidationError({'reference': _('Reference number is too large')})
"reference": _("Reference number is too large")
})
return reference_int return reference_int
reference_int = models.BigIntegerField(default=0) reference_int = models.BigIntegerField(default=0)
def extract_int(reference, clip=0x7fffffff, allow_negative=False): def extract_int(reference, clip=0x7FFFFFFF, allow_negative=False):
"""Extract an integer out of reference.""" """Extract an integer out of reference."""
# Default value if we cannot convert to an integer # Default value if we cannot convert to an integer
ref_int = 0 ref_int = 0
@ -390,7 +399,7 @@ def extract_int(reference, clip=0x7fffffff, allow_negative=False):
return 0 return 0
# Look at the start of the string - can it be "integerized"? # Look at the start of the string - can it be "integerized"?
result = re.match(r"^(\d+)", reference) result = re.match(r'^(\d+)', reference)
if result and len(result.groups()) == 1: if result and len(result.groups()) == 1:
ref = result.groups()[0] ref = result.groups()[0]
@ -438,6 +447,7 @@ class InvenTreeAttachment(models.Model):
class Meta: class Meta:
"""Metaclass options. Abstract ensures no database table is created.""" """Metaclass options. Abstract ensures no database table is created."""
abstract = True abstract = True
def getSubdir(self): def getSubdir(self):
@ -445,7 +455,7 @@ class InvenTreeAttachment(models.Model):
Note: Re-implement this for each subclass of InvenTreeAttachment Note: Re-implement this for each subclass of InvenTreeAttachment
""" """
return "attachments" return 'attachments'
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Provide better validation error.""" """Provide better validation error."""
@ -472,28 +482,40 @@ class InvenTreeAttachment(models.Model):
return os.path.basename(self.attachment.name) return os.path.basename(self.attachment.name)
return str(self.link) return str(self.link)
attachment = models.FileField(upload_to=rename_attachment, verbose_name=_('Attachment'), attachment = models.FileField(
upload_to=rename_attachment,
verbose_name=_('Attachment'),
help_text=_('Select file to attach'), help_text=_('Select file to attach'),
blank=True, null=True blank=True,
null=True,
) )
link = InvenTree.fields.InvenTreeURLField( link = InvenTree.fields.InvenTreeURLField(
blank=True, null=True, blank=True,
null=True,
verbose_name=_('Link'), verbose_name=_('Link'),
help_text=_('Link to external URL') help_text=_('Link to external URL'),
) )
comment = models.CharField(blank=True, max_length=100, verbose_name=_('Comment'), help_text=_('File comment')) comment = models.CharField(
blank=True,
max_length=100,
verbose_name=_('Comment'),
help_text=_('File comment'),
)
user = models.ForeignKey( user = models.ForeignKey(
User, User,
on_delete=models.SET_NULL, on_delete=models.SET_NULL,
blank=True, null=True, blank=True,
null=True,
verbose_name=_('User'), verbose_name=_('User'),
help_text=_('User'), help_text=_('User'),
) )
upload_date = models.DateField(auto_now_add=True, null=True, blank=True, verbose_name=_('upload date')) upload_date = models.DateField(
auto_now_add=True, null=True, blank=True, verbose_name=_('upload date')
)
@property @property
def basename(self): def basename(self):
@ -522,35 +544,59 @@ class InvenTreeAttachment(models.Model):
# Check that there are no directory tricks going on... # Check that there are no directory tricks going on...
if new_file.parent != attachment_dir: if new_file.parent != attachment_dir:
logger.error("Attempted to rename attachment outside valid directory: '%s'", new_file) logger.error(
raise ValidationError(_("Invalid attachment directory")) "Attempted to rename attachment outside valid directory: '%s'", new_file
)
raise ValidationError(_('Invalid attachment directory'))
# Ignore further checks if the filename is not actually being renamed # Ignore further checks if the filename is not actually being renamed
if new_file == old_file: if new_file == old_file:
return return
forbidden = ["'", '"', "#", "@", "!", "&", "^", "<", ">", ":", ";", "/", "\\", "|", "?", "*", "%", "~", "`"] forbidden = [
"'",
'"',
'#',
'@',
'!',
'&',
'^',
'<',
'>',
':',
';',
'/',
'\\',
'|',
'?',
'*',
'%',
'~',
'`',
]
for c in forbidden: for c in forbidden:
if c in fn: if c in fn:
raise ValidationError(_(f"Filename contains illegal character '{c}'")) raise ValidationError(_(f"Filename contains illegal character '{c}'"))
if len(fn.split('.')) < 2: if len(fn.split('.')) < 2:
raise ValidationError(_("Filename missing extension")) raise ValidationError(_('Filename missing extension'))
if not old_file.exists(): if not old_file.exists():
logger.error("Trying to rename attachment '%s' which does not exist", old_file) logger.error(
"Trying to rename attachment '%s' which does not exist", old_file
)
return return
if new_file.exists(): if new_file.exists():
raise ValidationError(_("Attachment with this filename already exists")) raise ValidationError(_('Attachment with this filename already exists'))
try: try:
os.rename(old_file, new_file) os.rename(old_file, new_file)
self.attachment.name = os.path.join(self.getSubdir(), fn) self.attachment.name = os.path.join(self.getSubdir(), fn)
self.save() self.save()
except Exception: except Exception:
raise ValidationError(_("Error renaming file")) raise ValidationError(_('Error renaming file'))
def fully_qualified_url(self): def fully_qualified_url(self):
"""Return a 'fully qualified' URL for this attachment. """Return a 'fully qualified' URL for this attachment.
@ -586,10 +632,12 @@ class InvenTreeTree(MPTTModel):
class Meta: class Meta:
"""Metaclass defines extra model properties.""" """Metaclass defines extra model properties."""
abstract = True abstract = True
class MPTTMeta: class MPTTMeta:
"""Set insert order.""" """Set insert order."""
order_insertion_by = ['name'] order_insertion_by = ['name']
def delete(self, delete_children=False, delete_items=False): def delete(self, delete_children=False, delete_items=False):
@ -607,13 +655,19 @@ class InvenTreeTree(MPTTModel):
self.refresh_from_db() self.refresh_from_db()
except self.__class__.DoesNotExist: except self.__class__.DoesNotExist:
# If the object no longer exists, raise a ValidationError # If the object no longer exists, raise a ValidationError
raise ValidationError("Object %s of type %s no longer exists", str(self), str(self.__class__)) raise ValidationError(
'Object %s of type %s no longer exists', str(self), str(self.__class__)
)
# Cache node ID values for lower nodes, before we delete this one # Cache node ID values for lower nodes, before we delete this one
lower_nodes = list(self.get_descendants(include_self=False).values_list('pk', flat=True)) lower_nodes = list(
self.get_descendants(include_self=False).values_list('pk', flat=True)
)
# 1. Update nodes and items under the current node # 1. Update nodes and items under the current node
self.handle_tree_delete(delete_children=delete_children, delete_items=delete_items) self.handle_tree_delete(
delete_children=delete_children, delete_items=delete_items
)
# 2. Delete *this* node # 2. Delete *this* node
super().delete() super().delete()
@ -673,9 +727,7 @@ class InvenTreeTree(MPTTModel):
# - Move all items at any lower level to the parent of this item # - Move all items at any lower level to the parent of this item
# - Delete all descendant nodes # - Delete all descendant nodes
elif delete_children and not delete_items: elif delete_children and not delete_items:
self.get_items(cascade=True).update(**{ self.get_items(cascade=True).update(**{self.ITEM_PARENT_KEY: self.parent})
self.ITEM_PARENT_KEY: self.parent
})
self.delete_nodes(child_nodes) self.delete_nodes(child_nodes)
@ -690,9 +742,7 @@ class InvenTreeTree(MPTTModel):
# - Move all items directly associated with this node up one level # - Move all items directly associated with this node up one level
# - Move any direct child nodes up one level # - Move any direct child nodes up one level
elif not delete_children and not delete_items: elif not delete_children and not delete_items:
self.get_items(cascade=False).update(**{ self.get_items(cascade=False).update(**{self.ITEM_PARENT_KEY: self.parent})
self.ITEM_PARENT_KEY: self.parent
})
self.get_children().update(parent=self.parent) self.get_children().update(parent=self.parent)
def delete_nodes(self, nodes): def delete_nodes(self, nodes):
@ -719,8 +769,7 @@ class InvenTreeTree(MPTTModel):
super().validate_unique(exclude) super().validate_unique(exclude)
results = self.__class__.objects.filter( results = self.__class__.objects.filter(
name=self.name, name=self.name, parent=self.parent
parent=self.parent
).exclude(pk=self.pk) ).exclude(pk=self.pk)
if results.exists(): if results.exists():
@ -730,17 +779,11 @@ class InvenTreeTree(MPTTModel):
def api_instance_filters(self): def api_instance_filters(self):
"""Instance filters for InvenTreeTree models.""" """Instance filters for InvenTreeTree models."""
return { return {'parent': {'exclude_tree': self.pk}}
'parent': {
'exclude_tree': self.pk,
}
}
def construct_pathstring(self): def construct_pathstring(self):
"""Construct the pathstring for this tree node""" """Construct the pathstring for this tree node"""
return InvenTree.helpers.constructPathString( return InvenTree.helpers.constructPathString([item.name for item in self.path])
[item.name for item in self.path]
)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Custom save method for InvenTreeTree abstract model""" """Custom save method for InvenTreeTree abstract model"""
@ -748,15 +791,12 @@ class InvenTreeTree(MPTTModel):
super().save(*args, **kwargs) super().save(*args, **kwargs)
except InvalidMove: except InvalidMove:
# Provide better error for parent selection # Provide better error for parent selection
raise ValidationError({ raise ValidationError({'parent': _('Invalid choice')})
'parent': _("Invalid choice"),
})
# Re-calculate the 'pathstring' field # Re-calculate the 'pathstring' field
pathstring = self.construct_pathstring() pathstring = self.construct_pathstring()
if pathstring != self.pathstring: if pathstring != self.pathstring:
if 'force_insert' in kwargs: if 'force_insert' in kwargs:
del kwargs['force_insert'] del kwargs['force_insert']
@ -781,33 +821,29 @@ class InvenTreeTree(MPTTModel):
self.__class__.objects.bulk_update(nodes_to_update, ['pathstring']) self.__class__.objects.bulk_update(nodes_to_update, ['pathstring'])
name = models.CharField( name = models.CharField(
blank=False, blank=False, max_length=100, verbose_name=_('Name'), help_text=_('Name')
max_length=100,
verbose_name=_("Name"),
help_text=_("Name"),
) )
description = models.CharField( description = models.CharField(
blank=True, blank=True,
max_length=250, max_length=250,
verbose_name=_("Description"), verbose_name=_('Description'),
help_text=_("Description (optional)") help_text=_('Description (optional)'),
) )
# When a category is deleted, graft the children onto its parent # When a category is deleted, graft the children onto its parent
parent = TreeForeignKey('self', parent = TreeForeignKey(
'self',
on_delete=models.DO_NOTHING, on_delete=models.DO_NOTHING,
blank=True, blank=True,
null=True, null=True,
verbose_name=_("parent"), verbose_name=_('parent'),
related_name='children') related_name='children',
)
# The 'pathstring' field is calculated each time the model is saved # The 'pathstring' field is calculated each time the model is saved
pathstring = models.CharField( pathstring = models.CharField(
blank=True, blank=True, max_length=250, verbose_name=_('Path'), help_text=_('Path')
max_length=250,
verbose_name=_('Path'),
help_text=_('Path')
) )
def get_items(self, cascade=False): def get_items(self, cascade=False):
@ -818,7 +854,7 @@ class InvenTreeTree(MPTTModel):
The default implementation returns an empty list The default implementation returns an empty list
""" """
raise NotImplementedError(f"items() method not implemented for {type(self)}") raise NotImplementedError(f'items() method not implemented for {type(self)}')
def getUniqueParents(self): def getUniqueParents(self):
"""Return a flat set of all parent items that exist above this node. """Return a flat set of all parent items that exist above this node.
@ -889,16 +925,11 @@ class InvenTreeTree(MPTTModel):
name: <name>, name: <name>,
} }
""" """
return [ return [{'pk': item.pk, 'name': item.name} for item in self.path]
{
'pk': item.pk,
'name': item.name
} for item in self.path
]
def __str__(self): def __str__(self):
"""String representation of a category is the full path to that category.""" """String representation of a category is the full path to that category."""
return f"{self.pathstring} - {self.description}" return f'{self.pathstring} - {self.description}'
class InvenTreeNotesMixin(models.Model): class InvenTreeNotesMixin(models.Model):
@ -914,11 +945,11 @@ class InvenTreeNotesMixin(models.Model):
Note: abstract must be true, as this is only a mixin, not a separate table Note: abstract must be true, as this is only a mixin, not a separate table
""" """
abstract = True abstract = True
notes = InvenTree.fields.InvenTreeNotesField( notes = InvenTree.fields.InvenTreeNotesField(
verbose_name=_('Notes'), verbose_name=_('Notes'), help_text=_('Markdown notes (optional)')
help_text=_('Markdown notes (optional)'),
) )
@ -941,18 +972,21 @@ class InvenTreeBarcodeMixin(models.Model):
Note: abstract must be true, as this is only a mixin, not a separate table Note: abstract must be true, as this is only a mixin, not a separate table
""" """
abstract = True abstract = True
barcode_data = models.CharField( barcode_data = models.CharField(
blank=True, max_length=500, blank=True,
max_length=500,
verbose_name=_('Barcode Data'), verbose_name=_('Barcode Data'),
help_text=_('Third party barcode data'), help_text=_('Third party barcode data'),
) )
barcode_hash = models.CharField( barcode_hash = models.CharField(
blank=True, max_length=128, blank=True,
max_length=128,
verbose_name=_('Barcode Hash'), verbose_name=_('Barcode Hash'),
help_text=_('Unique hash of barcode data') help_text=_('Unique hash of barcode data'),
) )
@classmethod @classmethod
@ -964,21 +998,17 @@ class InvenTreeBarcodeMixin(models.Model):
def format_barcode(self, **kwargs): def format_barcode(self, **kwargs):
"""Return a JSON string for formatting a QR code for this model instance.""" """Return a JSON string for formatting a QR code for this model instance."""
return InvenTree.helpers.MakeBarcode( return InvenTree.helpers.MakeBarcode(
self.__class__.barcode_model_type(), self.__class__.barcode_model_type(), self.pk, **kwargs
self.pk,
**kwargs
) )
def format_matched_response(self): def format_matched_response(self):
"""Format a standard response for a matched barcode.""" """Format a standard response for a matched barcode."""
data = { data = {'pk': self.pk}
'pk': self.pk,
}
if hasattr(self, 'get_api_url'): if hasattr(self, 'get_api_url'):
api_url = self.get_api_url() api_url = self.get_api_url()
data['api_url'] = f"{api_url}{self.pk}/" data['api_url'] = f'{api_url}{self.pk}/'
if hasattr(self, 'get_absolute_url'): if hasattr(self, 'get_absolute_url'):
data['web_url'] = self.get_absolute_url() data['web_url'] = self.get_absolute_url()
@ -995,7 +1025,9 @@ class InvenTreeBarcodeMixin(models.Model):
"""Check if a model instance exists with the specified third-party barcode hash.""" """Check if a model instance exists with the specified third-party barcode hash."""
return cls.objects.filter(barcode_hash=barcode_hash).first() return cls.objects.filter(barcode_hash=barcode_hash).first()
def assign_barcode(self, barcode_hash=None, barcode_data=None, raise_error=True, save=True): def assign_barcode(
self, barcode_hash=None, barcode_data=None, raise_error=True, save=True
):
"""Assign an external (third-party) barcode to this object.""" """Assign an external (third-party) barcode to this object."""
# Must provide either barcode_hash or barcode_data # Must provide either barcode_hash or barcode_data
if barcode_hash is None and barcode_data is None: if barcode_hash is None and barcode_data is None:
@ -1008,7 +1040,7 @@ class InvenTreeBarcodeMixin(models.Model):
# Check for existing item # Check for existing item
if self.__class__.lookup_barcode(barcode_hash) is not None: if self.__class__.lookup_barcode(barcode_hash) is not None:
if raise_error: if raise_error:
raise ValidationError(_("Existing barcode found")) raise ValidationError(_('Existing barcode found'))
else: else:
return False return False
@ -1044,20 +1076,24 @@ def after_error_logged(sender, instance: Error, created: bool, **kwargs):
users = get_user_model().objects.filter(is_staff=True) users = get_user_model().objects.filter(is_staff=True)
link = InvenTree.helpers_model.construct_absolute_url( link = InvenTree.helpers_model.construct_absolute_url(
reverse('admin:error_report_error_change', kwargs={'object_id': instance.pk}) reverse(
'admin:error_report_error_change', kwargs={'object_id': instance.pk}
)
) )
context = { context = {
'error': instance, 'error': instance,
'name': _('Server Error'), 'name': _('Server Error'),
'message': _('An error has been logged by the server.'), 'message': _('An error has been logged by the server.'),
'link': link 'link': link,
} }
target_users = [] target_users = []
for user in users: for user in users:
if common.models.InvenTreeUserSetting.get_setting('NOTIFICATION_ERROR_REPORT', True, user=user): if common.models.InvenTreeUserSetting.get_setting(
'NOTIFICATION_ERROR_REPORT', True, user=user
):
target_users.append(user) target_users.append(user)
if len(target_users) > 0: if len(target_users) > 0:
@ -1066,7 +1102,7 @@ def after_error_logged(sender, instance: Error, created: bool, **kwargs):
'inventree.error_log', 'inventree.error_log',
context=context, context=context,
targets=target_users, targets=target_users,
delivery_methods={common.notifications.UIMessageNotification, }, delivery_methods={common.notifications.UIMessageNotification},
) )
except Exception as exc: except Exception as exc:

View File

@ -18,7 +18,7 @@ def get_model_for_view(view, raise_error=True):
if hasattr(view, 'get_serializer_class'): if hasattr(view, 'get_serializer_class'):
return view.get_serializr_class().Meta.model return view.get_serializr_class().Meta.model
raise AttributeError(f"Serializer class not specified for {view.__class__}") raise AttributeError(f'Serializer class not specified for {view.__class__}')
class RolePermission(permissions.BasePermission): class RolePermission(permissions.BasePermission):
@ -62,7 +62,7 @@ class RolePermission(permissions.BasePermission):
} }
# let the view define a custom rolemap # let the view define a custom rolemap
if hasattr(view, "rolemap"): if hasattr(view, 'rolemap'):
rolemap.update(view.rolemap) rolemap.update(view.rolemap)
permission = rolemap[request.method] permission = rolemap[request.method]
@ -78,7 +78,7 @@ class RolePermission(permissions.BasePermission):
app_label = model._meta.app_label app_label = model._meta.app_label
model_name = model._meta.model_name model_name = model._meta.model_name
table = f"{app_label}_{model_name}" table = f'{app_label}_{model_name}'
except AttributeError: except AttributeError:
# We will assume that if the serializer class does *not* have a Meta, # We will assume that if the serializer class does *not* have a Meta,
# then we don't need a permission # then we don't need a permission
@ -100,12 +100,18 @@ class IsStaffOrReadOnly(permissions.IsAdminUser):
def has_permission(self, request, view): def has_permission(self, request, view):
"""Check if the user is a superuser.""" """Check if the user is a superuser."""
return bool(request.user and request.user.is_staff or request.method in permissions.SAFE_METHODS) return bool(
request.user
and request.user.is_staff
or request.method in permissions.SAFE_METHODS
)
def auth_exempt(view_func): def auth_exempt(view_func):
"""Mark a view function as being exempt from auth requirements.""" """Mark a view function as being exempt from auth requirements."""
def wrapped_view(*args, **kwargs): def wrapped_view(*args, **kwargs):
return view_func(*args, **kwargs) return view_func(*args, **kwargs)
wrapped_view.auth_exempt = True wrapped_view.auth_exempt = True
return wraps(view_func)(wrapped_view) return wraps(view_func)(wrapped_view)

View File

@ -16,11 +16,7 @@ def isImportingData():
def isRunningMigrations(): def isRunningMigrations():
"""Return True if the database is currently running migrations.""" """Return True if the database is currently running migrations."""
return any((x in sys.argv for x in [ return any((x in sys.argv for x in ['migrate', 'makemigrations', 'showmigrations']))
'migrate',
'makemigrations',
'showmigrations'
]))
def isInMainThread(): def isInMainThread():
@ -29,13 +25,15 @@ def isInMainThread():
- The RUN_MAIN env is set in that case. However if --noreload is applied, this variable - The RUN_MAIN env is set in that case. However if --noreload is applied, this variable
is not set because there are no different threads. is not set because there are no different threads.
""" """
if "runserver" in sys.argv and "--noreload" not in sys.argv: if 'runserver' in sys.argv and '--noreload' not in sys.argv:
return os.environ.get('RUN_MAIN', None) == "true" return os.environ.get('RUN_MAIN', None) == 'true'
return True return True
def canAppAccessDatabase(allow_test: bool = False, allow_plugins: bool = False, allow_shell: bool = False): def canAppAccessDatabase(
allow_test: bool = False, allow_plugins: bool = False, allow_shell: bool = False
):
"""Returns True if the apps.py file can access database records. """Returns True if the apps.py file can access database records.
There are some circumstances where we don't want the ready function in apps.py There are some circumstances where we don't want the ready function in apps.py

View File

@ -1,49 +1,196 @@
"""Functions to sanitize user input files.""" """Functions to sanitize user input files."""
from bleach import clean from bleach import clean
from bleach.css_sanitizer import CSSSanitizer from bleach.css_sanitizer import CSSSanitizer
ALLOWED_ELEMENTS_SVG = [ ALLOWED_ELEMENTS_SVG = [
'a', 'animate', 'animateColor', 'animateMotion', 'a',
'animateTransform', 'circle', 'defs', 'desc', 'ellipse', 'font-face', 'animate',
'font-face-name', 'font-face-src', 'g', 'glyph', 'hkern', 'animateColor',
'linearGradient', 'line', 'marker', 'metadata', 'missing-glyph', 'animateMotion',
'mpath', 'path', 'polygon', 'polyline', 'radialGradient', 'rect', 'animateTransform',
'set', 'stop', 'svg', 'switch', 'text', 'title', 'tspan', 'use' 'circle',
'defs',
'desc',
'ellipse',
'font-face',
'font-face-name',
'font-face-src',
'g',
'glyph',
'hkern',
'linearGradient',
'line',
'marker',
'metadata',
'missing-glyph',
'mpath',
'path',
'polygon',
'polyline',
'radialGradient',
'rect',
'set',
'stop',
'svg',
'switch',
'text',
'title',
'tspan',
'use',
] ]
ALLOWED_ATTRIBUTES_SVG = [ ALLOWED_ATTRIBUTES_SVG = [
'accent-height', 'accumulate', 'additive', 'alphabetic', 'accent-height',
'arabic-form', 'ascent', 'attributeName', 'attributeType', 'accumulate',
'baseProfile', 'bbox', 'begin', 'by', 'calcMode', 'cap-height', 'additive',
'class', 'color', 'color-rendering', 'content', 'cx', 'cy', 'd', 'dx', 'alphabetic',
'dy', 'descent', 'display', 'dur', 'end', 'fill', 'fill-opacity', 'arabic-form',
'fill-rule', 'font-family', 'font-size', 'font-stretch', 'font-style', 'ascent',
'font-variant', 'font-weight', 'from', 'fx', 'fy', 'g1', 'g2', 'attributeName',
'glyph-name', 'gradientUnits', 'hanging', 'height', 'horiz-adv-x', 'attributeType',
'horiz-origin-x', 'id', 'ideographic', 'k', 'keyPoints', 'baseProfile',
'keySplines', 'keyTimes', 'lang', 'marker-end', 'marker-mid', 'bbox',
'marker-start', 'markerHeight', 'markerUnits', 'markerWidth', 'begin',
'mathematical', 'max', 'min', 'name', 'offset', 'opacity', 'orient', 'by',
'origin', 'overline-position', 'overline-thickness', 'panose-1', 'calcMode',
'path', 'pathLength', 'points', 'preserveAspectRatio', 'r', 'refX', 'cap-height',
'refY', 'repeatCount', 'repeatDur', 'requiredExtensions', 'class',
'requiredFeatures', 'restart', 'rotate', 'rx', 'ry', 'slope', 'color',
'stemh', 'stemv', 'stop-color', 'stop-opacity', 'color-rendering',
'strikethrough-position', 'strikethrough-thickness', 'stroke', 'content',
'stroke-dasharray', 'stroke-dashoffset', 'stroke-linecap', 'cx',
'stroke-linejoin', 'stroke-miterlimit', 'stroke-opacity', 'cy',
'stroke-width', 'systemLanguage', 'target', 'text-anchor', 'to', 'd',
'transform', 'type', 'u1', 'u2', 'underline-position', 'dx',
'underline-thickness', 'unicode', 'unicode-range', 'units-per-em', 'dy',
'values', 'version', 'viewBox', 'visibility', 'width', 'widths', 'x', 'descent',
'x-height', 'x1', 'x2', 'xlink:actuate', 'xlink:arcrole', 'display',
'xlink:href', 'xlink:role', 'xlink:show', 'xlink:title', 'dur',
'xlink:type', 'xml:base', 'xml:lang', 'xml:space', 'xmlns', 'end',
'xmlns:xlink', 'y', 'y1', 'y2', 'zoomAndPan', 'style' 'fill',
'fill-opacity',
'fill-rule',
'font-family',
'font-size',
'font-stretch',
'font-style',
'font-variant',
'font-weight',
'from',
'fx',
'fy',
'g1',
'g2',
'glyph-name',
'gradientUnits',
'hanging',
'height',
'horiz-adv-x',
'horiz-origin-x',
'id',
'ideographic',
'k',
'keyPoints',
'keySplines',
'keyTimes',
'lang',
'marker-end',
'marker-mid',
'marker-start',
'markerHeight',
'markerUnits',
'markerWidth',
'mathematical',
'max',
'min',
'name',
'offset',
'opacity',
'orient',
'origin',
'overline-position',
'overline-thickness',
'panose-1',
'path',
'pathLength',
'points',
'preserveAspectRatio',
'r',
'refX',
'refY',
'repeatCount',
'repeatDur',
'requiredExtensions',
'requiredFeatures',
'restart',
'rotate',
'rx',
'ry',
'slope',
'stemh',
'stemv',
'stop-color',
'stop-opacity',
'strikethrough-position',
'strikethrough-thickness',
'stroke',
'stroke-dasharray',
'stroke-dashoffset',
'stroke-linecap',
'stroke-linejoin',
'stroke-miterlimit',
'stroke-opacity',
'stroke-width',
'systemLanguage',
'target',
'text-anchor',
'to',
'transform',
'type',
'u1',
'u2',
'underline-position',
'underline-thickness',
'unicode',
'unicode-range',
'units-per-em',
'values',
'version',
'viewBox',
'visibility',
'width',
'widths',
'x',
'x-height',
'x1',
'x2',
'xlink:actuate',
'xlink:arcrole',
'xlink:href',
'xlink:role',
'xlink:show',
'xlink:title',
'xlink:type',
'xml:base',
'xml:lang',
'xml:space',
'xmlns',
'xmlns:xlink',
'y',
'y1',
'y2',
'zoomAndPan',
'style',
] ]
def sanitize_svg(file_data, strip: bool = True, elements: str = ALLOWED_ELEMENTS_SVG, attributes: str = ALLOWED_ATTRIBUTES_SVG) -> str: def sanitize_svg(
file_data,
strip: bool = True,
elements: str = ALLOWED_ELEMENTS_SVG,
attributes: str = ALLOWED_ATTRIBUTES_SVG,
) -> str:
"""Sanitize a SVG file. """Sanitize a SVG file.
Args: Args:
@ -65,7 +212,7 @@ def sanitize_svg(file_data, strip: bool = True, elements: str = ALLOWED_ELEMENTS
attributes=attributes, attributes=attributes,
strip=strip, strip=strip,
strip_comments=strip, strip_comments=strip,
css_sanitizer=CSSSanitizer() css_sanitizer=CSSSanitizer(),
) )
return cleaned return cleaned

View File

@ -37,7 +37,7 @@ def sentry_ignore_errors():
def init_sentry(dsn, sample_rate, tags): def init_sentry(dsn, sample_rate, tags):
"""Initialize sentry.io error reporting""" """Initialize sentry.io error reporting"""
logger.info("Initializing sentry.io integration") logger.info('Initializing sentry.io integration')
sentry_sdk.init( sentry_sdk.init(
dsn=dsn, dsn=dsn,
@ -46,7 +46,9 @@ def init_sentry(dsn, sample_rate, tags):
send_default_pii=True, send_default_pii=True,
ignore_errors=sentry_ignore_errors(), ignore_errors=sentry_ignore_errors(),
release=InvenTree.version.INVENTREE_SW_VERSION, release=InvenTree.version.INVENTREE_SW_VERSION,
environment='development' if InvenTree.version.isInvenTreeDevelopmentVersion() else 'production' environment='development'
if InvenTree.version.isInvenTreeDevelopmentVersion()
else 'production',
) )
for key, val in tags.items(): for key, val in tags.items():
@ -62,11 +64,10 @@ def init_sentry(dsn, sample_rate, tags):
def report_exception(exc): def report_exception(exc):
"""Report an exception to sentry.io""" """Report an exception to sentry.io"""
if settings.SENTRY_ENABLED and settings.SENTRY_DSN: if settings.SENTRY_ENABLED and settings.SENTRY_DSN:
if not any(isinstance(exc, e) for e in sentry_ignore_errors()): if not any(isinstance(exc, e) for e in sentry_ignore_errors()):
logger.info("Reporting exception to sentry.io: %s", exc) logger.info('Reporting exception to sentry.io: %s', exc)
try: try:
sentry_sdk.capture_exception(exc) sentry_sdk.capture_exception(exc)
except Exception: except Exception:
logger.warning("Failed to report exception to sentry.io") logger.warning('Failed to report exception to sentry.io')

View File

@ -37,9 +37,9 @@ class InvenTreeMoneySerializer(MoneyField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Override default values.""" """Override default values."""
kwargs["max_digits"] = kwargs.get("max_digits", 19) kwargs['max_digits'] = kwargs.get('max_digits', 19)
self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 6) self.decimal_places = kwargs['decimal_places'] = kwargs.get('decimal_places', 6)
kwargs["required"] = kwargs.get("required", False) kwargs['required'] = kwargs.get('required', False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -57,13 +57,18 @@ class InvenTreeMoneySerializer(MoneyField):
amount = Decimal(amount) amount = Decimal(amount)
amount = round(amount, self.decimal_places) amount = round(amount, self.decimal_places)
except Exception: except Exception:
raise ValidationError({ raise ValidationError({self.field_name: [_('Must be a valid number')]})
self.field_name: [_("Must be a valid number")],
})
currency = data.get(get_currency_field_name(self.field_name), self.default_currency) currency = data.get(
get_currency_field_name(self.field_name), self.default_currency
)
if currency and amount is not None and not isinstance(amount, MONEY_CLASSES) and amount is not empty: if (
currency
and amount is not None
and not isinstance(amount, MONEY_CLASSES)
and amount is not empty
):
return Money(amount, currency) return Money(amount, currency)
return amount return amount
@ -76,7 +81,9 @@ class InvenTreeCurrencySerializer(serializers.ChoiceField):
"""Initialize the currency serializer""" """Initialize the currency serializer"""
choices = currency_code_mappings() choices = currency_code_mappings()
allow_blank = kwargs.get('allow_blank', False) or kwargs.get('allow_null', False) allow_blank = kwargs.get('allow_blank', False) or kwargs.get(
'allow_null', False
)
if allow_blank: if allow_blank:
choices = [('', '---------')] + choices choices = [('', '---------')] + choices
@ -97,6 +104,7 @@ class InvenTreeCurrencySerializer(serializers.ChoiceField):
class DependentField(serializers.Field): class DependentField(serializers.Field):
"""A dependent field can be used to dynamically return child fields based on the value of other fields.""" """A dependent field can be used to dynamically return child fields based on the value of other fields."""
child = None child = None
def __init__(self, *args, depends_on, field_serializer, **kwargs): def __init__(self, *args, depends_on, field_serializer, **kwargs):
@ -126,7 +134,7 @@ class DependentField(serializers.Field):
def get_child(self, raise_exception=False): def get_child(self, raise_exception=False):
"""This method tries to extract the child based on the provided data in the request by the client.""" """This method tries to extract the child based on the provided data in the request by the client."""
data = deepcopy(self.context["request"].data) data = deepcopy(self.context['request'].data)
def visit_parent(node): def visit_parent(node):
"""Recursively extract the data for the parent field/serializer in reverse.""" """Recursively extract the data for the parent field/serializer in reverse."""
@ -136,8 +144,9 @@ class DependentField(serializers.Field):
visit_parent(node.parent) visit_parent(node.parent)
# only do for composite fields and stop right before the current field # only do for composite fields and stop right before the current field
if hasattr(node, "child") and node is not self and isinstance(data, dict): if hasattr(node, 'child') and node is not self and isinstance(data, dict):
data = data.get(node.field_name, None) data = data.get(node.field_name, None)
visit_parent(self) visit_parent(self)
# ensure that data is a dictionary and that a parent exists # ensure that data is a dictionary and that a parent exists
@ -152,7 +161,9 @@ class DependentField(serializers.Field):
# partially validate the data for options requests that set raise_exception while calling .get_child(...) # partially validate the data for options requests that set raise_exception while calling .get_child(...)
if raise_exception: if raise_exception:
validation_data = {k: v for k, v in data.items() if k in self.depends_on} validation_data = {k: v for k, v in data.items() if k in self.depends_on}
serializer = self.parent.__class__(context=self.context, data=validation_data, partial=True) serializer = self.parent.__class__(
context=self.context, data=validation_data, partial=True
)
serializer.is_valid(raise_exception=raise_exception) serializer.is_valid(raise_exception=raise_exception)
# try to get the field serializer # try to get the field serializer
@ -196,7 +207,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
"""Custom __init__ routine to ensure that *default* values (as specified in the ORM) are used by the DRF serializers, *if* the values are not provided by the user.""" """Custom __init__ routine to ensure that *default* values (as specified in the ORM) are used by the DRF serializers, *if* the values are not provided by the user."""
# If instance is None, we are creating a new instance # If instance is None, we are creating a new instance
if instance is None and data is not empty: if instance is None and data is not empty:
if data is None: if data is None:
data = OrderedDict() data = OrderedDict()
else: else:
@ -211,7 +221,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
fields = model_meta.get_field_info(ModelClass) fields = model_meta.get_field_info(ModelClass)
for field_name, field in fields.fields.items(): for field_name, field in fields.fields.items():
""" """
Update the field IF (and ONLY IF): Update the field IF (and ONLY IF):
@ -219,7 +228,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
- The field does not already have a value set - The field does not already have a value set
""" """
if field.has_default() and field_name not in data: if field.has_default() and field_name not in data:
value = field.default value = field.default
# Account for callable functions # Account for callable functions
@ -247,9 +255,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
fields = model_meta.get_field_info(ModelClass) fields = model_meta.get_field_info(ModelClass)
for field_name, field in fields.fields.items(): for field_name, field in fields.fields.items():
if field.has_default() and field_name not in initials: if field.has_default() and field_name not in initials:
value = field.default value = field.default
# Account for callable functions # Account for callable functions
@ -337,7 +343,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
try: try:
instance.full_clean() instance.full_clean()
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
data = exc.message_dict data = exc.message_dict
# Change '__all__' key (django style) to 'non_field_errors' (DRF style) # Change '__all__' key (django style) to 'non_field_errors' (DRF style)
@ -369,6 +374,7 @@ class InvenTreeTaggitSerializer(TaggitSerializer):
class InvenTreeTagModelSerializer(InvenTreeTaggitSerializer, InvenTreeModelSerializer): class InvenTreeTagModelSerializer(InvenTreeTaggitSerializer, InvenTreeModelSerializer):
"""Combination of InvenTreeTaggitSerializer and InvenTreeModelSerializer.""" """Combination of InvenTreeTaggitSerializer and InvenTreeModelSerializer."""
pass pass
@ -377,38 +383,31 @@ class UserSerializer(InvenTreeModelSerializer):
class Meta: class Meta:
"""Metaclass defines serializer fields.""" """Metaclass defines serializer fields."""
model = User
fields = [
'pk',
'username',
'first_name',
'last_name',
'email',
]
read_only_fields = [ model = User
'username', fields = ['pk', 'username', 'first_name', 'last_name', 'email']
]
read_only_fields = ['username']
class ExendedUserSerializer(UserSerializer): class ExendedUserSerializer(UserSerializer):
"""Serializer for a User with a bit more info.""" """Serializer for a User with a bit more info."""
from users.serializers import GroupSerializer from users.serializers import GroupSerializer
groups = GroupSerializer(read_only=True, many=True) groups = GroupSerializer(read_only=True, many=True)
class Meta(UserSerializer.Meta): class Meta(UserSerializer.Meta):
"""Metaclass defines serializer fields.""" """Metaclass defines serializer fields."""
fields = UserSerializer.Meta.fields + [ fields = UserSerializer.Meta.fields + [
'groups', 'groups',
'is_staff', 'is_staff',
'is_superuser', 'is_superuser',
'is_active' 'is_active',
] ]
read_only_fields = UserSerializer.Meta.read_only_fields + [ read_only_fields = UserSerializer.Meta.read_only_fields + ['groups']
'groups',
]
def validate(self, attrs): def validate(self, attrs):
"""Expanded validation for changing user role.""" """Expanded validation for changing user role."""
@ -424,17 +423,20 @@ class ExendedUserSerializer(UserSerializer):
# Staff can change any role except is_superuser # Staff can change any role except is_superuser
pass pass
else: else:
raise PermissionDenied(_("You do not have permission to change this user role.")) raise PermissionDenied(
_('You do not have permission to change this user role.')
)
return super().validate(attrs) return super().validate(attrs)
class UserCreateSerializer(ExendedUserSerializer): class UserCreateSerializer(ExendedUserSerializer):
"""Serializer for creating a new User.""" """Serializer for creating a new User."""
def validate(self, attrs): def validate(self, attrs):
"""Expanded valiadation for auth.""" """Expanded valiadation for auth."""
# Check that the user trying to create a new user is a superuser # Check that the user trying to create a new user is a superuser
if not self.context['request'].user.is_superuser: if not self.context['request'].user.is_superuser:
raise serializers.ValidationError(_("Only superusers can create new users")) raise serializers.ValidationError(_('Only superusers can create new users'))
# Generate a random password # Generate a random password
password = User.objects.make_random_password(length=14) password = User.objects.make_random_password(length=14)
@ -451,8 +453,10 @@ class UserCreateSerializer(ExendedUserSerializer):
current_site = Site.objects.get_current() current_site = Site.objects.get_current()
domain = current_site.domain domain = current_site.domain
instance.email_user( instance.email_user(
subject=_(f"Welcome to {current_site.name}"), subject=_(f'Welcome to {current_site.name}'),
message=_(f"Your account has been created.\n\nPlease use the password reset function to get access (at https://{domain})."), message=_(
f'Your account has been created.\n\nPlease use the password reset function to get access (at https://{domain}).'
),
) )
return instance return instance
@ -509,17 +513,11 @@ class InvenTreeAttachmentSerializer(InvenTreeModelSerializer):
user_detail = UserSerializer(source='user', read_only=True, many=False) user_detail = UserSerializer(source='user', read_only=True, many=False)
attachment = InvenTreeAttachmentSerializerField( attachment = InvenTreeAttachmentSerializerField(required=False, allow_null=False)
required=False,
allow_null=False,
)
# The 'filename' field must be present in the serializer # The 'filename' field must be present in the serializer
filename = serializers.CharField( filename = serializers.CharField(
label=_('Filename'), label=_('Filename'), required=False, source='basename', allow_blank=False
required=False,
source='basename',
allow_blank=False,
) )
upload_date = serializers.DateField(read_only=True) upload_date = serializers.DateField(read_only=True)
@ -553,7 +551,7 @@ class InvenTreeDecimalField(serializers.FloatField):
try: try:
return Decimal(str(data)) return Decimal(str(data))
except Exception: except Exception:
raise serializers.ValidationError(_("Invalid value")) raise serializers.ValidationError(_('Invalid value'))
class DataFileUploadSerializer(serializers.Serializer): class DataFileUploadSerializer(serializers.Serializer):
@ -570,13 +568,11 @@ class DataFileUploadSerializer(serializers.Serializer):
class Meta: class Meta:
"""Metaclass options.""" """Metaclass options."""
fields = [ fields = ['data_file']
'data_file',
]
data_file = serializers.FileField( data_file = serializers.FileField(
label=_("Data File"), label=_('Data File'),
help_text=_("Select data file for upload"), help_text=_('Select data file for upload'),
required=True, required=True,
allow_empty_file=False, allow_empty_file=False,
) )
@ -590,20 +586,16 @@ class DataFileUploadSerializer(serializers.Serializer):
# Remove the leading . from the extension # Remove the leading . from the extension
ext = ext[1:] ext = ext[1:]
accepted_file_types = [ accepted_file_types = ['xls', 'xlsx', 'csv', 'tsv', 'xml']
'xls', 'xlsx',
'csv', 'tsv',
'xml',
]
if ext not in accepted_file_types: if ext not in accepted_file_types:
raise serializers.ValidationError(_("Unsupported file type")) raise serializers.ValidationError(_('Unsupported file type'))
# Impose a 50MB limit on uploaded BOM files # Impose a 50MB limit on uploaded BOM files
max_upload_file_size = 50 * 1024 * 1024 max_upload_file_size = 50 * 1024 * 1024
if data_file.size > max_upload_file_size: if data_file.size > max_upload_file_size:
raise serializers.ValidationError(_("File is too large")) raise serializers.ValidationError(_('File is too large'))
# Read file data into memory (bytes object) # Read file data into memory (bytes object)
try: try:
@ -624,10 +616,10 @@ class DataFileUploadSerializer(serializers.Serializer):
raise serializers.ValidationError(str(e)) raise serializers.ValidationError(str(e))
if len(self.dataset.headers) == 0: if len(self.dataset.headers) == 0:
raise serializers.ValidationError(_("No columns found in file")) raise serializers.ValidationError(_('No columns found in file'))
if len(self.dataset) == 0: if len(self.dataset) == 0:
raise serializers.ValidationError(_("No data rows found in file")) raise serializers.ValidationError(_('No data rows found in file'))
return data_file return data_file
@ -721,24 +713,14 @@ class DataFileExtractSerializer(serializers.Serializer):
class Meta: class Meta:
"""Metaclass options.""" """Metaclass options."""
fields = [ fields = ['columns', 'rows']
'columns',
'rows',
]
# Mapping of columns # Mapping of columns
columns = serializers.ListField( columns = serializers.ListField(child=serializers.CharField(allow_blank=True))
child=serializers.CharField(
allow_blank=True,
),
)
rows = serializers.ListField( rows = serializers.ListField(
child=serializers.ListField( child=serializers.ListField(
child=serializers.CharField( child=serializers.CharField(allow_blank=True, allow_null=True)
allow_blank=True,
allow_null=True,
),
) )
) )
@ -750,10 +732,10 @@ class DataFileExtractSerializer(serializers.Serializer):
self.rows = data.get('rows', []) self.rows = data.get('rows', [])
if len(self.rows) == 0: if len(self.rows) == 0:
raise serializers.ValidationError(_("No data rows provided")) raise serializers.ValidationError(_('No data rows provided'))
if len(self.columns) == 0: if len(self.columns) == 0:
raise serializers.ValidationError(_("No data columns supplied")) raise serializers.ValidationError(_('No data columns supplied'))
self.validate_extracted_columns() self.validate_extracted_columns()
@ -776,16 +758,9 @@ class DataFileExtractSerializer(serializers.Serializer):
processed_row = self.process_row(self.row_to_dict(row)) processed_row = self.process_row(self.row_to_dict(row))
if processed_row: if processed_row:
rows.append({ rows.append({'original': row, 'data': processed_row})
"original": row,
"data": processed_row,
})
return { return {'fields': model_fields, 'columns': self.columns, 'rows': rows}
'fields': model_fields,
'columns': self.columns,
'rows': rows,
}
def process_row(self, row): def process_row(self, row):
"""Process a 'row' of data, which is a mapped column:value dict. """Process a 'row' of data, which is a mapped column:value dict.
@ -799,12 +774,9 @@ class DataFileExtractSerializer(serializers.Serializer):
def row_to_dict(self, row): def row_to_dict(self, row):
"""Convert a "row" to a named data dict.""" """Convert a "row" to a named data dict."""
row_dict = { row_dict = {'errors': {}}
'errors': {},
}
for idx, value in enumerate(row): for idx, value in enumerate(row):
if idx < len(self.columns): if idx < len(self.columns):
col = self.columns[idx] col = self.columns[idx]
@ -824,16 +796,16 @@ class DataFileExtractSerializer(serializers.Serializer):
cols_seen = set() cols_seen = set()
for name, field in model_fields.items(): for name, field in model_fields.items():
required = field.get('required', False) required = field.get('required', False)
# Check for missing required columns # Check for missing required columns
if required: if required:
if name not in self.columns: if name not in self.columns:
raise serializers.ValidationError(_(f"Missing required column: '{name}'")) raise serializers.ValidationError(
_(f"Missing required column: '{name}'")
)
for col in self.columns: for col in self.columns:
if not col: if not col:
continue continue
@ -856,16 +828,14 @@ class RemoteImageMixin(metaclass=serializers.SerializerMetaclass):
def skip_create_fields(self): def skip_create_fields(self):
"""Ensure the 'remote_image' field is skipped when creating a new instance""" """Ensure the 'remote_image' field is skipped when creating a new instance"""
return [ return ['remote_image']
'remote_image',
]
remote_image = serializers.URLField( remote_image = serializers.URLField(
required=False, required=False,
allow_blank=False, allow_blank=False,
write_only=True, write_only=True,
label=_("Remote Image"), label=_('Remote Image'),
help_text=_("URL of remote image file"), help_text=_('URL of remote image file'),
) )
def validate_remote_image(self, url): def validate_remote_image(self, url):
@ -877,8 +847,12 @@ class RemoteImageMixin(metaclass=serializers.SerializerMetaclass):
if not url: if not url:
return return
if not common_models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_FROM_URL'): if not common_models.InvenTreeSetting.get_setting(
raise ValidationError(_("Downloading images from remote URL is not enabled")) 'INVENTREE_DOWNLOAD_FROM_URL'
):
raise ValidationError(
_('Downloading images from remote URL is not enabled')
)
try: try:
self.remote_image_file = download_image_from_url(url) self.remote_image_file = download_image_from_url(url)

View File

@ -38,9 +38,8 @@ INVENTREE_NEWS_URL = 'https://inventree.org/news/feed.atom'
TESTING = 'test' in sys.argv or 'TESTING' in os.environ TESTING = 'test' in sys.argv or 'TESTING' in os.environ
if TESTING: if TESTING:
# Use a weaker password hasher for testing (improves testing speed) # Use a weaker password hasher for testing (improves testing speed)
PASSWORD_HASHERS = ['django.contrib.auth.hashers.MD5PasswordHasher',] PASSWORD_HASHERS = ['django.contrib.auth.hashers.MD5PasswordHasher']
# Enable slow-test-runner # Enable slow-test-runner
TEST_RUNNER = 'django_slowtests.testrunner.DiscoverSlowestTestsRunner' TEST_RUNNER = 'django_slowtests.testrunner.DiscoverSlowestTestsRunner'
@ -53,7 +52,7 @@ if TESTING:
site_packages = '/usr/local/lib/python3.9/site-packages' site_packages = '/usr/local/lib/python3.9/site-packages'
if site_packages not in sys.path: if site_packages not in sys.path:
print("Adding missing site-packages path:", site_packages) print('Adding missing site-packages path:', site_packages)
sys.path.append(site_packages) sys.path.append(site_packages)
# Are environment variables manipulated by tests? Needs to be set by testing code # Are environment variables manipulated by tests? Needs to be set by testing code
@ -78,16 +77,17 @@ if version_file.exists():
# SECURITY WARNING: don't run with debug turned on in production! # SECURITY WARNING: don't run with debug turned on in production!
DEBUG = get_boolean_setting('INVENTREE_DEBUG', 'debug', True) DEBUG = get_boolean_setting('INVENTREE_DEBUG', 'debug', True)
ENABLE_CLASSIC_FRONTEND = get_boolean_setting('INVENTREE_CLASSIC_FRONTEND', 'classic_frontend', True) ENABLE_CLASSIC_FRONTEND = get_boolean_setting(
ENABLE_PLATFORM_FRONTEND = get_boolean_setting('INVENTREE_PLATFORM_FRONTEND', 'platform_frontend', True) 'INVENTREE_CLASSIC_FRONTEND', 'classic_frontend', True
)
ENABLE_PLATFORM_FRONTEND = get_boolean_setting(
'INVENTREE_PLATFORM_FRONTEND', 'platform_frontend', True
)
# Configure logging settings # Configure logging settings
log_level = get_setting('INVENTREE_LOG_LEVEL', 'log_level', 'WARNING') log_level = get_setting('INVENTREE_LOG_LEVEL', 'log_level', 'WARNING')
logging.basicConfig( logging.basicConfig(level=log_level, format='%(asctime)s %(levelname)s %(message)s')
level=log_level,
format="%(asctime)s %(levelname)s %(message)s",
)
if log_level not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']: if log_level not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
log_level = 'WARNING' # pragma: no cover log_level = 'WARNING' # pragma: no cover
@ -95,32 +95,21 @@ if log_level not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']:
LOGGING = { LOGGING = {
'version': 1, 'version': 1,
'disable_existing_loggers': False, 'disable_existing_loggers': False,
'handlers': { 'handlers': {'console': {'class': 'logging.StreamHandler'}},
'console': { 'root': {'handlers': ['console'], 'level': log_level},
'class': 'logging.StreamHandler',
},
},
'root': {
'handlers': ['console'],
'level': log_level,
},
'filters': { 'filters': {
'require_not_maintenance_mode_503': { 'require_not_maintenance_mode_503': {
'()': 'maintenance_mode.logging.RequireNotMaintenanceMode503', '()': 'maintenance_mode.logging.RequireNotMaintenanceMode503'
}, }
}, },
} }
# Optionally add database-level logging # Optionally add database-level logging
if get_setting('INVENTREE_DB_LOGGING', 'db_logging', False): if get_setting('INVENTREE_DB_LOGGING', 'db_logging', False):
LOGGING['loggers'] = { LOGGING['loggers'] = {'django.db.backends': {'level': log_level or 'DEBUG'}}
'django.db.backends': {
'level': log_level or 'DEBUG',
},
}
# Get a logger instance for this setup file # Get a logger instance for this setup file
logger = logging.getLogger("inventree") logger = logging.getLogger('inventree')
# Load SECRET_KEY # Load SECRET_KEY
SECRET_KEY = config.get_secret_key() SECRET_KEY = config.get_secret_key()
@ -133,7 +122,7 @@ MEDIA_ROOT = config.get_media_dir()
# List of allowed hosts (default = allow all) # List of allowed hosts (default = allow all)
ALLOWED_HOSTS = get_setting( ALLOWED_HOSTS = get_setting(
"INVENTREE_ALLOWED_HOSTS", 'INVENTREE_ALLOWED_HOSTS',
config_key='allowed_hosts', config_key='allowed_hosts',
default_value=['*'], default_value=['*'],
typecast=list, typecast=list,
@ -146,13 +135,11 @@ CORS_URLS_REGEX = r'^/(api|media|static)/.*$'
# Extract CORS options from configuration file # Extract CORS options from configuration file
CORS_ORIGIN_ALLOW_ALL = get_boolean_setting( CORS_ORIGIN_ALLOW_ALL = get_boolean_setting(
"INVENTREE_CORS_ORIGIN_ALLOW_ALL", 'INVENTREE_CORS_ORIGIN_ALLOW_ALL', config_key='cors.allow_all', default_value=False
config_key='cors.allow_all',
default_value=False,
) )
CORS_ORIGIN_WHITELIST = get_setting( CORS_ORIGIN_WHITELIST = get_setting(
"INVENTREE_CORS_ORIGIN_WHITELIST", 'INVENTREE_CORS_ORIGIN_WHITELIST',
config_key='cors.whitelist', config_key='cors.whitelist',
default_value=[], default_value=[],
typecast=list, typecast=list,
@ -173,9 +160,7 @@ STATICFILES_I18_TRG = BASE_DIR.joinpath('InvenTree', 'static_i18n')
STATICFILES_DIRS.append(STATICFILES_I18_TRG) STATICFILES_DIRS.append(STATICFILES_I18_TRG)
STATICFILES_I18_TRG = STATICFILES_I18_TRG.joinpath(STATICFILES_I18_PREFIX) STATICFILES_I18_TRG = STATICFILES_I18_TRG.joinpath(STATICFILES_I18_PREFIX)
STATFILES_I18_PROCESSORS = [ STATFILES_I18_PROCESSORS = ['InvenTree.context.status_codes']
'InvenTree.context.status_codes',
]
# Color Themes Directory # Color Themes Directory
STATIC_COLOR_THEMES_DIR = STATIC_ROOT.joinpath('css', 'color-themes').resolve() STATIC_COLOR_THEMES_DIR = STATIC_ROOT.joinpath('css', 'color-themes').resolve()
@ -189,33 +174,28 @@ DBBACKUP_SEND_EMAIL = False
DBBACKUP_STORAGE = get_setting( DBBACKUP_STORAGE = get_setting(
'INVENTREE_BACKUP_STORAGE', 'INVENTREE_BACKUP_STORAGE',
'backup_storage', 'backup_storage',
'django.core.files.storage.FileSystemStorage' 'django.core.files.storage.FileSystemStorage',
) )
# Default backup configuration # Default backup configuration
DBBACKUP_STORAGE_OPTIONS = get_setting('INVENTREE_BACKUP_OPTIONS', 'backup_options', None) DBBACKUP_STORAGE_OPTIONS = get_setting(
'INVENTREE_BACKUP_OPTIONS', 'backup_options', None
)
if DBBACKUP_STORAGE_OPTIONS is None: if DBBACKUP_STORAGE_OPTIONS is None:
DBBACKUP_STORAGE_OPTIONS = { DBBACKUP_STORAGE_OPTIONS = {'location': config.get_backup_dir()}
'location': config.get_backup_dir(),
}
INVENTREE_ADMIN_ENABLED = get_boolean_setting( INVENTREE_ADMIN_ENABLED = get_boolean_setting(
'INVENTREE_ADMIN_ENABLED', 'INVENTREE_ADMIN_ENABLED', config_key='admin_enabled', default_value=True
config_key='admin_enabled',
default_value=True
) )
# Base URL for admin pages (default="admin") # Base URL for admin pages (default="admin")
INVENTREE_ADMIN_URL = get_setting( INVENTREE_ADMIN_URL = get_setting(
'INVENTREE_ADMIN_URL', 'INVENTREE_ADMIN_URL', config_key='admin_url', default_value='admin'
config_key='admin_url',
default_value='admin'
) )
INSTALLED_APPS = [ INSTALLED_APPS = [
# Admin site integration # Admin site integration
'django.contrib.admin', 'django.contrib.admin',
# InvenTree apps # InvenTree apps
'build.apps.BuildConfig', 'build.apps.BuildConfig',
'common.apps.CommonConfig', 'common.apps.CommonConfig',
@ -230,7 +210,6 @@ INSTALLED_APPS = [
'web', 'web',
'generic', 'generic',
'InvenTree.apps.InvenTreeConfig', # InvenTree app runs last 'InvenTree.apps.InvenTreeConfig', # InvenTree app runs last
# Core django modules # Core django modules
'django.contrib.auth', 'django.contrib.auth',
'django.contrib.contenttypes', 'django.contrib.contenttypes',
@ -238,10 +217,8 @@ INSTALLED_APPS = [
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'django.contrib.sites', 'django.contrib.sites',
# Maintenance # Maintenance
'maintenance_mode', 'maintenance_mode',
# Third part add-ons # Third part add-ons
'django_filters', # Extended filter functionality 'django_filters', # Extended filter functionality
'rest_framework', # DRF (Django Rest Framework) 'rest_framework', # DRF (Django Rest Framework)
@ -259,24 +236,22 @@ INSTALLED_APPS = [
'dbbackup', # Backups - django-dbbackup 'dbbackup', # Backups - django-dbbackup
'taggit', # Tagging 'taggit', # Tagging
'flags', # Flagging - django-flags 'flags', # Flagging - django-flags
'allauth', # Base app for SSO 'allauth', # Base app for SSO
'allauth.account', # Extend user with accounts 'allauth.account', # Extend user with accounts
'allauth.socialaccount', # Use 'social' providers 'allauth.socialaccount', # Use 'social' providers
'django_otp', # OTP is needed for MFA - base package 'django_otp', # OTP is needed for MFA - base package
'django_otp.plugins.otp_totp', # Time based OTP 'django_otp.plugins.otp_totp', # Time based OTP
'django_otp.plugins.otp_static', # Backup codes 'django_otp.plugins.otp_static', # Backup codes
'allauth_2fa', # MFA flow for allauth 'allauth_2fa', # MFA flow for allauth
'dj_rest_auth', # Authentication APIs - dj-rest-auth 'dj_rest_auth', # Authentication APIs - dj-rest-auth
'dj_rest_auth.registration', # Registration APIs - dj-rest-auth' 'dj_rest_auth.registration', # Registration APIs - dj-rest-auth'
'drf_spectacular', # API documentation 'drf_spectacular', # API documentation
'django_ical', # For exporting calendars 'django_ical', # For exporting calendars
] ]
MIDDLEWARE = CONFIG.get('middleware', [ MIDDLEWARE = CONFIG.get(
'middleware',
[
'django.middleware.security.SecurityMiddleware', 'django.middleware.security.SecurityMiddleware',
'x_forwarded_for.middleware.XForwardedForMiddleware', 'x_forwarded_for.middleware.XForwardedForMiddleware',
'user_sessions.middleware.SessionMiddleware', # db user sessions 'user_sessions.middleware.SessionMiddleware', # db user sessions
@ -295,43 +270,52 @@ MIDDLEWARE = CONFIG.get('middleware', [
'InvenTree.middleware.Check2FAMiddleware', # Check if the user should be forced to use MFA 'InvenTree.middleware.Check2FAMiddleware', # Check if the user should be forced to use MFA
'maintenance_mode.middleware.MaintenanceModeMiddleware', 'maintenance_mode.middleware.MaintenanceModeMiddleware',
'InvenTree.middleware.InvenTreeExceptionProcessor', # Error reporting 'InvenTree.middleware.InvenTreeExceptionProcessor', # Error reporting
]) ],
)
AUTHENTICATION_BACKENDS = CONFIG.get('authentication_backends', [ AUTHENTICATION_BACKENDS = CONFIG.get(
'authentication_backends',
[
'django.contrib.auth.backends.RemoteUserBackend', # proxy login 'django.contrib.auth.backends.RemoteUserBackend', # proxy login
'django.contrib.auth.backends.ModelBackend', 'django.contrib.auth.backends.ModelBackend',
'allauth.account.auth_backends.AuthenticationBackend', # SSO login via external providers 'allauth.account.auth_backends.AuthenticationBackend', # SSO login via external providers
"sesame.backends.ModelBackend", # Magic link login django-sesame 'sesame.backends.ModelBackend', # Magic link login django-sesame
]) ],
)
# LDAP support # LDAP support
LDAP_AUTH = get_boolean_setting("INVENTREE_LDAP_ENABLED", "ldap.enabled", False) LDAP_AUTH = get_boolean_setting('INVENTREE_LDAP_ENABLED', 'ldap.enabled', False)
if LDAP_AUTH: if LDAP_AUTH:
import ldap import ldap
from django_auth_ldap.config import GroupOfUniqueNamesType, LDAPSearch from django_auth_ldap.config import GroupOfUniqueNamesType, LDAPSearch
AUTHENTICATION_BACKENDS.append("django_auth_ldap.backend.LDAPBackend") AUTHENTICATION_BACKENDS.append('django_auth_ldap.backend.LDAPBackend')
# debug mode to troubleshoot configuration # debug mode to troubleshoot configuration
LDAP_DEBUG = get_boolean_setting("INVENTREE_LDAP_DEBUG", "ldap.debug", False) LDAP_DEBUG = get_boolean_setting('INVENTREE_LDAP_DEBUG', 'ldap.debug', False)
if LDAP_DEBUG: if LDAP_DEBUG:
if "loggers" not in LOGGING: if 'loggers' not in LOGGING:
LOGGING["loggers"] = {} LOGGING['loggers'] = {}
LOGGING["loggers"]["django_auth_ldap"] = {"level": "DEBUG", "handlers": ["console"]} LOGGING['loggers']['django_auth_ldap'] = {
'level': 'DEBUG',
'handlers': ['console'],
}
# get global options from dict and use ldap.OPT_* as keys and values # get global options from dict and use ldap.OPT_* as keys and values
global_options_dict = get_setting("INVENTREE_LDAP_GLOBAL_OPTIONS", "ldap.global_options", {}, dict) global_options_dict = get_setting(
'INVENTREE_LDAP_GLOBAL_OPTIONS', 'ldap.global_options', {}, dict
)
global_options = {} global_options = {}
for k, v in global_options_dict.items(): for k, v in global_options_dict.items():
# keys are always ldap.OPT_* constants # keys are always ldap.OPT_* constants
k_attr = getattr(ldap, k, None) k_attr = getattr(ldap, k, None)
if not k.startswith("OPT_") or k_attr is None: if not k.startswith('OPT_') or k_attr is None:
print(f"[LDAP] ldap.global_options, key '{k}' not found, skipping...") print(f"[LDAP] ldap.global_options, key '{k}' not found, skipping...")
continue continue
# values can also be other strings, e.g. paths # values can also be other strings, e.g. paths
v_attr = v v_attr = v
if v.startswith("OPT_"): if v.startswith('OPT_'):
v_attr = getattr(ldap, v, None) v_attr = getattr(ldap, v, None)
if v_attr is None: if v_attr is None:
@ -341,41 +325,60 @@ if LDAP_AUTH:
global_options[k_attr] = v_attr global_options[k_attr] = v_attr
AUTH_LDAP_GLOBAL_OPTIONS = global_options AUTH_LDAP_GLOBAL_OPTIONS = global_options
if LDAP_DEBUG: if LDAP_DEBUG:
print("[LDAP] ldap.global_options =", global_options) print('[LDAP] ldap.global_options =', global_options)
AUTH_LDAP_SERVER_URI = get_setting("INVENTREE_LDAP_SERVER_URI", "ldap.server_uri") AUTH_LDAP_SERVER_URI = get_setting('INVENTREE_LDAP_SERVER_URI', 'ldap.server_uri')
AUTH_LDAP_START_TLS = get_boolean_setting("INVENTREE_LDAP_START_TLS", "ldap.start_tls", False) AUTH_LDAP_START_TLS = get_boolean_setting(
AUTH_LDAP_BIND_DN = get_setting("INVENTREE_LDAP_BIND_DN", "ldap.bind_dn") 'INVENTREE_LDAP_START_TLS', 'ldap.start_tls', False
AUTH_LDAP_BIND_PASSWORD = get_setting("INVENTREE_LDAP_BIND_PASSWORD", "ldap.bind_password") )
AUTH_LDAP_USER_SEARCH = LDAPSearch( AUTH_LDAP_BIND_DN = get_setting('INVENTREE_LDAP_BIND_DN', 'ldap.bind_dn')
get_setting("INVENTREE_LDAP_SEARCH_BASE_DN", "ldap.search_base_dn"), AUTH_LDAP_BIND_PASSWORD = get_setting(
ldap.SCOPE_SUBTREE, 'INVENTREE_LDAP_BIND_PASSWORD', 'ldap.bind_password'
str(get_setting("INVENTREE_LDAP_SEARCH_FILTER_STR", "ldap.search_filter_str", "(uid= %(user)s)")) )
AUTH_LDAP_USER_SEARCH = LDAPSearch(
get_setting('INVENTREE_LDAP_SEARCH_BASE_DN', 'ldap.search_base_dn'),
ldap.SCOPE_SUBTREE,
str(
get_setting(
'INVENTREE_LDAP_SEARCH_FILTER_STR',
'ldap.search_filter_str',
'(uid= %(user)s)',
)
),
)
AUTH_LDAP_USER_DN_TEMPLATE = get_setting(
'INVENTREE_LDAP_USER_DN_TEMPLATE', 'ldap.user_dn_template'
)
AUTH_LDAP_USER_ATTR_MAP = get_setting(
'INVENTREE_LDAP_USER_ATTR_MAP',
'ldap.user_attr_map',
{'first_name': 'givenName', 'last_name': 'sn', 'email': 'mail'},
dict,
)
AUTH_LDAP_ALWAYS_UPDATE_USER = get_boolean_setting(
'INVENTREE_LDAP_ALWAYS_UPDATE_USER', 'ldap.always_update_user', True
)
AUTH_LDAP_CACHE_TIMEOUT = get_setting(
'INVENTREE_LDAP_CACHE_TIMEOUT', 'ldap.cache_timeout', 3600, int
) )
AUTH_LDAP_USER_DN_TEMPLATE = get_setting("INVENTREE_LDAP_USER_DN_TEMPLATE", "ldap.user_dn_template")
AUTH_LDAP_USER_ATTR_MAP = get_setting("INVENTREE_LDAP_USER_ATTR_MAP", "ldap.user_attr_map", {
'first_name': 'givenName',
'last_name': 'sn',
'email': 'mail',
}, dict)
AUTH_LDAP_ALWAYS_UPDATE_USER = get_boolean_setting("INVENTREE_LDAP_ALWAYS_UPDATE_USER", "ldap.always_update_user", True)
AUTH_LDAP_CACHE_TIMEOUT = get_setting("INVENTREE_LDAP_CACHE_TIMEOUT", "ldap.cache_timeout", 3600, int)
AUTH_LDAP_GROUP_SEARCH = LDAPSearch( AUTH_LDAP_GROUP_SEARCH = LDAPSearch(
get_setting("INVENTREE_LDAP_GROUP_SEARCH", "ldap.group_search"), get_setting('INVENTREE_LDAP_GROUP_SEARCH', 'ldap.group_search'),
ldap.SCOPE_SUBTREE, ldap.SCOPE_SUBTREE,
"(objectClass=groupOfUniqueNames)", '(objectClass=groupOfUniqueNames)',
)
AUTH_LDAP_GROUP_TYPE = GroupOfUniqueNamesType(name_attr='cn')
AUTH_LDAP_REQUIRE_GROUP = get_setting(
'INVENTREE_LDAP_REQUIRE_GROUP', 'ldap.require_group'
)
AUTH_LDAP_DENY_GROUP = get_setting('INVENTREE_LDAP_DENY_GROUP', 'ldap.deny_group')
AUTH_LDAP_USER_FLAGS_BY_GROUP = get_setting(
'INVENTREE_LDAP_USER_FLAGS_BY_GROUP', 'ldap.user_flags_by_group', {}, dict
) )
AUTH_LDAP_GROUP_TYPE = GroupOfUniqueNamesType(name_attr="cn")
AUTH_LDAP_REQUIRE_GROUP = get_setting("INVENTREE_LDAP_REQUIRE_GROUP", "ldap.require_group")
AUTH_LDAP_DENY_GROUP = get_setting("INVENTREE_LDAP_DENY_GROUP", "ldap.deny_group")
AUTH_LDAP_USER_FLAGS_BY_GROUP = get_setting("INVENTREE_LDAP_USER_FLAGS_BY_GROUP", "ldap.user_flags_by_group", {}, dict)
AUTH_LDAP_FIND_GROUP_PERMS = True AUTH_LDAP_FIND_GROUP_PERMS = True
# Internal IP addresses allowed to see the debug toolbar # Internal IP addresses allowed to see the debug toolbar
INTERNAL_IPS = [ INTERNAL_IPS = ['127.0.0.1']
'127.0.0.1',
]
# Internal flag to determine if we are running in docker mode # Internal flag to determine if we are running in docker mode
DOCKER = get_boolean_setting('INVENTREE_DOCKER', default_value=False) DOCKER = get_boolean_setting('INVENTREE_DOCKER', default_value=False)
@ -383,7 +386,10 @@ DOCKER = get_boolean_setting('INVENTREE_DOCKER', default_value=False)
if DOCKER: # pragma: no cover if DOCKER: # pragma: no cover
# Internal IP addresses are different when running under docker # Internal IP addresses are different when running under docker
hostname, ___, ips = socket.gethostbyname_ex(socket.gethostname()) hostname, ___, ips = socket.gethostbyname_ex(socket.gethostname())
INTERNAL_IPS = [ip[: ip.rfind(".")] + ".1" for ip in ips] + ["127.0.0.1", "10.0.2.2"] INTERNAL_IPS = [ip[: ip.rfind('.')] + '.1' for ip in ips] + [
'127.0.0.1',
'10.0.2.2',
]
# Allow secure http developer server in debug mode # Allow secure http developer server in debug mode
if DEBUG: if DEBUG:
@ -413,15 +419,18 @@ TEMPLATES = [
'InvenTree.context.status_codes', 'InvenTree.context.status_codes',
'InvenTree.context.user_roles', 'InvenTree.context.user_roles',
], ],
'loaders': [( 'loaders': [
'InvenTree.template.InvenTreeTemplateLoader', [ (
'InvenTree.template.InvenTreeTemplateLoader',
[
'plugin.template.PluginTemplateLoader', 'plugin.template.PluginTemplateLoader',
'django.template.loaders.filesystem.Loader', 'django.template.loaders.filesystem.Loader',
'django.template.loaders.app_directories.Loader', 'django.template.loaders.app_directories.Loader',
]) ],
)
], ],
}, },
}, }
] ]
REST_FRAMEWORK = { REST_FRAMEWORK = {
@ -440,22 +449,24 @@ REST_FRAMEWORK = {
), ),
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
'DEFAULT_METADATA_CLASS': 'InvenTree.metadata.InvenTreeMetadata', 'DEFAULT_METADATA_CLASS': 'InvenTree.metadata.InvenTreeMetadata',
'DEFAULT_RENDERER_CLASSES': [ 'DEFAULT_RENDERER_CLASSES': ['rest_framework.renderers.JSONRenderer'],
'rest_framework.renderers.JSONRenderer',
],
'TOKEN_MODEL': 'users.models.ApiToken', 'TOKEN_MODEL': 'users.models.ApiToken',
} }
if DEBUG: if DEBUG:
# Enable browsable API if in DEBUG mode # Enable browsable API if in DEBUG mode
REST_FRAMEWORK['DEFAULT_RENDERER_CLASSES'].append('rest_framework.renderers.BrowsableAPIRenderer') REST_FRAMEWORK['DEFAULT_RENDERER_CLASSES'].append(
'rest_framework.renderers.BrowsableAPIRenderer'
)
# dj-rest-auth # dj-rest-auth
# JWT switch # JWT switch
USE_JWT = get_boolean_setting('INVENTREE_USE_JWT', 'use_jwt', False) USE_JWT = get_boolean_setting('INVENTREE_USE_JWT', 'use_jwt', False)
REST_USE_JWT = USE_JWT REST_USE_JWT = USE_JWT
OLD_PASSWORD_FIELD_ENABLED = True OLD_PASSWORD_FIELD_ENABLED = True
REST_AUTH_REGISTER_SERIALIZERS = {'REGISTER_SERIALIZER': 'InvenTree.forms.CustomRegisterSerializer'} REST_AUTH_REGISTER_SERIALIZERS = {
'REGISTER_SERIALIZER': 'InvenTree.forms.CustomRegisterSerializer'
}
# JWT settings - rest_framework_simplejwt # JWT settings - rest_framework_simplejwt
if USE_JWT: if USE_JWT:
@ -471,7 +482,10 @@ SPECTACULAR_SETTINGS = {
'TITLE': 'InvenTree API', 'TITLE': 'InvenTree API',
'DESCRIPTION': 'API for InvenTree - the intuitive open source inventory management system', 'DESCRIPTION': 'API for InvenTree - the intuitive open source inventory management system',
'LICENSE': {'MIT': 'https://github.com/inventree/InvenTree/blob/master/LICENSE'}, 'LICENSE': {'MIT': 'https://github.com/inventree/InvenTree/blob/master/LICENSE'},
'EXTERNAL_DOCS': {'docs': 'https://docs.inventree.org', 'web': 'https://inventree.org'}, 'EXTERNAL_DOCS': {
'docs': 'https://docs.inventree.org',
'web': 'https://inventree.org',
},
'VERSION': inventreeApiVersion(), 'VERSION': inventreeApiVersion(),
'SERVE_INCLUDE_SCHEMA': False, 'SERVE_INCLUDE_SCHEMA': False,
} }
@ -486,7 +500,7 @@ Configure the database backend based on the user-specified values.
- The following code lets the user "mix and match" database configuration - The following code lets the user "mix and match" database configuration
""" """
logger.debug("Configuring database backend:") logger.debug('Configuring database backend:')
# Extract database configuration from the config.yaml file # Extract database configuration from the config.yaml file
db_config = CONFIG.get('database', {}) db_config = CONFIG.get('database', {})
@ -500,7 +514,7 @@ db_keys = ['ENGINE', 'NAME', 'USER', 'PASSWORD', 'HOST', 'PORT']
for key in db_keys: for key in db_keys:
# First, check the environment variables # First, check the environment variables
env_key = f"INVENTREE_DB_{key}" env_key = f'INVENTREE_DB_{key}'
env_var = os.environ.get(env_key, None) env_var = os.environ.get(env_key, None)
if env_var: if env_var:
@ -509,7 +523,7 @@ for key in db_keys:
try: try:
env_var = int(env_var) env_var = int(env_var)
except ValueError: except ValueError:
logger.exception("Invalid number for %s: %s", env_key, env_var) logger.exception('Invalid number for %s: %s', env_key, env_var)
# Override configuration value # Override configuration value
db_config[key] = env_var db_config[key] = env_var
@ -550,9 +564,9 @@ if 'sqlite' in db_engine:
db_name = str(Path(db_name).resolve()) db_name = str(Path(db_name).resolve())
db_config['NAME'] = db_name db_config['NAME'] = db_name
logger.info("DB_ENGINE: %s", db_engine) logger.info('DB_ENGINE: %s', db_engine)
logger.info("DB_NAME: %s", db_name) logger.info('DB_NAME: %s', db_name)
logger.info("DB_HOST: %s", db_host) logger.info('DB_HOST: %s', db_host)
""" """
In addition to base-level database configuration, we may wish to specify specific options to the database backend In addition to base-level database configuration, we may wish to specify specific options to the database backend
@ -565,19 +579,21 @@ Ref: https://docs.djangoproject.com/en/3.2/ref/settings/#std:setting-OPTIONS
# connecting to the database server (such as a replica failover) don't sit and # connecting to the database server (such as a replica failover) don't sit and
# wait for possibly an hour or more, just tell the client something went wrong # wait for possibly an hour or more, just tell the client something went wrong
# and let the client retry when they want to. # and let the client retry when they want to.
db_options = db_config.get("OPTIONS", db_config.get("options", {})) db_options = db_config.get('OPTIONS', db_config.get('options', {}))
# Specific options for postgres backend # Specific options for postgres backend
if "postgres" in db_engine: # pragma: no cover if 'postgres' in db_engine: # pragma: no cover
from psycopg2.extensions import (ISOLATION_LEVEL_READ_COMMITTED, from psycopg2.extensions import (
ISOLATION_LEVEL_SERIALIZABLE) ISOLATION_LEVEL_READ_COMMITTED,
ISOLATION_LEVEL_SERIALIZABLE,
)
# Connection timeout # Connection timeout
if "connect_timeout" not in db_options: if 'connect_timeout' not in db_options:
# The DB server is in the same data center, it should not take very # The DB server is in the same data center, it should not take very
# long to connect to the database server # long to connect to the database server
# # seconds, 2 is minimum allowed by libpq # # seconds, 2 is minimum allowed by libpq
db_options["connect_timeout"] = int( db_options['connect_timeout'] = int(
get_setting('INVENTREE_DB_TIMEOUT', 'database.timeout', 2) get_setting('INVENTREE_DB_TIMEOUT', 'database.timeout', 2)
) )
@ -587,27 +603,37 @@ if "postgres" in db_engine: # pragma: no cover
# issue to resolve itself. It it that doesn't happen whatever happened # issue to resolve itself. It it that doesn't happen whatever happened
# is probably fatal and no amount of waiting is going to fix it. # is probably fatal and no amount of waiting is going to fix it.
# # 0 - TCP Keepalives disabled; 1 - enabled # # 0 - TCP Keepalives disabled; 1 - enabled
if "keepalives" not in db_options: if 'keepalives' not in db_options:
db_options["keepalives"] = int( db_options['keepalives'] = int(
get_setting('INVENTREE_DB_TCP_KEEPALIVES', 'database.tcp_keepalives', 1) get_setting('INVENTREE_DB_TCP_KEEPALIVES', 'database.tcp_keepalives', 1)
) )
# Seconds after connection is idle to send keep alive # Seconds after connection is idle to send keep alive
if "keepalives_idle" not in db_options: if 'keepalives_idle' not in db_options:
db_options["keepalives_idle"] = int( db_options['keepalives_idle'] = int(
get_setting('INVENTREE_DB_TCP_KEEPALIVES_IDLE', 'database.tcp_keepalives_idle', 1) get_setting(
'INVENTREE_DB_TCP_KEEPALIVES_IDLE', 'database.tcp_keepalives_idle', 1
)
) )
# Seconds after missing ACK to send another keep alive # Seconds after missing ACK to send another keep alive
if "keepalives_interval" not in db_options: if 'keepalives_interval' not in db_options:
db_options["keepalives_interval"] = int( db_options['keepalives_interval'] = int(
get_setting("INVENTREE_DB_TCP_KEEPALIVES_INTERVAL", "database.tcp_keepalives_internal", "1") get_setting(
'INVENTREE_DB_TCP_KEEPALIVES_INTERVAL',
'database.tcp_keepalives_internal',
'1',
)
) )
# Number of missing ACKs before we close the connection # Number of missing ACKs before we close the connection
if "keepalives_count" not in db_options: if 'keepalives_count' not in db_options:
db_options["keepalives_count"] = int( db_options['keepalives_count'] = int(
get_setting("INVENTREE_DB_TCP_KEEPALIVES_COUNT", "database.tcp_keepalives_count", "5") get_setting(
'INVENTREE_DB_TCP_KEEPALIVES_COUNT',
'database.tcp_keepalives_count',
'5',
)
) )
# # Milliseconds for how long pending data should remain unacked # # Milliseconds for how long pending data should remain unacked
@ -621,12 +647,18 @@ if "postgres" in db_engine: # pragma: no cover
# protect against simultaneous changes. # protect against simultaneous changes.
# https://www.postgresql.org/docs/devel/transaction-iso.html # https://www.postgresql.org/docs/devel/transaction-iso.html
# https://docs.djangoproject.com/en/3.2/ref/databases/#isolation-level # https://docs.djangoproject.com/en/3.2/ref/databases/#isolation-level
if "isolation_level" not in db_options: if 'isolation_level' not in db_options:
serializable = get_boolean_setting('INVENTREE_DB_ISOLATION_SERIALIZABLE', 'database.serializable', False) serializable = get_boolean_setting(
db_options["isolation_level"] = ISOLATION_LEVEL_SERIALIZABLE if serializable else ISOLATION_LEVEL_READ_COMMITTED 'INVENTREE_DB_ISOLATION_SERIALIZABLE', 'database.serializable', False
)
db_options['isolation_level'] = (
ISOLATION_LEVEL_SERIALIZABLE
if serializable
else ISOLATION_LEVEL_READ_COMMITTED
)
# Specific options for MySql / MariaDB backend # Specific options for MySql / MariaDB backend
elif "mysql" in db_engine: # pragma: no cover elif 'mysql' in db_engine: # pragma: no cover
# TODO TCP time outs and keepalives # TODO TCP time outs and keepalives
# MariaDB's default isolation level is Repeatable Read which is # MariaDB's default isolation level is Repeatable Read which is
@ -635,12 +667,16 @@ elif "mysql" in db_engine: # pragma: no cover
# protect against siumltaneous changes. # protect against siumltaneous changes.
# https://mariadb.com/kb/en/mariadb-transactions-and-isolation-levels-for-sql-server-users/#changing-the-isolation-level # https://mariadb.com/kb/en/mariadb-transactions-and-isolation-levels-for-sql-server-users/#changing-the-isolation-level
# https://docs.djangoproject.com/en/3.2/ref/databases/#mysql-isolation-level # https://docs.djangoproject.com/en/3.2/ref/databases/#mysql-isolation-level
if "isolation_level" not in db_options: if 'isolation_level' not in db_options:
serializable = get_boolean_setting('INVENTREE_DB_ISOLATION_SERIALIZABLE', 'database.serializable', False) serializable = get_boolean_setting(
db_options["isolation_level"] = "serializable" if serializable else "read committed" 'INVENTREE_DB_ISOLATION_SERIALIZABLE', 'database.serializable', False
)
db_options['isolation_level'] = (
'serializable' if serializable else 'read committed'
)
# Specific options for sqlite backend # Specific options for sqlite backend
elif "sqlite" in db_engine: elif 'sqlite' in db_engine:
# TODO: Verify timeouts are not an issue because no network is involved for SQLite # TODO: Verify timeouts are not an issue because no network is involved for SQLite
# SQLite's default isolation level is Serializable due to SQLite's # SQLite's default isolation level is Serializable due to SQLite's
@ -653,31 +689,34 @@ elif "sqlite" in db_engine:
db_config['OPTIONS'] = db_options db_config['OPTIONS'] = db_options
# Set testing options for the database # Set testing options for the database
db_config['TEST'] = { db_config['TEST'] = {'CHARSET': 'utf8'}
'CHARSET': 'utf8',
}
# Set collation option for mysql test database # Set collation option for mysql test database
if 'mysql' in db_engine: if 'mysql' in db_engine:
db_config['TEST']['COLLATION'] = 'utf8_general_ci' # pragma: no cover db_config['TEST']['COLLATION'] = 'utf8_general_ci' # pragma: no cover
DATABASES = { DATABASES = {'default': db_config}
'default': db_config
}
# login settings # login settings
REMOTE_LOGIN = get_boolean_setting('INVENTREE_REMOTE_LOGIN', 'remote_login_enabled', False) REMOTE_LOGIN = get_boolean_setting(
REMOTE_LOGIN_HEADER = get_setting('INVENTREE_REMOTE_LOGIN_HEADER', 'remote_login_header', 'REMOTE_USER') 'INVENTREE_REMOTE_LOGIN', 'remote_login_enabled', False
)
REMOTE_LOGIN_HEADER = get_setting(
'INVENTREE_REMOTE_LOGIN_HEADER', 'remote_login_header', 'REMOTE_USER'
)
# sentry.io integration for error reporting # sentry.io integration for error reporting
SENTRY_ENABLED = get_boolean_setting('INVENTREE_SENTRY_ENABLED', 'sentry_enabled', False) SENTRY_ENABLED = get_boolean_setting(
'INVENTREE_SENTRY_ENABLED', 'sentry_enabled', False
)
# Default Sentry DSN (can be overridden if user wants custom sentry integration) # Default Sentry DSN (can be overridden if user wants custom sentry integration)
SENTRY_DSN = get_setting('INVENTREE_SENTRY_DSN', 'sentry_dsn', default_sentry_dsn()) SENTRY_DSN = get_setting('INVENTREE_SENTRY_DSN', 'sentry_dsn', default_sentry_dsn())
SENTRY_SAMPLE_RATE = float(get_setting('INVENTREE_SENTRY_SAMPLE_RATE', 'sentry_sample_rate', 0.1)) SENTRY_SAMPLE_RATE = float(
get_setting('INVENTREE_SENTRY_SAMPLE_RATE', 'sentry_sample_rate', 0.1)
)
if SENTRY_ENABLED and SENTRY_DSN: # pragma: no cover if SENTRY_ENABLED and SENTRY_DSN: # pragma: no cover
inventree_tags = { inventree_tags = {
'testing': TESTING, 'testing': TESTING,
'docker': DOCKER, 'docker': DOCKER,
@ -696,53 +735,47 @@ if cache_host: # pragma: no cover
# so don't wait too long for the cache as nothing in the cache should be # so don't wait too long for the cache as nothing in the cache should be
# irreplaceable. # irreplaceable.
_cache_options = { _cache_options = {
"CLIENT_CLASS": "django_redis.client.DefaultClient", 'CLIENT_CLASS': 'django_redis.client.DefaultClient',
"SOCKET_CONNECT_TIMEOUT": int(os.getenv("CACHE_CONNECT_TIMEOUT", "2")), 'SOCKET_CONNECT_TIMEOUT': int(os.getenv('CACHE_CONNECT_TIMEOUT', '2')),
"SOCKET_TIMEOUT": int(os.getenv("CACHE_SOCKET_TIMEOUT", "2")), 'SOCKET_TIMEOUT': int(os.getenv('CACHE_SOCKET_TIMEOUT', '2')),
"CONNECTION_POOL_KWARGS": { 'CONNECTION_POOL_KWARGS': {
"socket_keepalive": config.is_true( 'socket_keepalive': config.is_true(os.getenv('CACHE_TCP_KEEPALIVE', '1')),
os.getenv("CACHE_TCP_KEEPALIVE", "1") 'socket_keepalive_options': {
), socket.TCP_KEEPCNT: int(os.getenv('CACHE_KEEPALIVES_COUNT', '5')),
"socket_keepalive_options": { socket.TCP_KEEPIDLE: int(os.getenv('CACHE_KEEPALIVES_IDLE', '1')),
socket.TCP_KEEPCNT: int( socket.TCP_KEEPINTVL: int(os.getenv('CACHE_KEEPALIVES_INTERVAL', '1')),
os.getenv("CACHE_KEEPALIVES_COUNT", "5")
),
socket.TCP_KEEPIDLE: int(
os.getenv("CACHE_KEEPALIVES_IDLE", "1")
),
socket.TCP_KEEPINTVL: int(
os.getenv("CACHE_KEEPALIVES_INTERVAL", "1")
),
socket.TCP_USER_TIMEOUT: int( socket.TCP_USER_TIMEOUT: int(
os.getenv("CACHE_TCP_USER_TIMEOUT", "1000") os.getenv('CACHE_TCP_USER_TIMEOUT', '1000')
), ),
}, },
}, },
} }
CACHES = { CACHES = {
"default": { 'default': {
"BACKEND": "django_redis.cache.RedisCache", 'BACKEND': 'django_redis.cache.RedisCache',
"LOCATION": f"redis://{cache_host}:{cache_port}/0", 'LOCATION': f'redis://{cache_host}:{cache_port}/0',
"OPTIONS": _cache_options, 'OPTIONS': _cache_options,
}, }
} }
else: else:
CACHES = { CACHES = {'default': {'BACKEND': 'django.core.cache.backends.locmem.LocMemCache'}}
"default": {
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
},
}
_q_worker_timeout = int(get_setting('INVENTREE_BACKGROUND_TIMEOUT', 'background.timeout', 90)) _q_worker_timeout = int(
get_setting('INVENTREE_BACKGROUND_TIMEOUT', 'background.timeout', 90)
)
# django-q background worker configuration # django-q background worker configuration
Q_CLUSTER = { Q_CLUSTER = {
'name': 'InvenTree', 'name': 'InvenTree',
'label': 'Background Tasks', 'label': 'Background Tasks',
'workers': int(get_setting('INVENTREE_BACKGROUND_WORKERS', 'background.workers', 4)), 'workers': int(
get_setting('INVENTREE_BACKGROUND_WORKERS', 'background.workers', 4)
),
'timeout': _q_worker_timeout, 'timeout': _q_worker_timeout,
'retry': min(120, _q_worker_timeout + 30), 'retry': min(120, _q_worker_timeout + 30),
'max_attempts': int(get_setting('INVENTREE_BACKGROUND_MAX_ATTEMPTS', 'background.max_attempts', 5)), 'max_attempts': int(
get_setting('INVENTREE_BACKGROUND_MAX_ATTEMPTS', 'background.max_attempts', 5)
),
'queue_limit': 50, 'queue_limit': 50,
'catch_up': False, 'catch_up': False,
'bulk': 10, 'bulk': 10,
@ -754,40 +787,30 @@ Q_CLUSTER = {
# Configure django-q sentry integration # Configure django-q sentry integration
if SENTRY_ENABLED and SENTRY_DSN: if SENTRY_ENABLED and SENTRY_DSN:
Q_CLUSTER['error_reporter'] = { Q_CLUSTER['error_reporter'] = {'sentry': {'dsn': SENTRY_DSN}}
'sentry': {
'dsn': SENTRY_DSN
}
}
if cache_host: # pragma: no cover if cache_host: # pragma: no cover
# If using external redis cache, make the cache the broker for Django Q # If using external redis cache, make the cache the broker for Django Q
# as well # as well
Q_CLUSTER["django_redis"] = "worker" Q_CLUSTER['django_redis'] = 'worker'
# database user sessions # database user sessions
SESSION_ENGINE = 'user_sessions.backends.db' SESSION_ENGINE = 'user_sessions.backends.db'
LOGOUT_REDIRECT_URL = get_setting('INVENTREE_LOGOUT_REDIRECT_URL', 'logout_redirect_url', 'index') LOGOUT_REDIRECT_URL = get_setting(
SILENCED_SYSTEM_CHECKS = [ 'INVENTREE_LOGOUT_REDIRECT_URL', 'logout_redirect_url', 'index'
'admin.E410', )
] SILENCED_SYSTEM_CHECKS = ['admin.E410']
# Password validation # Password validation
# https://docs.djangoproject.com/en/1.10/ref/settings/#auth-password-validators # https://docs.djangoproject.com/en/1.10/ref/settings/#auth-password-validators
AUTH_PASSWORD_VALIDATORS = [ AUTH_PASSWORD_VALIDATORS = [
{ {
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator'
},
{
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
},
{
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
}, },
{'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator'},
{'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator'},
{'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator'},
] ]
# Extra (optional) URL validators # Extra (optional) URL validators
@ -796,7 +819,7 @@ AUTH_PASSWORD_VALIDATORS = [
EXTRA_URL_SCHEMES = get_setting('INVENTREE_EXTRA_URL_SCHEMES', 'extra_url_schemes', []) EXTRA_URL_SCHEMES = get_setting('INVENTREE_EXTRA_URL_SCHEMES', 'extra_url_schemes', [])
if type(EXTRA_URL_SCHEMES) not in [list]: # pragma: no cover if type(EXTRA_URL_SCHEMES) not in [list]: # pragma: no cover
logger.warning("extra_url_schemes not correctly formatted") logger.warning('extra_url_schemes not correctly formatted')
EXTRA_URL_SCHEMES = [] EXTRA_URL_SCHEMES = []
# Internationalization # Internationalization
@ -854,26 +877,21 @@ if get_boolean_setting('TEST_TRANSLATIONS', default_value=False): # pragma: no
LANGUAGES.append(('xx', 'Test')) LANGUAGES.append(('xx', 'Test'))
# Add custom languages not provided by Django # Add custom languages not provided by Django
EXTRA_LANG_INFO = { EXTRA_LANG_INFO = {'xx': {'code': 'xx', 'name': 'Test', 'name_local': 'Test'}}
'xx': {
'code': 'xx',
'name': 'Test',
'name_local': 'Test'
},
}
LANG_INFO = dict(django.conf.locale.LANG_INFO, **EXTRA_LANG_INFO) LANG_INFO = dict(django.conf.locale.LANG_INFO, **EXTRA_LANG_INFO)
django.conf.locale.LANG_INFO = LANG_INFO django.conf.locale.LANG_INFO = LANG_INFO
# Currencies available for use # Currencies available for use
CURRENCIES = get_setting( CURRENCIES = get_setting(
'INVENTREE_CURRENCIES', 'currencies', 'INVENTREE_CURRENCIES',
'currencies',
['AUD', 'CAD', 'CNY', 'EUR', 'GBP', 'JPY', 'NZD', 'USD'], ['AUD', 'CAD', 'CNY', 'EUR', 'GBP', 'JPY', 'NZD', 'USD'],
typecast=list, typecast=list,
) )
# Ensure that at least one currency value is available # Ensure that at least one currency value is available
if len(CURRENCIES) == 0: # pragma: no cover if len(CURRENCIES) == 0: # pragma: no cover
logger.warning("No currencies selected: Defaulting to USD") logger.warning('No currencies selected: Defaulting to USD')
CURRENCIES = ['USD'] CURRENCIES = ['USD']
# Maximum number of decimal places for currency rendering # Maximum number of decimal places for currency rendering
@ -889,12 +907,18 @@ for currency in CURRENCIES:
EXCHANGE_BACKEND = 'InvenTree.exchange.InvenTreeExchange' EXCHANGE_BACKEND = 'InvenTree.exchange.InvenTreeExchange'
# Email configuration options # Email configuration options
EMAIL_BACKEND = get_setting('INVENTREE_EMAIL_BACKEND', 'email.backend', 'django.core.mail.backends.smtp.EmailBackend') EMAIL_BACKEND = get_setting(
'INVENTREE_EMAIL_BACKEND',
'email.backend',
'django.core.mail.backends.smtp.EmailBackend',
)
EMAIL_HOST = get_setting('INVENTREE_EMAIL_HOST', 'email.host', '') EMAIL_HOST = get_setting('INVENTREE_EMAIL_HOST', 'email.host', '')
EMAIL_PORT = get_setting('INVENTREE_EMAIL_PORT', 'email.port', 25, typecast=int) EMAIL_PORT = get_setting('INVENTREE_EMAIL_PORT', 'email.port', 25, typecast=int)
EMAIL_HOST_USER = get_setting('INVENTREE_EMAIL_USERNAME', 'email.username', '') EMAIL_HOST_USER = get_setting('INVENTREE_EMAIL_USERNAME', 'email.username', '')
EMAIL_HOST_PASSWORD = get_setting('INVENTREE_EMAIL_PASSWORD', 'email.password', '') EMAIL_HOST_PASSWORD = get_setting('INVENTREE_EMAIL_PASSWORD', 'email.password', '')
EMAIL_SUBJECT_PREFIX = get_setting('INVENTREE_EMAIL_PREFIX', 'email.prefix', '[InvenTree] ') EMAIL_SUBJECT_PREFIX = get_setting(
'INVENTREE_EMAIL_PREFIX', 'email.prefix', '[InvenTree] '
)
EMAIL_USE_TLS = get_boolean_setting('INVENTREE_EMAIL_TLS', 'email.tls', False) EMAIL_USE_TLS = get_boolean_setting('INVENTREE_EMAIL_TLS', 'email.tls', False)
EMAIL_USE_SSL = get_boolean_setting('INVENTREE_EMAIL_SSL', 'email.ssl', False) EMAIL_USE_SSL = get_boolean_setting('INVENTREE_EMAIL_SSL', 'email.ssl', False)
@ -907,9 +931,7 @@ if not DEFAULT_FROM_EMAIL:
EMAIL_USE_LOCALTIME = False EMAIL_USE_LOCALTIME = False
EMAIL_TIMEOUT = 60 EMAIL_TIMEOUT = 60
LOCALE_PATHS = ( LOCALE_PATHS = (BASE_DIR.joinpath('locale/'),)
BASE_DIR.joinpath('locale/'),
)
TIME_ZONE = get_setting('INVENTREE_TIMEZONE', 'timezone', 'UTC') TIME_ZONE = get_setting('INVENTREE_TIMEZONE', 'timezone', 'UTC')
@ -922,9 +944,7 @@ USE_L10N = True
if not TESTING: if not TESTING:
USE_TZ = True # pragma: no cover USE_TZ = True # pragma: no cover
DATE_INPUT_FORMATS = [ DATE_INPUT_FORMATS = ['%Y-%m-%d']
"%Y-%m-%d",
]
# crispy forms use the bootstrap templates # crispy forms use the bootstrap templates
CRISPY_TEMPLATE_PACK = 'bootstrap4' CRISPY_TEMPLATE_PACK = 'bootstrap4'
@ -935,10 +955,11 @@ IMPORT_EXPORT_USE_TRANSACTIONS = True
SITE_ID = 1 SITE_ID = 1
# Load the allauth social backends # Load the allauth social backends
SOCIAL_BACKENDS = get_setting('INVENTREE_SOCIAL_BACKENDS', 'social_backends', [], typecast=list) SOCIAL_BACKENDS = get_setting(
'INVENTREE_SOCIAL_BACKENDS', 'social_backends', [], typecast=list
)
for app in SOCIAL_BACKENDS: for app in SOCIAL_BACKENDS:
# Ensure that the app starts with 'allauth.socialaccount.providers' # Ensure that the app starts with 'allauth.socialaccount.providers'
social_prefix = 'allauth.socialaccount.providers.' social_prefix = 'allauth.socialaccount.providers.'
@ -947,14 +968,22 @@ for app in SOCIAL_BACKENDS:
INSTALLED_APPS.append(app) # pragma: no cover INSTALLED_APPS.append(app) # pragma: no cover
SOCIALACCOUNT_PROVIDERS = get_setting('INVENTREE_SOCIAL_PROVIDERS', 'social_providers', None, typecast=dict) SOCIALACCOUNT_PROVIDERS = get_setting(
'INVENTREE_SOCIAL_PROVIDERS', 'social_providers', None, typecast=dict
)
SOCIALACCOUNT_STORE_TOKENS = True SOCIALACCOUNT_STORE_TOKENS = True
# settings for allauth # settings for allauth
ACCOUNT_EMAIL_CONFIRMATION_EXPIRE_DAYS = get_setting('INVENTREE_LOGIN_CONFIRM_DAYS', 'login_confirm_days', 3, typecast=int) ACCOUNT_EMAIL_CONFIRMATION_EXPIRE_DAYS = get_setting(
ACCOUNT_LOGIN_ATTEMPTS_LIMIT = get_setting('INVENTREE_LOGIN_ATTEMPTS', 'login_attempts', 5, typecast=int) 'INVENTREE_LOGIN_CONFIRM_DAYS', 'login_confirm_days', 3, typecast=int
ACCOUNT_DEFAULT_HTTP_PROTOCOL = get_setting('INVENTREE_LOGIN_DEFAULT_HTTP_PROTOCOL', 'login_default_protocol', 'http') )
ACCOUNT_LOGIN_ATTEMPTS_LIMIT = get_setting(
'INVENTREE_LOGIN_ATTEMPTS', 'login_attempts', 5, typecast=int
)
ACCOUNT_DEFAULT_HTTP_PROTOCOL = get_setting(
'INVENTREE_LOGIN_DEFAULT_HTTP_PROTOCOL', 'login_default_protocol', 'http'
)
ACCOUNT_LOGOUT_ON_PASSWORD_CHANGE = True ACCOUNT_LOGOUT_ON_PASSWORD_CHANGE = True
ACCOUNT_PREVENT_ENUMERATION = True ACCOUNT_PREVENT_ENUMERATION = True
# 2FA # 2FA
@ -981,21 +1010,17 @@ ACCOUNT_ADAPTER = 'InvenTree.forms.CustomAccountAdapter'
MARKDOWNIFY = { MARKDOWNIFY = {
'default': { 'default': {
'BLEACH': True, 'BLEACH': True,
'WHITELIST_ATTRS': [ 'WHITELIST_ATTRS': ['href', 'src', 'alt'],
'href', 'MARKDOWN_EXTENSIONS': ['markdown.extensions.extra'],
'src',
'alt',
],
'MARKDOWN_EXTENSIONS': [
'markdown.extensions.extra'
],
'WHITELIST_TAGS': [ 'WHITELIST_TAGS': [
'a', 'a',
'abbr', 'abbr',
'b', 'b',
'blockquote', 'blockquote',
'em', 'em',
'h1', 'h2', 'h3', 'h1',
'h2',
'h3',
'i', 'i',
'img', 'img',
'li', 'li',
@ -1008,46 +1033,55 @@ MARKDOWNIFY = {
'tbody', 'tbody',
'th', 'th',
'tr', 'tr',
'td' 'td',
], ],
} }
} }
# Ignore these error typeps for in-database error logging # Ignore these error typeps for in-database error logging
IGNORED_ERRORS = [ IGNORED_ERRORS = [Http404, django.core.exceptions.PermissionDenied]
Http404,
django.core.exceptions.PermissionDenied,
]
# Maintenance mode # Maintenance mode
MAINTENANCE_MODE_RETRY_AFTER = 60 MAINTENANCE_MODE_RETRY_AFTER = 60
MAINTENANCE_MODE_STATE_BACKEND = 'maintenance_mode.backends.StaticStorageBackend' MAINTENANCE_MODE_STATE_BACKEND = 'maintenance_mode.backends.StaticStorageBackend'
# Are plugins enabled? # Are plugins enabled?
PLUGINS_ENABLED = get_boolean_setting('INVENTREE_PLUGINS_ENABLED', 'plugins_enabled', False) PLUGINS_ENABLED = get_boolean_setting(
'INVENTREE_PLUGINS_ENABLED', 'plugins_enabled', False
)
PLUGIN_FILE = config.get_plugin_file() PLUGIN_FILE = config.get_plugin_file()
# Plugin test settings # Plugin test settings
PLUGIN_TESTING = get_setting('INVENTREE_PLUGIN_TESTING', 'PLUGIN_TESTING', TESTING) # Are plugins being tested? PLUGIN_TESTING = get_setting(
PLUGIN_TESTING_SETUP = get_setting('INVENTREE_PLUGIN_TESTING_SETUP', 'PLUGIN_TESTING_SETUP', False) # Load plugins from setup hooks in testing? 'INVENTREE_PLUGIN_TESTING', 'PLUGIN_TESTING', TESTING
) # Are plugins being tested?
PLUGIN_TESTING_SETUP = get_setting(
'INVENTREE_PLUGIN_TESTING_SETUP', 'PLUGIN_TESTING_SETUP', False
) # Load plugins from setup hooks in testing?
PLUGIN_TESTING_EVENTS = False # Flag if events are tested right now PLUGIN_TESTING_EVENTS = False # Flag if events are tested right now
PLUGIN_RETRY = get_setting('INVENTREE_PLUGIN_RETRY', 'PLUGIN_RETRY', 5) # How often should plugin loading be tried? PLUGIN_RETRY = get_setting(
'INVENTREE_PLUGIN_RETRY', 'PLUGIN_RETRY', 5
) # How often should plugin loading be tried?
PLUGIN_FILE_CHECKED = False # Was the plugin file checked? PLUGIN_FILE_CHECKED = False # Was the plugin file checked?
# Site URL can be specified statically, or via a run-time setting # Site URL can be specified statically, or via a run-time setting
SITE_URL = get_setting('INVENTREE_SITE_URL', 'site_url', None) SITE_URL = get_setting('INVENTREE_SITE_URL', 'site_url', None)
if SITE_URL: if SITE_URL:
logger.info("Site URL: %s", SITE_URL) logger.info('Site URL: %s', SITE_URL)
# Check that the site URL is valid # Check that the site URL is valid
validator = URLValidator() validator = URLValidator()
validator(SITE_URL) validator(SITE_URL)
# User interface customization values # User interface customization values
CUSTOM_LOGO = get_custom_file('INVENTREE_CUSTOM_LOGO', 'customize.logo', 'custom logo', lookup_media=True) CUSTOM_LOGO = get_custom_file(
CUSTOM_SPLASH = get_custom_file('INVENTREE_CUSTOM_SPLASH', 'customize.splash', 'custom splash') 'INVENTREE_CUSTOM_LOGO', 'customize.logo', 'custom logo', lookup_media=True
)
CUSTOM_SPLASH = get_custom_file(
'INVENTREE_CUSTOM_SPLASH', 'customize.splash', 'custom splash'
)
CUSTOMIZE = get_setting('INVENTREE_CUSTOMIZE', 'customize', {}) CUSTOMIZE = get_setting('INVENTREE_CUSTOMIZE', 'customize', {})
@ -1056,7 +1090,7 @@ FRONTEND_SETTINGS = config.get_frontend_settings(debug=DEBUG)
FRONTEND_URL_BASE = FRONTEND_SETTINGS.get('base_url', 'platform') FRONTEND_URL_BASE = FRONTEND_SETTINGS.get('base_url', 'platform')
if DEBUG: if DEBUG:
logger.info("InvenTree running with DEBUG enabled") logger.info('InvenTree running with DEBUG enabled')
logger.info("MEDIA_ROOT: '%s'", MEDIA_ROOT) logger.info("MEDIA_ROOT: '%s'", MEDIA_ROOT)
logger.info("STATIC_ROOT: '%s'", STATIC_ROOT) logger.info("STATIC_ROOT: '%s'", STATIC_ROOT)
@ -1068,7 +1102,7 @@ FLAGS = {
{'condition': 'parameter', 'value': 'experimental='}, {'condition': 'parameter', 'value': 'experimental='},
], # Should experimental features be turned on? ], # Should experimental features be turned on?
'NEXT_GEN': [ 'NEXT_GEN': [
{'condition': 'parameter', 'value': 'ngen='}, {'condition': 'parameter', 'value': 'ngen='}
], # Should next-gen features be turned on? ], # Should next-gen features be turned on?
} }
@ -1076,12 +1110,12 @@ FLAGS = {
CUSTOM_FLAGS = get_setting('INVENTREE_FLAGS', 'flags', None, typecast=dict) CUSTOM_FLAGS = get_setting('INVENTREE_FLAGS', 'flags', None, typecast=dict)
if CUSTOM_FLAGS: if CUSTOM_FLAGS:
if not isinstance(CUSTOM_FLAGS, dict): if not isinstance(CUSTOM_FLAGS, dict):
logger.error("Invalid custom flags, must be valid dict: %s", str(CUSTOM_FLAGS)) logger.error('Invalid custom flags, must be valid dict: %s', str(CUSTOM_FLAGS))
else: else:
logger.info("Custom flags: %s", str(CUSTOM_FLAGS)) logger.info('Custom flags: %s', str(CUSTOM_FLAGS))
FLAGS.update(CUSTOM_FLAGS) FLAGS.update(CUSTOM_FLAGS)
# Magic login django-sesame # Magic login django-sesame
SESAME_MAX_AGE = 300 SESAME_MAX_AGE = 300
# LOGIN_REDIRECT_URL = f"/{FRONTEND_URL_BASE}/logged-in/" # LOGIN_REDIRECT_URL = f"/{FRONTEND_URL_BASE}/logged-in/"
LOGIN_REDIRECT_URL = "/index/" LOGIN_REDIRECT_URL = '/index/'

View File

@ -1,4 +1,5 @@
"""API endpoints for social authentication with allauth.""" """API endpoints for social authentication with allauth."""
import logging import logging
from importlib import import_module from importlib import import_module
@ -6,8 +7,7 @@ from django.urls import NoReverseMatch, include, path, reverse
from allauth.account.models import EmailAddress from allauth.account.models import EmailAddress
from allauth.socialaccount import providers from allauth.socialaccount import providers
from allauth.socialaccount.providers.oauth2.views import (OAuth2Adapter, from allauth.socialaccount.providers.oauth2.views import OAuth2Adapter, OAuth2LoginView
OAuth2LoginView)
from drf_spectacular.utils import OpenApiResponse, extend_schema from drf_spectacular.utils import OpenApiResponse, extend_schema
from rest_framework.exceptions import NotFound from rest_framework.exceptions import NotFound
from rest_framework.permissions import AllowAny, IsAuthenticated from rest_framework.permissions import AllowAny, IsAuthenticated
@ -23,6 +23,7 @@ logger = logging.getLogger('inventree')
class GenericOAuth2ApiLoginView(OAuth2LoginView): class GenericOAuth2ApiLoginView(OAuth2LoginView):
"""Api view to login a user with a social account""" """Api view to login a user with a social account"""
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
"""Dispatch the regular login view directly.""" """Dispatch the regular login view directly."""
return self.login(request, *args, **kwargs) return self.login(request, *args, **kwargs)
@ -44,8 +45,16 @@ class GenericOAuth2ApiConnectView(GenericOAuth2ApiLoginView):
def handle_oauth2(adapter: OAuth2Adapter): def handle_oauth2(adapter: OAuth2Adapter):
"""Define urls for oauth2 endpoints.""" """Define urls for oauth2 endpoints."""
return [ return [
path('login/', GenericOAuth2ApiLoginView.adapter_view(adapter), name=f'{provider.id}_api_login'), path(
path('connect/', GenericOAuth2ApiConnectView.adapter_view(adapter), name=f'{provider.id}_api_connect'), 'login/',
GenericOAuth2ApiLoginView.adapter_view(adapter),
name=f'{provider.id}_api_login',
),
path(
'connect/',
GenericOAuth2ApiConnectView.adapter_view(adapter),
name=f'{provider.id}_api_connect',
),
] ]
@ -64,15 +73,20 @@ social_auth_urlpatterns = []
provider_urlpatterns = [] provider_urlpatterns = []
for name, provider in providers.registry.provider_map.items(): for name, provider in providers.registry.provider_map.items():
try: try:
prov_mod = import_module(provider.get_package() + ".views") prov_mod = import_module(provider.get_package() + '.views')
except ImportError: except ImportError:
logger.exception("Could not import authentication provider %s", name) logger.exception('Could not import authentication provider %s', name)
continue continue
# Try to extract the adapter class # Try to extract the adapter class
adapters = [cls for cls in prov_mod.__dict__.values() if isinstance(cls, type) and not cls == OAuth2Adapter and issubclass(cls, OAuth2Adapter)] adapters = [
cls
for cls in prov_mod.__dict__.values()
if isinstance(cls, type)
and not cls == OAuth2Adapter
and issubclass(cls, OAuth2Adapter)
]
# Get urls # Get urls
urls = [] urls = []
@ -80,10 +94,17 @@ for name, provider in providers.registry.provider_map.items():
urls = handle_oauth2(adapter=adapters[0]) urls = handle_oauth2(adapter=adapters[0])
else: else:
if provider.id in legacy: if provider.id in legacy:
logger.warning('`%s` is not supported on platform UI. Use `%s` instead.', provider.id, legacy[provider.id]) logger.warning(
'`%s` is not supported on platform UI. Use `%s` instead.',
provider.id,
legacy[provider.id],
)
continue continue
else: else:
logger.error('Found handler that is not yet ready for platform UI: `%s`. Open an feature request on GitHub if you need it implemented.', provider.id) logger.error(
'Found handler that is not yet ready for platform UI: `%s`. Open an feature request on GitHub if you need it implemented.',
provider.id,
)
continue continue
provider_urlpatterns += [path(f'{provider.id}/', include(urls))] provider_urlpatterns += [path(f'{provider.id}/', include(urls))]
@ -93,6 +114,7 @@ social_auth_urlpatterns += provider_urlpatterns
class SocialProviderListView(ListAPI): class SocialProviderListView(ListAPI):
"""List of available social providers.""" """List of available social providers."""
permission_classes = (AllowAny,) permission_classes = (AllowAny,)
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
@ -102,21 +124,27 @@ class SocialProviderListView(ListAPI):
provider_data = { provider_data = {
'id': provider.id, 'id': provider.id,
'name': provider.name, 'name': provider.name,
'configured': False 'configured': False,
} }
try: try:
provider_data['login'] = request.build_absolute_uri(reverse(f'{provider.id}_api_login')) provider_data['login'] = request.build_absolute_uri(
reverse(f'{provider.id}_api_login')
)
except NoReverseMatch: except NoReverseMatch:
provider_data['login'] = None provider_data['login'] = None
try: try:
provider_data['connect'] = request.build_absolute_uri(reverse(f'{provider.id}_api_connect')) provider_data['connect'] = request.build_absolute_uri(
reverse(f'{provider.id}_api_connect')
)
except NoReverseMatch: except NoReverseMatch:
provider_data['connect'] = None provider_data['connect'] = None
provider_data['configured'] = InvenTree.sso.check_provider(provider) provider_data['configured'] = InvenTree.sso.check_provider(provider)
provider_data['display_name'] = InvenTree.sso.provider_display_name(provider) provider_data['display_name'] = InvenTree.sso.provider_display_name(
provider
)
provider_list.append(provider_data) provider_list.append(provider_data)
@ -124,7 +152,7 @@ class SocialProviderListView(ListAPI):
'sso_enabled': InvenTree.sso.login_enabled(), 'sso_enabled': InvenTree.sso.login_enabled(),
'sso_registration': InvenTree.sso.registration_enabled(), 'sso_registration': InvenTree.sso.registration_enabled(),
'mfa_required': InvenTreeSetting.get_setting('LOGIN_ENFORCE_MFA'), 'mfa_required': InvenTreeSetting.get_setting('LOGIN_ENFORCE_MFA'),
'providers': provider_list 'providers': provider_list,
} }
return Response(data) return Response(data)
@ -151,6 +179,7 @@ class EmptyEmailAddressSerializer(InvenTreeModelSerializer):
class EmailListView(ListCreateAPI): class EmailListView(ListCreateAPI):
"""List of registered email addresses for current users.""" """List of registered email addresses for current users."""
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
serializer_class = EmailAddressSerializer serializer_class = EmailAddressSerializer
@ -161,12 +190,15 @@ class EmailListView(ListCreateAPI):
class EmailActionMixin(CreateAPI): class EmailActionMixin(CreateAPI):
"""Mixin to modify email addresses for current users.""" """Mixin to modify email addresses for current users."""
serializer_class = EmptyEmailAddressSerializer serializer_class = EmptyEmailAddressSerializer
permission_classes = (IsAuthenticated,) permission_classes = (IsAuthenticated,)
def get_queryset(self): def get_queryset(self):
"""Filter queryset for current user.""" """Filter queryset for current user."""
return EmailAddress.objects.filter(user=self.request.user, pk=self.kwargs['pk']).first() return EmailAddress.objects.filter(
user=self.request.user, pk=self.kwargs['pk']
).first()
@extend_schema(responses={200: OpenApiResponse(response=EmailAddressSerializer)}) @extend_schema(responses={200: OpenApiResponse(response=EmailAddressSerializer)})
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):

View File

@ -1,6 +1,5 @@
"""Helper functions for Single Sign On functionality""" """Helper functions for Single Sign On functionality"""
import logging import logging
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
@ -49,7 +48,7 @@ def check_provider(provider, raise_error=False):
if allauth.app_settings.SITES_ENABLED: if allauth.app_settings.SITES_ENABLED:
# At least one matching site must be specified # At least one matching site must be specified
if not app.sites.exists(): if not app.sites.exists():
logger.error("SocialApp %s has no sites configured", app) logger.error('SocialApp %s has no sites configured', app)
return False return False
# At this point, we assume that the provider is correctly configured # At this point, we assume that the provider is correctly configured

View File

@ -13,7 +13,7 @@ from django_q.status import Stat
import InvenTree.email import InvenTree.email
import InvenTree.ready import InvenTree.ready
logger = logging.getLogger("inventree") logger = logging.getLogger('inventree')
def is_worker_running(**kwargs): def is_worker_running(**kwargs):
@ -33,9 +33,7 @@ def is_worker_running(**kwargs):
now = timezone.now() now = timezone.now()
past = now - timedelta(minutes=10) past = now - timedelta(minutes=10)
results = Success.objects.filter( results = Success.objects.filter(started__gte=past)
started__gte=past
)
# If any results are returned, then the background worker is running! # If any results are returned, then the background worker is running!
try: try:
@ -65,13 +63,13 @@ def check_system_health(**kwargs):
if not is_worker_running(**kwargs): # pragma: no cover if not is_worker_running(**kwargs): # pragma: no cover
result = False result = False
logger.warning(_("Background worker check failed")) logger.warning(_('Background worker check failed'))
if not InvenTree.email.is_email_configured(): # pragma: no cover if not InvenTree.email.is_email_configured(): # pragma: no cover
result = False result = False
logger.warning(_("Email backend not configured")) logger.warning(_('Email backend not configured'))
if not result: # pragma: no cover if not result: # pragma: no cover
logger.warning(_("InvenTree system health checks failed")) logger.warning(_('InvenTree system health checks failed'))
return result return result

View File

@ -9,68 +9,68 @@ class PurchaseOrderStatus(StatusCode):
"""Defines a set of status codes for a PurchaseOrder.""" """Defines a set of status codes for a PurchaseOrder."""
# Order status codes # Order status codes
PENDING = 10, _("Pending"), 'secondary' # Order is pending (not yet placed) PENDING = 10, _('Pending'), 'secondary' # Order is pending (not yet placed)
PLACED = 20, _("Placed"), 'primary' # Order has been placed with supplier PLACED = 20, _('Placed'), 'primary' # Order has been placed with supplier
COMPLETE = 30, _("Complete"), 'success' # Order has been completed COMPLETE = 30, _('Complete'), 'success' # Order has been completed
CANCELLED = 40, _("Cancelled"), 'danger' # Order was cancelled CANCELLED = 40, _('Cancelled'), 'danger' # Order was cancelled
LOST = 50, _("Lost"), 'warning' # Order was lost LOST = 50, _('Lost'), 'warning' # Order was lost
RETURNED = 60, _("Returned"), 'warning' # Order was returned RETURNED = 60, _('Returned'), 'warning' # Order was returned
class PurchaseOrderStatusGroups: class PurchaseOrderStatusGroups:
"""Groups for PurchaseOrderStatus codes.""" """Groups for PurchaseOrderStatus codes."""
# Open orders # Open orders
OPEN = [ OPEN = [PurchaseOrderStatus.PENDING.value, PurchaseOrderStatus.PLACED.value]
PurchaseOrderStatus.PENDING.value,
PurchaseOrderStatus.PLACED.value,
]
# Failed orders # Failed orders
FAILED = [ FAILED = [
PurchaseOrderStatus.CANCELLED.value, PurchaseOrderStatus.CANCELLED.value,
PurchaseOrderStatus.LOST.value, PurchaseOrderStatus.LOST.value,
PurchaseOrderStatus.RETURNED.value PurchaseOrderStatus.RETURNED.value,
] ]
class SalesOrderStatus(StatusCode): class SalesOrderStatus(StatusCode):
"""Defines a set of status codes for a SalesOrder.""" """Defines a set of status codes for a SalesOrder."""
PENDING = 10, _("Pending"), 'secondary' # Order is pending PENDING = 10, _('Pending'), 'secondary' # Order is pending
IN_PROGRESS = 15, _("In Progress"), 'primary' # Order has been issued, and is in progress IN_PROGRESS = (
SHIPPED = 20, _("Shipped"), 'success' # Order has been shipped to customer 15,
CANCELLED = 40, _("Cancelled"), 'danger' # Order has been cancelled _('In Progress'),
LOST = 50, _("Lost"), 'warning' # Order was lost 'primary',
RETURNED = 60, _("Returned"), 'warning' # Order was returned ) # Order has been issued, and is in progress
SHIPPED = 20, _('Shipped'), 'success' # Order has been shipped to customer
CANCELLED = 40, _('Cancelled'), 'danger' # Order has been cancelled
LOST = 50, _('Lost'), 'warning' # Order was lost
RETURNED = 60, _('Returned'), 'warning' # Order was returned
class SalesOrderStatusGroups: class SalesOrderStatusGroups:
"""Groups for SalesOrderStatus codes.""" """Groups for SalesOrderStatus codes."""
# Open orders # Open orders
OPEN = [ OPEN = [SalesOrderStatus.PENDING.value, SalesOrderStatus.IN_PROGRESS.value]
SalesOrderStatus.PENDING.value,
SalesOrderStatus.IN_PROGRESS.value,
]
# Completed orders # Completed orders
COMPLETE = [ COMPLETE = [SalesOrderStatus.SHIPPED.value]
SalesOrderStatus.SHIPPED.value,
]
class StockStatus(StatusCode): class StockStatus(StatusCode):
"""Status codes for Stock.""" """Status codes for Stock."""
OK = 10, _("OK"), 'success' # Item is OK OK = 10, _('OK'), 'success' # Item is OK
ATTENTION = 50, _("Attention needed"), 'warning' # Item requires attention ATTENTION = 50, _('Attention needed'), 'warning' # Item requires attention
DAMAGED = 55, _("Damaged"), 'warning' # Item is damaged DAMAGED = 55, _('Damaged'), 'warning' # Item is damaged
DESTROYED = 60, _("Destroyed"), 'danger' # Item is destroyed DESTROYED = 60, _('Destroyed'), 'danger' # Item is destroyed
REJECTED = 65, _("Rejected"), 'danger' # Item is rejected REJECTED = 65, _('Rejected'), 'danger' # Item is rejected
LOST = 70, _("Lost"), 'dark' # Item has been lost LOST = 70, _('Lost'), 'dark' # Item has been lost
QUARANTINED = 75, _("Quarantined"), 'info' # Item has been quarantined and is unavailable QUARANTINED = (
RETURNED = 85, _("Returned"), 'warning' # Item has been returned from a customer 75,
_('Quarantined'),
'info',
) # Item has been quarantined and is unavailable
RETURNED = 85, _('Returned'), 'warning' # Item has been returned from a customer
class StockStatusGroups: class StockStatusGroups:
@ -129,7 +129,7 @@ class StockHistoryCode(StatusCode):
BUILD_CONSUMED = 57, _('Consumed by build order') BUILD_CONSUMED = 57, _('Consumed by build order')
# Sales order codes # Sales order codes
SHIPPED_AGAINST_SALES_ORDER = 60, _("Shipped against Sales Order") SHIPPED_AGAINST_SALES_ORDER = 60, _('Shipped against Sales Order')
# Purchase order codes # Purchase order codes
RECEIVED_AGAINST_PURCHASE_ORDER = 70, _('Received against Purchase Order') RECEIVED_AGAINST_PURCHASE_ORDER = 70, _('Received against Purchase Order')
@ -145,59 +145,53 @@ class StockHistoryCode(StatusCode):
class BuildStatus(StatusCode): class BuildStatus(StatusCode):
"""Build status codes.""" """Build status codes."""
PENDING = 10, _("Pending"), 'secondary' # Build is pending / active PENDING = 10, _('Pending'), 'secondary' # Build is pending / active
PRODUCTION = 20, _("Production"), 'primary' # BuildOrder is in production PRODUCTION = 20, _('Production'), 'primary' # BuildOrder is in production
CANCELLED = 30, _("Cancelled"), 'danger' # Build was cancelled CANCELLED = 30, _('Cancelled'), 'danger' # Build was cancelled
COMPLETE = 40, _("Complete"), 'success' # Build is complete COMPLETE = 40, _('Complete'), 'success' # Build is complete
class BuildStatusGroups: class BuildStatusGroups:
"""Groups for BuildStatus codes.""" """Groups for BuildStatus codes."""
ACTIVE_CODES = [ ACTIVE_CODES = [BuildStatus.PENDING.value, BuildStatus.PRODUCTION.value]
BuildStatus.PENDING.value,
BuildStatus.PRODUCTION.value,
]
class ReturnOrderStatus(StatusCode): class ReturnOrderStatus(StatusCode):
"""Defines a set of status codes for a ReturnOrder""" """Defines a set of status codes for a ReturnOrder"""
# Order is pending, waiting for receipt of items # Order is pending, waiting for receipt of items
PENDING = 10, _("Pending"), 'secondary' PENDING = 10, _('Pending'), 'secondary'
# Items have been received, and are being inspected # Items have been received, and are being inspected
IN_PROGRESS = 20, _("In Progress"), 'primary' IN_PROGRESS = 20, _('In Progress'), 'primary'
COMPLETE = 30, _("Complete"), 'success' COMPLETE = 30, _('Complete'), 'success'
CANCELLED = 40, _("Cancelled"), 'danger' CANCELLED = 40, _('Cancelled'), 'danger'
class ReturnOrderStatusGroups: class ReturnOrderStatusGroups:
"""Groups for ReturnOrderStatus codes.""" """Groups for ReturnOrderStatus codes."""
OPEN = [ OPEN = [ReturnOrderStatus.PENDING.value, ReturnOrderStatus.IN_PROGRESS.value]
ReturnOrderStatus.PENDING.value,
ReturnOrderStatus.IN_PROGRESS.value,
]
class ReturnOrderLineStatus(StatusCode): class ReturnOrderLineStatus(StatusCode):
"""Defines a set of status codes for a ReturnOrderLineItem""" """Defines a set of status codes for a ReturnOrderLineItem"""
PENDING = 10, _("Pending"), 'secondary' PENDING = 10, _('Pending'), 'secondary'
# Item is to be returned to customer, no other action # Item is to be returned to customer, no other action
RETURN = 20, _("Return"), 'success' RETURN = 20, _('Return'), 'success'
# Item is to be repaired, and returned to customer # Item is to be repaired, and returned to customer
REPAIR = 30, _("Repair"), 'primary' REPAIR = 30, _('Repair'), 'primary'
# Item is to be replaced (new item shipped) # Item is to be replaced (new item shipped)
REPLACE = 40, _("Replace"), 'warning' REPLACE = 40, _('Replace'), 'warning'
# Item is to be refunded (cannot be repaired) # Item is to be refunded (cannot be repaired)
REFUND = 50, _("Refund"), 'info' REFUND = 50, _('Refund'), 'info'
# Item is rejected # Item is rejected
REJECT = 60, _("Reject"), 'danger' REJECT = 60, _('Reject'), 'danger'

View File

@ -16,20 +16,22 @@ from django.core.exceptions import AppRegistryNotReady
from django.core.management import call_command from django.core.management import call_command
from django.db import DEFAULT_DB_ALIAS, connections from django.db import DEFAULT_DB_ALIAS, connections
from django.db.migrations.executor import MigrationExecutor from django.db.migrations.executor import MigrationExecutor
from django.db.utils import (NotSupportedError, OperationalError, from django.db.utils import NotSupportedError, OperationalError, ProgrammingError
ProgrammingError)
from django.utils import timezone from django.utils import timezone
import requests import requests
from maintenance_mode.core import (get_maintenance_mode, maintenance_mode_on, from maintenance_mode.core import (
set_maintenance_mode) get_maintenance_mode,
maintenance_mode_on,
set_maintenance_mode,
)
from InvenTree.config import get_setting from InvenTree.config import get_setting
from plugin import registry from plugin import registry
from .version import isInvenTreeUpToDate from .version import isInvenTreeUpToDate
logger = logging.getLogger("inventree") logger = logging.getLogger('inventree')
def schedule_task(taskname, **kwargs): def schedule_task(taskname, **kwargs):
@ -44,7 +46,7 @@ def schedule_task(taskname, **kwargs):
try: try:
from django_q.models import Schedule from django_q.models import Schedule
except AppRegistryNotReady: # pragma: no cover except AppRegistryNotReady: # pragma: no cover
logger.info("Could not start background tasks - App registry not ready") logger.info('Could not start background tasks - App registry not ready')
return return
try: try:
@ -57,11 +59,7 @@ def schedule_task(taskname, **kwargs):
else: else:
logger.info("Creating scheduled task '%s'", taskname) logger.info("Creating scheduled task '%s'", taskname)
Schedule.objects.create( Schedule.objects.create(name=taskname, func=taskname, **kwargs)
name=taskname,
func=taskname,
**kwargs
)
except (OperationalError, ProgrammingError): # pragma: no cover except (OperationalError, ProgrammingError): # pragma: no cover
# Required if the DB is not ready yet # Required if the DB is not ready yet
pass pass
@ -96,7 +94,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool:
from InvenTree.ready import isInTestMode from InvenTree.ready import isInTestMode
if n_days <= 0: if n_days <= 0:
logger.info("Specified interval for task '%s' < 1 - task will not run", task_name) logger.info(
"Specified interval for task '%s' < 1 - task will not run", task_name
)
return False return False
# Sleep a random number of seconds to prevent worker conflict # Sleep a random number of seconds to prevent worker conflict
@ -119,7 +119,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool:
threshold = datetime.now() - timedelta(days=n_days) threshold = datetime.now() - timedelta(days=n_days)
if last_success > threshold: if last_success > threshold:
logger.info("Last successful run for '%s' was too recent - skipping task", task_name) logger.info(
"Last successful run for '%s' was too recent - skipping task", task_name
)
return False return False
# Check for any information we have about this task # Check for any information we have about this task
@ -136,7 +138,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool:
threshold = datetime.now() - timedelta(hours=12) threshold = datetime.now() - timedelta(hours=12)
if last_attempt > threshold: if last_attempt > threshold:
logger.info("Last attempt for '%s' was too recent - skipping task", task_name) logger.info(
"Last attempt for '%s' was too recent - skipping task", task_name
)
return False return False
# Record this attempt # Record this attempt
@ -152,17 +156,23 @@ def record_task_attempt(task_name: str):
logger.info("Logging task attempt for '%s'", task_name) logger.info("Logging task attempt for '%s'", task_name)
InvenTreeSetting.set_setting(f'_{task_name}_ATTEMPT', datetime.now().isoformat(), None) InvenTreeSetting.set_setting(
f'_{task_name}_ATTEMPT', datetime.now().isoformat(), None
)
def record_task_success(task_name: str): def record_task_success(task_name: str):
"""Record that a multi-day task was successful *now*""" """Record that a multi-day task was successful *now*"""
from common.models import InvenTreeSetting from common.models import InvenTreeSetting
InvenTreeSetting.set_setting(f'_{task_name}_SUCCESS', datetime.now().isoformat(), None) InvenTreeSetting.set_setting(
f'_{task_name}_SUCCESS', datetime.now().isoformat(), None
)
def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs) -> bool: def offload_task(
taskname, *args, force_async=False, force_sync=False, **kwargs
) -> bool:
"""Create an AsyncTask if workers are running. This is different to a 'scheduled' task, in that it only runs once! """Create an AsyncTask if workers are running. This is different to a 'scheduled' task, in that it only runs once!
If workers are not running or force_sync flag, is set then the task is ran synchronously. If workers are not running or force_sync flag, is set then the task is ran synchronously.
@ -205,7 +215,6 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs)
raise_warning(f"WARNING: '{taskname}' not offloaded due to {str(exc)}") raise_warning(f"WARNING: '{taskname}' not offloaded due to {str(exc)}")
return False return False
else: else:
if callable(taskname): if callable(taskname):
# function was passed - use that # function was passed - use that
_func = taskname _func = taskname
@ -215,14 +224,18 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs)
app, mod, func = taskname.split('.') app, mod, func = taskname.split('.')
app_mod = app + '.' + mod app_mod = app + '.' + mod
except ValueError: except ValueError:
raise_warning(f"WARNING: '{taskname}' not started - Malformed function path") raise_warning(
f"WARNING: '{taskname}' not started - Malformed function path"
)
return False return False
# Import module from app # Import module from app
try: try:
_mod = importlib.import_module(app_mod) _mod = importlib.import_module(app_mod)
except ModuleNotFoundError: except ModuleNotFoundError:
raise_warning(f"WARNING: '{taskname}' not started - No module named '{app_mod}'") raise_warning(
f"WARNING: '{taskname}' not started - No module named '{app_mod}'"
)
return False return False
# Retrieve function # Retrieve function
@ -236,7 +249,9 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs)
if not _func: if not _func:
_func = eval(func) # pragma: no cover _func = eval(func) # pragma: no cover
except NameError: except NameError:
raise_warning(f"WARNING: '{taskname}' not started - No function named '{func}'") raise_warning(
f"WARNING: '{taskname}' not started - No function named '{func}'"
)
return False return False
# Workers are not running: run it as synchronous task # Workers are not running: run it as synchronous task
@ -263,18 +278,19 @@ class ScheduledTask:
interval: str interval: str
minutes: int = None minutes: int = None
MINUTES = "I" MINUTES = 'I'
HOURLY = "H" HOURLY = 'H'
DAILY = "D" DAILY = 'D'
WEEKLY = "W" WEEKLY = 'W'
MONTHLY = "M" MONTHLY = 'M'
QUARTERLY = "Q" QUARTERLY = 'Q'
YEARLY = "Y" YEARLY = 'Y'
TYPE = [MINUTES, HOURLY, DAILY, WEEKLY, MONTHLY, QUARTERLY, YEARLY] TYPE = [MINUTES, HOURLY, DAILY, WEEKLY, MONTHLY, QUARTERLY, YEARLY]
class TaskRegister: class TaskRegister:
"""Registry for periodic tasks.""" """Registry for periodic tasks."""
task_list: List[ScheduledTask] = [] task_list: List[ScheduledTask] = []
def register(self, task, schedule, minutes: int = None): def register(self, task, schedule, minutes: int = None):
@ -320,6 +336,7 @@ def scheduled_task(interval: str, minutes: int = None, tasklist: TaskRegister =
_tasks.register(admin_class, interval, minutes=minutes) _tasks.register(admin_class, interval, minutes=minutes)
return admin_class return admin_class
return _task_wrapper return _task_wrapper
@ -332,7 +349,7 @@ def heartbeat():
try: try:
from django_q.models import Success from django_q.models import Success
except AppRegistryNotReady: # pragma: no cover except AppRegistryNotReady: # pragma: no cover
logger.info("Could not perform heartbeat task - App registry not ready") logger.info('Could not perform heartbeat task - App registry not ready')
return return
threshold = timezone.now() - timedelta(minutes=30) threshold = timezone.now() - timedelta(minutes=30)
@ -340,8 +357,7 @@ def heartbeat():
# Delete heartbeat results more than half an hour old, # Delete heartbeat results more than half an hour old,
# otherwise they just create extra noise # otherwise they just create extra noise
heartbeats = Success.objects.filter( heartbeats = Success.objects.filter(
func='InvenTree.tasks.heartbeat', func='InvenTree.tasks.heartbeat', started__lte=threshold
started__lte=threshold
) )
heartbeats.delete() heartbeats.delete()
@ -359,16 +375,16 @@ def delete_successful_tasks():
threshold = timezone.now() - timedelta(days=days) threshold = timezone.now() - timedelta(days=days)
# Delete successful tasks # Delete successful tasks
results = Success.objects.filter( results = Success.objects.filter(started__lte=threshold)
started__lte=threshold
)
if results.count() > 0: if results.count() > 0:
logger.info("Deleting %s successful task records", results.count()) logger.info('Deleting %s successful task records', results.count())
results.delete() results.delete()
except AppRegistryNotReady: # pragma: no cover except AppRegistryNotReady: # pragma: no cover
logger.info("Could not perform 'delete_successful_tasks' - App registry not ready") logger.info(
"Could not perform 'delete_successful_tasks' - App registry not ready"
)
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
@ -383,12 +399,10 @@ def delete_failed_tasks():
threshold = timezone.now() - timedelta(days=days) threshold = timezone.now() - timedelta(days=days)
# Delete failed tasks # Delete failed tasks
results = Failure.objects.filter( results = Failure.objects.filter(started__lte=threshold)
started__lte=threshold
)
if results.count() > 0: if results.count() > 0:
logger.info("Deleting %s failed task records", results.count()) logger.info('Deleting %s failed task records', results.count())
results.delete() results.delete()
except AppRegistryNotReady: # pragma: no cover except AppRegistryNotReady: # pragma: no cover
@ -406,47 +420,48 @@ def delete_old_error_logs():
days = InvenTreeSetting.get_setting('INVENTREE_DELETE_ERRORS_DAYS', 30) days = InvenTreeSetting.get_setting('INVENTREE_DELETE_ERRORS_DAYS', 30)
threshold = timezone.now() - timedelta(days=days) threshold = timezone.now() - timedelta(days=days)
errors = Error.objects.filter( errors = Error.objects.filter(when__lte=threshold)
when__lte=threshold,
)
if errors.count() > 0: if errors.count() > 0:
logger.info("Deleting %s old error logs", errors.count()) logger.info('Deleting %s old error logs', errors.count())
errors.delete() errors.delete()
except AppRegistryNotReady: # pragma: no cover except AppRegistryNotReady: # pragma: no cover
# Apps not yet loaded # Apps not yet loaded
logger.info("Could not perform 'delete_old_error_logs' - App registry not ready") logger.info(
"Could not perform 'delete_old_error_logs' - App registry not ready"
)
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
def delete_old_notifications(): def delete_old_notifications():
"""Delete old notification logs""" """Delete old notification logs"""
try: try:
from common.models import (InvenTreeSetting, NotificationEntry, from common.models import (
NotificationMessage) InvenTreeSetting,
NotificationEntry,
NotificationMessage,
)
days = InvenTreeSetting.get_setting('INVENTREE_DELETE_NOTIFICATIONS_DAYS', 30) days = InvenTreeSetting.get_setting('INVENTREE_DELETE_NOTIFICATIONS_DAYS', 30)
threshold = timezone.now() - timedelta(days=days) threshold = timezone.now() - timedelta(days=days)
items = NotificationEntry.objects.filter( items = NotificationEntry.objects.filter(updated__lte=threshold)
updated__lte=threshold
)
if items.count() > 0: if items.count() > 0:
logger.info("Deleted %s old notification entries", items.count()) logger.info('Deleted %s old notification entries', items.count())
items.delete() items.delete()
items = NotificationMessage.objects.filter( items = NotificationMessage.objects.filter(creation__lte=threshold)
creation__lte=threshold
)
if items.count() > 0: if items.count() > 0:
logger.info("Deleted %s old notification messages", items.count()) logger.info('Deleted %s old notification messages', items.count())
items.delete() items.delete()
except AppRegistryNotReady: except AppRegistryNotReady:
logger.info("Could not perform 'delete_old_notifications' - App registry not ready") logger.info(
"Could not perform 'delete_old_notifications' - App registry not ready"
)
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
@ -460,13 +475,17 @@ def check_for_updates():
logger.info("Could not perform 'check_for_updates' - App registry not ready") logger.info("Could not perform 'check_for_updates' - App registry not ready")
return return
interval = int(common.models.InvenTreeSetting.get_setting('INVENTREE_UPDATE_CHECK_INTERVAL', 7, cache=False)) interval = int(
common.models.InvenTreeSetting.get_setting(
'INVENTREE_UPDATE_CHECK_INTERVAL', 7, cache=False
)
)
# Check if we should check for updates *today* # Check if we should check for updates *today*
if not check_daily_holdoff('check_for_updates', interval): if not check_daily_holdoff('check_for_updates', interval):
return return
logger.info("Checking for InvenTree software updates") logger.info('Checking for InvenTree software updates')
headers = {} headers = {}
@ -475,15 +494,17 @@ def check_for_updates():
token = os.getenv('GITHUB_TOKEN', None) token = os.getenv('GITHUB_TOKEN', None)
if token: if token:
headers['Authorization'] = f"Bearer {token}" headers['Authorization'] = f'Bearer {token}'
response = requests.get( response = requests.get(
'https://api.github.com/repos/inventree/inventree/releases/latest', 'https://api.github.com/repos/inventree/inventree/releases/latest',
headers=headers headers=headers,
) )
if response.status_code != 200: if response.status_code != 200:
raise ValueError(f'Unexpected status code from GitHub API: {response.status_code}') # pragma: no cover raise ValueError(
f'Unexpected status code from GitHub API: {response.status_code}'
) # pragma: no cover
data = json.loads(response.text) data = json.loads(response.text)
@ -492,7 +513,7 @@ def check_for_updates():
if not tag: if not tag:
raise ValueError("'tag_name' missing from GitHub response") # pragma: no cover raise ValueError("'tag_name' missing from GitHub response") # pragma: no cover
match = re.match(r"^.*(\d+)\.(\d+)\.(\d+).*$", tag) match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', tag)
if len(match.groups()) != 3: # pragma: no cover if len(match.groups()) != 3: # pragma: no cover
logger.warning("Version '%s' did not match expected pattern", tag) logger.warning("Version '%s' did not match expected pattern", tag)
@ -506,31 +527,26 @@ def check_for_updates():
logger.info("Latest InvenTree version: '%s'", tag) logger.info("Latest InvenTree version: '%s'", tag)
# Save the version to the database # Save the version to the database
common.models.InvenTreeSetting.set_setting( common.models.InvenTreeSetting.set_setting('_INVENTREE_LATEST_VERSION', tag, None)
'_INVENTREE_LATEST_VERSION',
tag,
None
)
# Record that this task was successful # Record that this task was successful
record_task_success('check_for_updates') record_task_success('check_for_updates')
# Send notification if there is a new version # Send notification if there is a new version
if not isInvenTreeUpToDate(): if not isInvenTreeUpToDate():
logger.warning("InvenTree is not up-to-date, sending notification") logger.warning('InvenTree is not up-to-date, sending notification')
plg = registry.get_plugin('InvenTreeCoreNotificationsPlugin') plg = registry.get_plugin('InvenTreeCoreNotificationsPlugin')
if not plg: if not plg:
logger.warning("Cannot send notification - plugin not found") logger.warning('Cannot send notification - plugin not found')
return return
plg = plg.plugin_config() plg = plg.plugin_config()
if not plg: if not plg:
logger.warning("Cannot send notification - plugin config not found") logger.warning('Cannot send notification - plugin config not found')
return return
# Send notification # Send notification
trigger_superuser_notification( trigger_superuser_notification(
plg, plg, f'An update for InvenTree to version {tag} is available'
f'An update for InvenTree to version {tag} is available',
) )
@ -549,17 +565,21 @@ def update_exchange_rates(force: bool = False):
from InvenTree.exchange import InvenTreeExchange from InvenTree.exchange import InvenTreeExchange
except AppRegistryNotReady: # pragma: no cover except AppRegistryNotReady: # pragma: no cover
# Apps not yet loaded! # Apps not yet loaded!
logger.info("Could not perform 'update_exchange_rates' - App registry not ready") logger.info(
"Could not perform 'update_exchange_rates' - App registry not ready"
)
return return
except Exception as exc: # pragma: no cover except Exception as exc: # pragma: no cover
logger.info("Could not perform 'update_exchange_rates' - %s", exc) logger.info("Could not perform 'update_exchange_rates' - %s", exc)
return return
if not force: if not force:
interval = int(InvenTreeSetting.get_setting('CURRENCY_UPDATE_INTERVAL', 1, cache=False)) interval = int(
InvenTreeSetting.get_setting('CURRENCY_UPDATE_INTERVAL', 1, cache=False)
)
if not check_daily_holdoff('update_exchange_rates', interval): if not check_daily_holdoff('update_exchange_rates', interval):
logger.info("Skipping exchange rate update (interval not reached)") logger.info('Skipping exchange rate update (interval not reached)')
return return
backend = InvenTreeExchange() backend = InvenTreeExchange()
@ -570,15 +590,17 @@ def update_exchange_rates(force: bool = False):
backend.update_rates(base_currency=base) backend.update_rates(base_currency=base)
# Remove any exchange rates which are not in the provided currencies # Remove any exchange rates which are not in the provided currencies
Rate.objects.filter(backend="InvenTreeExchange").exclude(currency__in=currency_codes()).delete() Rate.objects.filter(backend='InvenTreeExchange').exclude(
currency__in=currency_codes()
).delete()
# Record successful task execution # Record successful task execution
record_task_success('update_exchange_rates') record_task_success('update_exchange_rates')
except (AppRegistryNotReady, OperationalError, ProgrammingError): except (AppRegistryNotReady, OperationalError, ProgrammingError):
logger.warning("Could not update exchange rates - database not ready") logger.warning('Could not update exchange rates - database not ready')
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.exception("Error updating exchange rates: %s", str(type(e))) logger.exception('Error updating exchange rates: %s', str(type(e)))
@scheduled_task(ScheduledTask.DAILY) @scheduled_task(ScheduledTask.DAILY)
@ -590,16 +612,20 @@ def run_backup():
# Backups are not enabled - exit early # Backups are not enabled - exit early
return return
interval = int(InvenTreeSetting.get_setting('INVENTREE_BACKUP_DAYS', 1, cache=False)) interval = int(
InvenTreeSetting.get_setting('INVENTREE_BACKUP_DAYS', 1, cache=False)
)
# Check if should run this task *today* # Check if should run this task *today*
if not check_daily_holdoff('run_backup', interval): if not check_daily_holdoff('run_backup', interval):
return return
logger.info("Performing automated database backup task") logger.info('Performing automated database backup task')
call_command("dbbackup", noinput=True, clean=True, compress=True, interactive=False) call_command('dbbackup', noinput=True, clean=True, compress=True, interactive=False)
call_command("mediabackup", noinput=True, clean=True, compress=True, interactive=False) call_command(
'mediabackup', noinput=True, clean=True, compress=True, interactive=False
)
# Record that this task was successful # Record that this task was successful
record_task_success('run_backup') record_task_success('run_backup')
@ -627,7 +653,7 @@ def check_for_migrations():
logger.info('There are %s pending migrations', n) logger.info('There are %s pending migrations', n)
InvenTreeSetting.set_setting('_PENDING_MIGRATIONS', n, None) InvenTreeSetting.set_setting('_PENDING_MIGRATIONS', n, None)
logger.info("Checking for pending database migrations") logger.info('Checking for pending database migrations')
# Force plugin registry reload # Force plugin registry reload
registry.check_reload() registry.check_reload()
@ -645,12 +671,12 @@ def check_for_migrations():
# Test if auto-updates are enabled # Test if auto-updates are enabled
if not get_setting('INVENTREE_AUTO_UPDATE', 'auto_update'): if not get_setting('INVENTREE_AUTO_UPDATE', 'auto_update'):
logger.info("Auto-update is disabled - skipping migrations") logger.info('Auto-update is disabled - skipping migrations')
return return
# Log open migrations # Log open migrations
for migration in plan: for migration in plan:
logger.info("- %s", str(migration[0])) logger.info('- %s', str(migration[0]))
# Set the application to maintenance mode - no access from now on. # Set the application to maintenance mode - no access from now on.
set_maintenance_mode(True) set_maintenance_mode(True)
@ -668,13 +694,13 @@ def check_for_migrations():
else: else:
set_pending_migrations(0) set_pending_migrations(0)
logger.info("Completed %s migrations", n) logger.info('Completed %s migrations', n)
# Make sure we are out of maintenance mode # Make sure we are out of maintenance mode
if get_maintenance_mode(): if get_maintenance_mode():
logger.warning("Maintenance mode was not disabled - forcing it now") logger.warning('Maintenance mode was not disabled - forcing it now')
set_maintenance_mode(False) set_maintenance_mode(False)
logger.info("Manually released maintenance mode") logger.info('Manually released maintenance mode')
# We should be current now - triggering full reload to make sure all models # We should be current now - triggering full reload to make sure all models
# are loaded fully in their new state. # are loaded fully in their new state.

View File

@ -17,6 +17,7 @@ class HTMLAPITests(InvenTreeTestCase):
which raised an AssertionError when using the HTML API interface, which raised an AssertionError when using the HTML API interface,
while the regular JSON interface continued to work as expected. while the regular JSON interface continued to work as expected.
""" """
roles = 'all' roles = 'all'
def test_part_api(self): def test_part_api(self):
@ -60,12 +61,7 @@ class HTMLAPITests(InvenTreeTestCase):
class APITests(InvenTreeAPITestCase): class APITests(InvenTreeAPITestCase):
"""Tests for the InvenTree API.""" """Tests for the InvenTree API."""
fixtures = [ fixtures = ['location', 'category', 'part', 'stock']
'location',
'category',
'part',
'stock'
]
token = None token = None
auto_login = False auto_login = False
@ -73,11 +69,11 @@ class APITests(InvenTreeAPITestCase):
"""Helper function to use basic auth.""" """Helper function to use basic auth."""
# Use basic authentication # Use basic authentication
authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii") authstring = bytes('{u}:{p}'.format(u=self.username, p=self.password), 'ascii')
# Use "basic" auth by default # Use "basic" auth by default
auth = b64encode(authstring).decode("ascii") auth = b64encode(authstring).decode('ascii')
self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth)) self.client.credentials(HTTP_AUTHORIZATION='Basic {auth}'.format(auth=auth))
def tokenAuth(self): def tokenAuth(self):
"""Helper function to use token auth.""" """Helper function to use token auth."""
@ -270,33 +266,20 @@ class BulkDeleteTests(InvenTreeAPITestCase):
url = reverse('api-stock-test-result-list') url = reverse('api-stock-test-result-list')
# DELETE without any of the required fields # DELETE without any of the required fields
response = self.delete( response = self.delete(url, {}, expected_code=400)
url,
{},
expected_code=400
)
self.assertIn('List of items or filters must be provided for bulk deletion', str(response.data)) self.assertIn(
'List of items or filters must be provided for bulk deletion',
str(response.data),
)
# DELETE with invalid 'items' # DELETE with invalid 'items'
response = self.delete( response = self.delete(url, {'items': {'hello': 'world'}}, expected_code=400)
url,
{
'items': {"hello": "world"},
},
expected_code=400,
)
self.assertIn("'items' must be supplied as a list object", str(response.data)) self.assertIn("'items' must be supplied as a list object", str(response.data))
# DELETE with invalid 'filters' # DELETE with invalid 'filters'
response = self.delete( response = self.delete(url, {'filters': [1, 2, 3]}, expected_code=400)
url,
{
'filters': [1, 2, 3],
},
expected_code=400,
)
self.assertIn("'filters' must be supplied as a dict object", str(response.data)) self.assertIn("'filters' must be supplied as a dict object", str(response.data))
@ -317,11 +300,7 @@ class SearchTests(InvenTreeAPITestCase):
def test_empty(self): def test_empty(self):
"""Test empty request""" """Test empty request"""
data = [ data = ['', None, {}]
'',
None,
{},
]
for d in data: for d in data:
response = self.post(reverse('api-search'), d, expected_code=400) response = self.post(reverse('api-search'), d, expected_code=400)
@ -331,13 +310,8 @@ class SearchTests(InvenTreeAPITestCase):
"""Test individual result types""" """Test individual result types"""
response = self.post( response = self.post(
reverse('api-search'), reverse('api-search'),
{ {'search': 'chair', 'limit': 3, 'part': {}, 'build': {}},
'search': 'chair', expected_code=200,
'limit': 3,
'part': {},
'build': {},
},
expected_code=200
) )
# No build results # No build results
@ -354,12 +328,7 @@ class SearchTests(InvenTreeAPITestCase):
# Search for orders # Search for orders
response = self.post( response = self.post(
reverse('api-search'), reverse('api-search'),
{ {'search': '01', 'limit': 2, 'purchaseorder': {}, 'salesorder': {}},
'search': '01',
'limit': 2,
'purchaseorder': {},
'salesorder': {},
},
expected_code=200, expected_code=200,
) )
@ -392,33 +361,25 @@ class SearchTests(InvenTreeAPITestCase):
'salesorder', 'salesorder',
] ]
query = { query = {'search': 'c', 'limit': 3}
'search': 'c',
'limit': 3,
}
for mdl in models: for mdl in models:
query[mdl] = {} query[mdl] = {}
response = self.post( response = self.post(reverse('api-search'), query, expected_code=200)
reverse('api-search'),
query,
expected_code=200
)
# Check for 'permission denied' error # Check for 'permission denied' error
for mdl in models: for mdl in models:
self.assertEqual(response.data[mdl]['error'], 'User does not have permission to view this model') self.assertEqual(
response.data[mdl]['error'],
'User does not have permission to view this model',
)
# Assign view roles for some parts # Assign view roles for some parts
self.assignRole('build.view') self.assignRole('build.view')
self.assignRole('part.view') self.assignRole('part.view')
response = self.post( response = self.post(reverse('api-search'), query, expected_code=200)
reverse('api-search'),
query,
expected_code=200
)
# Check for expected results, based on permissions # Check for expected results, based on permissions
# We expect results to be returned for the following model types # We expect results to be returned for the following model types
@ -438,4 +399,6 @@ class SearchTests(InvenTreeAPITestCase):
self.assertIn('count', result) self.assertIn('count', result)
else: else:
self.assertIn('error', result) self.assertIn('error', result)
self.assertEqual(result['error'], 'User does not have permission to view this model') self.assertEqual(
result['error'], 'User does not have permission to view this model'
)

View File

@ -1,6 +1,5 @@
"""Tests for api_version.""" """Tests for api_version."""
from django.urls import reverse from django.urls import reverse
from InvenTree.api_version import INVENTREE_API_VERSION from InvenTree.api_version import INVENTREE_API_VERSION

View File

@ -66,6 +66,7 @@ class MiddlewareTests(InvenTreeTestCase):
def test_error_exceptions(self): def test_error_exceptions(self):
"""Test that ignored errors are not logged.""" """Test that ignored errors are not logged."""
def check(excpected_nbr=0): def check(excpected_nbr=0):
# Check that errors are empty # Check that errors are empty
errors = Error.objects.all() errors = Error.objects.all()

View File

@ -67,15 +67,23 @@ class InvenTreeTaskTests(TestCase):
# Error runs # Error runs
# Malformed taskname # Malformed taskname
with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTree' not started - Malformed function path"): with self.assertWarnsMessage(
UserWarning, "WARNING: 'InvenTree' not started - Malformed function path"
):
InvenTree.tasks.offload_task('InvenTree') InvenTree.tasks.offload_task('InvenTree')
# Non existent app # Non existent app
with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTreeABC.test_tasks.doesnotmatter' not started - No module named 'InvenTreeABC.test_tasks'"): with self.assertWarnsMessage(
UserWarning,
"WARNING: 'InvenTreeABC.test_tasks.doesnotmatter' not started - No module named 'InvenTreeABC.test_tasks'",
):
InvenTree.tasks.offload_task('InvenTreeABC.test_tasks.doesnotmatter') InvenTree.tasks.offload_task('InvenTreeABC.test_tasks.doesnotmatter')
# Non existent function # Non existent function
with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTree.test_tasks.doesnotexist' not started - No function named 'doesnotexist'"): with self.assertWarnsMessage(
UserWarning,
"WARNING: 'InvenTree.test_tasks.doesnotexist' not started - No function named 'doesnotexist'",
):
InvenTree.tasks.offload_task('InvenTree.test_tasks.doesnotexist') InvenTree.tasks.offload_task('InvenTree.test_tasks.doesnotexist')
def test_task_hearbeat(self): def test_task_hearbeat(self):
@ -86,7 +94,9 @@ class InvenTreeTaskTests(TestCase):
"""Test the task delete_successful_tasks.""" """Test the task delete_successful_tasks."""
from django_q.models import Success from django_q.models import Success
Success.objects.create(name='abc', func='abc', stopped=threshold, started=threshold_low) Success.objects.create(
name='abc', func='abc', stopped=threshold, started=threshold_low
)
InvenTree.tasks.offload_task(InvenTree.tasks.delete_successful_tasks) InvenTree.tasks.offload_task(InvenTree.tasks.delete_successful_tasks)
results = Success.objects.filter(started__lte=threshold) results = Success.objects.filter(started__lte=threshold)
self.assertEqual(len(results), 0) self.assertEqual(len(results), 0)
@ -99,14 +109,14 @@ class InvenTreeTaskTests(TestCase):
error_obj.save() error_obj.save()
# Check that it is not empty # Check that it is not empty
errors = Error.objects.filter(when__lte=threshold,) errors = Error.objects.filter(when__lte=threshold)
self.assertNotEqual(len(errors), 0) self.assertNotEqual(len(errors), 0)
# Run action # Run action
InvenTree.tasks.offload_task(InvenTree.tasks.delete_old_error_logs) InvenTree.tasks.offload_task(InvenTree.tasks.delete_old_error_logs)
# Check that it is empty again # Check that it is empty again
errors = Error.objects.filter(when__lte=threshold,) errors = Error.objects.filter(when__lte=threshold)
self.assertEqual(len(errors), 0) self.assertEqual(len(errors), 0)
def test_task_check_for_updates(self): def test_task_check_for_updates(self):
@ -146,7 +156,9 @@ class InvenTreeTaskTests(TestCase):
# Cleanup # Cleanup
try: try:
migration_name = InvenTree.tasks.get_migration_plan()[0][0].name + '.py' migration_name = InvenTree.tasks.get_migration_plan()[0][0].name + '.py'
migration_path = settings.BASE_DIR / 'InvenTree' / 'migrations' / migration_name migration_path = (
settings.BASE_DIR / 'InvenTree' / 'migrations' / migration_name
)
migration_path.unlink() migration_path.unlink()
except IndexError: # pragma: no cover except IndexError: # pragma: no cover
pass pass

View File

@ -44,7 +44,7 @@ class URLTest(TestCase):
('part', 'templates'), ('part', 'templates'),
('report', 'templates'), ('report', 'templates'),
('stock', 'templates'), ('stock', 'templates'),
('templates', ), ('templates',),
] ]
template_files = [] template_files = []
@ -53,11 +53,9 @@ class URLTest(TestCase):
tld = os.path.join(here, '..') tld = os.path.join(here, '..')
for directory in template_dirs: for directory in template_dirs:
template_dir = os.path.join(tld, *directory) template_dir = os.path.join(tld, *directory)
for path in Path(template_dir).rglob(suffix): for path in Path(template_dir).rglob(suffix):
f = os.path.abspath(path) f = os.path.abspath(path)
if f not in template_files: if f not in template_files:
@ -69,25 +67,18 @@ class URLTest(TestCase):
"""Search for all instances of {% url %} in supplied template file.""" """Search for all instances of {% url %} in supplied template file."""
urls = [] urls = []
pattern = "{% url ['\"]([^'\"]+)['\"]([^%]*)%}" pattern = '{% url [\'"]([^\'"]+)[\'"]([^%]*)%}'
with open(input_file, 'r') as f: with open(input_file, 'r') as f:
data = f.read() data = f.read()
results = re.findall(pattern, data) results = re.findall(pattern, data)
for result in results: for result in results:
if len(result) == 2: if len(result) == 2:
urls.append([ urls.append([result[0].strip(), result[1].strip()])
result[0].strip(),
result[1].strip()
])
elif len(result) == 1: # pragma: no cover elif len(result) == 1: # pragma: no cover
urls.append([ urls.append([result[0].strip(), ''])
result[0].strip(),
''
])
return urls return urls
@ -100,16 +91,16 @@ class URLTest(TestCase):
pk = None pk = None
# TODO: Handle reverse lookup of admin URLs! # TODO: Handle reverse lookup of admin URLs!
if url.startswith("admin:"): if url.startswith('admin:'):
return return
# TODO can this be more elegant? # TODO can this be more elegant?
if url.startswith("account_"): if url.startswith('account_'):
return return
if pk: if pk:
# We will assume that there is at least one item in the database # We will assume that there is at least one item in the database
reverse(url, kwargs={"pk": 1}) reverse(url, kwargs={'pk': 1})
else: else:
reverse(url) reverse(url)
@ -122,14 +113,14 @@ class URLTest(TestCase):
def test_html_templates(self): def test_html_templates(self):
"""Test all HTML templates for broken url tags.""" """Test all HTML templates for broken url tags."""
template_files = self.find_files("*.html") template_files = self.find_files('*.html')
for f in template_files: for f in template_files:
self.check_file(f) self.check_file(f)
def test_js_templates(self): def test_js_templates(self):
"""Test all JS templates for broken url tags.""" """Test all JS templates for broken url tags."""
template_files = self.find_files("*.js") template_files = self.find_files('*.js')
for f in template_files: for f in template_files:
self.check_file(f) self.check_file(f)

View File

@ -23,13 +23,13 @@ class ViewTests(InvenTreeTestCase):
def test_index_redirect(self): def test_index_redirect(self):
"""Top-level URL should redirect to "index" page.""" """Top-level URL should redirect to "index" page."""
response = self.client.get("/") response = self.client.get('/')
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
def get_index_page(self): def get_index_page(self):
"""Retrieve the index page (used for subsequent unit tests)""" """Retrieve the index page (used for subsequent unit tests)"""
response = self.client.get("/index/") response = self.client.get('/index/')
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@ -60,35 +60,20 @@ class ViewTests(InvenTreeTestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
content = response.content.decode() content = response.content.decode()
user_panels = [ user_panels = ['account', 'user-display', 'user-home', 'user-reports']
'account',
'user-display',
'user-home',
'user-reports',
]
staff_panels = [ staff_panels = ['server', 'login', 'barcodes', 'pricing', 'parts', 'stock']
'server',
'login',
'barcodes',
'pricing',
'parts',
'stock',
]
plugin_panels = [ plugin_panels = ['plugin']
'plugin',
]
# Default user has staff access, so all panels will be present # Default user has staff access, so all panels will be present
for panel in user_panels + staff_panels + plugin_panels: for panel in user_panels + staff_panels + plugin_panels:
self.assertIn(f"select-{panel}", content) self.assertIn(f'select-{panel}', content)
self.assertIn(f"panel-{panel}", content) self.assertIn(f'panel-{panel}', content)
# Now create a user who does not have staff access # Now create a user who does not have staff access
pleb_user = get_user_model().objects.create_user( pleb_user = get_user_model().objects.create_user(
username='pleb', username='pleb', password='notstaff'
password='notstaff',
) )
pleb_user.groups.add(self.group) pleb_user.groups.add(self.group)
@ -98,10 +83,7 @@ class ViewTests(InvenTreeTestCase):
self.client.logout() self.client.logout()
result = self.client.login( result = self.client.login(username='pleb', password='notstaff')
username='pleb',
password='notstaff',
)
self.assertTrue(result) self.assertTrue(result)
@ -111,22 +93,24 @@ class ViewTests(InvenTreeTestCase):
# Normal user still has access to user-specific panels # Normal user still has access to user-specific panels
for panel in user_panels: for panel in user_panels:
self.assertIn(f"select-{panel}", content) self.assertIn(f'select-{panel}', content)
self.assertIn(f"panel-{panel}", content) self.assertIn(f'panel-{panel}', content)
# Normal user does NOT have access to global or plugin settings # Normal user does NOT have access to global or plugin settings
for panel in staff_panels + plugin_panels: for panel in staff_panels + plugin_panels:
self.assertNotIn(f"select-{panel}", content) self.assertNotIn(f'select-{panel}', content)
self.assertNotIn(f"panel-{panel}", content) self.assertNotIn(f'panel-{panel}', content)
def test_url_login(self): def test_url_login(self):
"""Test logging in via arguments""" """Test logging in via arguments"""
# Log out # Log out
self.client.logout() self.client.logout()
response = self.client.get("/index/") response = self.client.get('/index/')
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
# Try login with url # Try login with url
response = self.client.get(f"/accounts/login/?next=/&login={self.username}&password={self.password}") response = self.client.get(
f'/accounts/login/?next=/&login={self.username}&password={self.password}'
)
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, '/') self.assertEqual(response.url, '/')

View File

@ -45,12 +45,12 @@ class ConversionTest(TestCase):
def test_prefixes(self): def test_prefixes(self):
"""Test inputs where prefixes are used""" """Test inputs where prefixes are used"""
tests = { tests = {
"3": 3, '3': 3,
"3m": 3, '3m': 3,
"3mm": 0.003, '3mm': 0.003,
"3k": 3000, '3k': 3000,
"3u": 0.000003, '3u': 0.000003,
"3 inch": 0.0762, '3 inch': 0.0762,
} }
for val, expected in tests.items(): for val, expected in tests.items():
@ -60,13 +60,13 @@ class ConversionTest(TestCase):
def test_base_units(self): def test_base_units(self):
"""Test conversion to specified base units""" """Test conversion to specified base units"""
tests = { tests = {
"3": 3, '3': 3,
"3 dozen": 36, '3 dozen': 36,
"50 dozen kW": 600000, '50 dozen kW': 600000,
"1 / 10": 0.1, '1 / 10': 0.1,
"1/2 kW": 500, '1/2 kW': 500,
"1/2 dozen kW": 6000, '1/2 dozen kW': 6000,
"0.005 MW": 5000, '0.005 MW': 5000,
} }
for val, expected in tests.items(): for val, expected in tests.items():
@ -104,13 +104,7 @@ class ConversionTest(TestCase):
def test_invalid_units(self): def test_invalid_units(self):
"""Test conversion with bad units""" """Test conversion with bad units"""
tests = { tests = {'3': '10', '13': '-?-', '-3': 'xyz', '-12': '-12', '1/0': '1/0'}
'3': '10',
'13': '-?-',
'-3': 'xyz',
'-12': '-12',
'1/0': '1/0',
}
for val, unit in tests.items(): for val, unit in tests.items():
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
@ -118,12 +112,7 @@ class ConversionTest(TestCase):
def test_invalid_values(self): def test_invalid_values(self):
"""Test conversion of invalid inputs""" """Test conversion of invalid inputs"""
inputs = [ inputs = ['-x', '1/0', 'xyz', '12B45C']
'-x',
'1/0',
'xyz',
'12B45C'
]
for val in inputs: for val in inputs:
# Test with a provided unit # Test with a provided unit
@ -148,9 +137,7 @@ class ConversionTest(TestCase):
# Create a new custom unit # Create a new custom unit
CustomUnit.objects.create( CustomUnit.objects.create(
name='fanciful_unit', name='fanciful_unit', definition='henry / mm', symbol='hpmm'
definition='henry / mm',
symbol='hpmm',
) )
# Reload registry # Reload registry
@ -171,7 +158,9 @@ class ConversionTest(TestCase):
for val, expected in tests.items(): for val, expected in tests.items():
# Convert, and leave units # Convert, and leave units
q = InvenTree.conversion.convert_physical_value(val, 'henry / km', strip_units=False) q = InvenTree.conversion.convert_physical_value(
val, 'henry / km', strip_units=False
)
self.assertAlmostEqual(float(q.magnitude), expected, 2) self.assertAlmostEqual(float(q.magnitude), expected, 2)
# Convert and strip units # Convert and strip units
@ -184,24 +173,24 @@ class ValidatorTest(TestCase):
def test_overage(self): def test_overage(self):
"""Test overage validator.""" """Test overage validator."""
validate_overage("100%") validate_overage('100%')
validate_overage("10") validate_overage('10')
validate_overage("45.2 %") validate_overage('45.2 %')
with self.assertRaises(django_exceptions.ValidationError): with self.assertRaises(django_exceptions.ValidationError):
validate_overage("-1") validate_overage('-1')
with self.assertRaises(django_exceptions.ValidationError): with self.assertRaises(django_exceptions.ValidationError):
validate_overage("-2.04 %") validate_overage('-2.04 %')
with self.assertRaises(django_exceptions.ValidationError): with self.assertRaises(django_exceptions.ValidationError):
validate_overage("105%") validate_overage('105%')
with self.assertRaises(django_exceptions.ValidationError): with self.assertRaises(django_exceptions.ValidationError):
validate_overage("xxx %") validate_overage('xxx %')
with self.assertRaises(django_exceptions.ValidationError): with self.assertRaises(django_exceptions.ValidationError):
validate_overage("aaaa") validate_overage('aaaa')
def test_url_validation(self): def test_url_validation(self):
"""Test for AllowedURLValidator""" """Test for AllowedURLValidator"""
@ -241,7 +230,7 @@ class FormatTest(TestCase):
def test_parse(self): def test_parse(self):
"""Tests for the 'parse_format_string' function""" """Tests for the 'parse_format_string' function"""
# Extract data from a valid format string # Extract data from a valid format string
fmt = "PO-{abc:02f}-{ref:04d}-{date}-???" fmt = 'PO-{abc:02f}-{ref:04d}-{date}-???'
info = InvenTree.format.parse_format_string(fmt) info = InvenTree.format.parse_format_string(fmt)
@ -250,22 +239,17 @@ class FormatTest(TestCase):
self.assertIn('date', info) self.assertIn('date', info)
# Try with invalid strings # Try with invalid strings
for fmt in [ for fmt in ['PO-{{xyz}', 'PO-{xyz}}', 'PO-{xyz}-{']:
'PO-{{xyz}',
'PO-{xyz}}',
'PO-{xyz}-{',
]:
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
InvenTree.format.parse_format_string(fmt) InvenTree.format.parse_format_string(fmt)
def test_create_regex(self): def test_create_regex(self):
"""Test function for creating a regex from a format string""" """Test function for creating a regex from a format string"""
tests = { tests = {
"PO-123-{ref:04f}": r"^PO\-123\-(?P<ref>.+)$", 'PO-123-{ref:04f}': r'^PO\-123\-(?P<ref>.+)$',
"{PO}-???-{ref}-{date}-22": r"^(?P<PO>.+)\-...\-(?P<ref>.+)\-(?P<date>.+)\-22$", '{PO}-???-{ref}-{date}-22': r'^(?P<PO>.+)\-...\-(?P<ref>.+)\-(?P<date>.+)\-22$',
"ABC-123-###-{ref}": r"^ABC\-123\-\d\d\d\-(?P<ref>.+)$", 'ABC-123-###-{ref}': r'^ABC\-123\-\d\d\d\-(?P<ref>.+)$',
"ABC-123": r"^ABC\-123$", 'ABC-123': r'^ABC\-123$',
} }
for fmt, reg in tests.items(): for fmt, reg in tests.items():
@ -275,55 +259,44 @@ class FormatTest(TestCase):
"""Test that string validation works as expected""" """Test that string validation works as expected"""
# These tests should pass # These tests should pass
for value, pattern in { for value, pattern in {
"ABC-hello-123": "???-{q}-###", 'ABC-hello-123': '???-{q}-###',
"BO-1234": "BO-{ref}", 'BO-1234': 'BO-{ref}',
"111.222.fred.china": "???.###.{name}.{place}", '111.222.fred.china': '???.###.{name}.{place}',
"PO-1234": "PO-{ref:04d}" 'PO-1234': 'PO-{ref:04d}',
}.items(): }.items():
self.assertTrue(InvenTree.format.validate_string(value, pattern)) self.assertTrue(InvenTree.format.validate_string(value, pattern))
# These tests should fail # These tests should fail
for value, pattern in { for value, pattern in {
"ABC-hello-123": "###-{q}-???", 'ABC-hello-123': '###-{q}-???',
"BO-1234": "BO.{ref}", 'BO-1234': 'BO.{ref}',
"BO-####": "BO-{pattern}-{next}", 'BO-####': 'BO-{pattern}-{next}',
"BO-123d": "BO-{ref:04d}" 'BO-123d': 'BO-{ref:04d}',
}.items(): }.items():
self.assertFalse(InvenTree.format.validate_string(value, pattern)) self.assertFalse(InvenTree.format.validate_string(value, pattern))
def test_extract_value(self): def test_extract_value(self):
"""Test that we can extract named values based on a format string""" """Test that we can extract named values based on a format string"""
# Simple tests based on a straight-forward format string # Simple tests based on a straight-forward format string
fmt = "PO-###-{ref:04d}" fmt = 'PO-###-{ref:04d}'
tests = { tests = {'123': 'PO-123-123', '456': 'PO-123-456', '789': 'PO-123-789'}
"123": "PO-123-123",
"456": "PO-123-456",
"789": "PO-123-789",
}
for k, v in tests.items(): for k, v in tests.items():
self.assertEqual(InvenTree.format.extract_named_group('ref', v, fmt), k) self.assertEqual(InvenTree.format.extract_named_group('ref', v, fmt), k)
# However these ones should fail # However these ones should fail
tests = { tests = {'abc': 'PO-123-abc', 'xyz': 'PO-123-xyz'}
'abc': 'PO-123-abc',
'xyz': 'PO-123-xyz',
}
for v in tests.values(): for v in tests.values():
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
InvenTree.format.extract_named_group('ref', v, fmt) InvenTree.format.extract_named_group('ref', v, fmt)
# More complex tests # More complex tests
fmt = "PO-{date}-{test}-???-{ref}-###" fmt = 'PO-{date}-{test}-???-{ref}-###'
val = "PO-2022-02-01-hello-ABC-12345-222" val = 'PO-2022-02-01-hello-ABC-12345-222'
data = { data = {'date': '2022-02-01', 'test': 'hello', 'ref': '12345'}
'date': '2022-02-01',
'test': 'hello',
'ref': '12345',
}
for k, v in data.items(): for k, v in data.items():
self.assertEqual(InvenTree.format.extract_named_group(k, val, fmt), v) self.assertEqual(InvenTree.format.extract_named_group(k, val, fmt), v)
@ -332,60 +305,46 @@ class FormatTest(TestCase):
# Raises a ValueError as the format string is bad # Raises a ValueError as the format string is bad
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
InvenTree.format.extract_named_group( InvenTree.format.extract_named_group('test', 'PO-1234-5', 'PO-{test}-{')
"test",
"PO-1234-5",
"PO-{test}-{"
)
# Raises a NameError as the named group does not exist in the format string # Raises a NameError as the named group does not exist in the format string
with self.assertRaises(NameError): with self.assertRaises(NameError):
InvenTree.format.extract_named_group( InvenTree.format.extract_named_group('missing', 'PO-12345', 'PO-{test}')
"missing",
"PO-12345",
"PO-{test}",
)
# Raises a ValueError as the value does not match the format string # Raises a ValueError as the value does not match the format string
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
InvenTree.format.extract_named_group( InvenTree.format.extract_named_group('test', 'PO-1234', 'PO-{test}-1234')
"test",
"PO-1234",
"PO-{test}-1234",
)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
InvenTree.format.extract_named_group( InvenTree.format.extract_named_group('test', 'PO-ABC-xyz', 'PO-###-{test}')
"test",
"PO-ABC-xyz",
"PO-###-{test}",
)
def test_currency_formatting(self): def test_currency_formatting(self):
"""Test that currency formatting works correctly for multiple currencies""" """Test that currency formatting works correctly for multiple currencies"""
test_data = ( test_data = (
(Money( 3651.285718, "USD"), 4, "$3,651.2857" ), # noqa: E201,E202 (Money(3651.285718, 'USD'), 4, '$3,651.2857'), # noqa: E201,E202
(Money(487587.849178, "CAD"), 5, "CA$487,587.84918"), # noqa: E201,E202 (Money(487587.849178, 'CAD'), 5, 'CA$487,587.84918'), # noqa: E201,E202
(Money( 0.348102, "EUR"), 1, "€0.3" ), # noqa: E201,E202 (Money(0.348102, 'EUR'), 1, '€0.3'), # noqa: E201,E202
(Money( 0.916530, "GBP"), 1, "£0.9" ), # noqa: E201,E202 (Money(0.916530, 'GBP'), 1, '£0.9'), # noqa: E201,E202
(Money( 61.031024, "JPY"), 3, "¥61.031" ), # noqa: E201,E202 (Money(61.031024, 'JPY'), 3, '¥61.031'), # noqa: E201,E202
(Money( 49609.694602, "JPY"), 1, "¥49,609.7" ), # noqa: E201,E202 (Money(49609.694602, 'JPY'), 1, '¥49,609.7'), # noqa: E201,E202
(Money(155565.264777, "AUD"), 2, "A$155,565.26" ), # noqa: E201,E202 (Money(155565.264777, 'AUD'), 2, 'A$155,565.26'), # noqa: E201,E202
(Money( 0.820437, "CNY"), 4, "CN¥0.8204" ), # noqa: E201,E202 (Money(0.820437, 'CNY'), 4, 'CN¥0.8204'), # noqa: E201,E202
(Money( 7587.849178, "EUR"), 0, "€7,588" ), # noqa: E201,E202 (Money(7587.849178, 'EUR'), 0, '€7,588'), # noqa: E201,E202
(Money( 0.348102, "GBP"), 3, "£0.348" ), # noqa: E201,E202 (Money(0.348102, 'GBP'), 3, '£0.348'), # noqa: E201,E202
(Money( 0.652923, "CHF"), 0, "CHF1" ), # noqa: E201,E202 (Money(0.652923, 'CHF'), 0, 'CHF1'), # noqa: E201,E202
(Money( 0.820437, "CNY"), 1, "CN¥0.8" ), # noqa: E201,E202 (Money(0.820437, 'CNY'), 1, 'CN¥0.8'), # noqa: E201,E202
(Money(98789.5295680, "CHF"), 0, "CHF98,790" ), # noqa: E201,E202 (Money(98789.5295680, 'CHF'), 0, 'CHF98,790'), # noqa: E201,E202
(Money( 0.585787, "USD"), 1, "$0.6" ), # noqa: E201,E202 (Money(0.585787, 'USD'), 1, '$0.6'), # noqa: E201,E202
(Money( 0.690541, "CAD"), 3, "CA$0.691" ), # noqa: E201,E202 (Money(0.690541, 'CAD'), 3, 'CA$0.691'), # noqa: E201,E202
(Money( 427.814104, "AUD"), 5, "A$427.81410" ), # noqa: E201,E202 (Money(427.814104, 'AUD'), 5, 'A$427.81410'), # noqa: E201,E202
) )
with self.settings(LANGUAGE_CODE="en-us"): with self.settings(LANGUAGE_CODE='en-us'):
for value, decimal_places, expected_result in test_data: for value, decimal_places, expected_result in test_data:
result = InvenTree.format.format_money(value, decimal_places=decimal_places) result = InvenTree.format.format_money(
value, decimal_places=decimal_places
)
assert result == expected_result assert result == expected_result
@ -394,30 +353,35 @@ class TestHelpers(TestCase):
def test_absolute_url(self): def test_absolute_url(self):
"""Test helper function for generating an absolute URL""" """Test helper function for generating an absolute URL"""
base = "https://demo.inventree.org:12345" base = 'https://demo.inventree.org:12345'
InvenTreeSetting.set_setting('INVENTREE_BASE_URL', base, change_user=None) InvenTreeSetting.set_setting('INVENTREE_BASE_URL', base, change_user=None)
tests = { tests = {
"": base, '': base,
"api/": base + "/api/", 'api/': base + '/api/',
"/api/": base + "/api/", '/api/': base + '/api/',
"api": base + "/api", 'api': base + '/api',
"media/label/output/": base + "/media/label/output/", 'media/label/output/': base + '/media/label/output/',
"static/logo.png": base + "/static/logo.png", 'static/logo.png': base + '/static/logo.png',
"https://www.google.com": "https://www.google.com", 'https://www.google.com': 'https://www.google.com',
"https://demo.inventree.org:12345/out.html": "https://demo.inventree.org:12345/out.html", 'https://demo.inventree.org:12345/out.html': 'https://demo.inventree.org:12345/out.html',
"https://demo.inventree.org/test.html": "https://demo.inventree.org/test.html", 'https://demo.inventree.org/test.html': 'https://demo.inventree.org/test.html',
"http://www.cwi.nl:80/%7Eguido/Python.html": "http://www.cwi.nl:80/%7Eguido/Python.html", 'http://www.cwi.nl:80/%7Eguido/Python.html': 'http://www.cwi.nl:80/%7Eguido/Python.html',
"test.org": base + "/test.org", 'test.org': base + '/test.org',
} }
for url, expected in tests.items(): for url, expected in tests.items():
# Test with supplied base URL # Test with supplied base URL
self.assertEqual(InvenTree.helpers_model.construct_absolute_url(url, site_url=base), expected) self.assertEqual(
InvenTree.helpers_model.construct_absolute_url(url, site_url=base),
expected,
)
# Test without supplied base URL # Test without supplied base URL
self.assertEqual(InvenTree.helpers_model.construct_absolute_url(url), expected) self.assertEqual(
InvenTree.helpers_model.construct_absolute_url(url), expected
)
def test_image_url(self): def test_image_url(self):
"""Test if a filename looks like an image.""" """Test if a filename looks like an image."""
@ -453,7 +417,9 @@ class TestHelpers(TestCase):
"""Test static url helpers.""" """Test static url helpers."""
self.assertEqual(helpers.getStaticUrl('test.jpg'), '/static/test.jpg') self.assertEqual(helpers.getStaticUrl('test.jpg'), '/static/test.jpg')
self.assertEqual(helpers.getBlankImage(), '/static/img/blank_image.png') self.assertEqual(helpers.getBlankImage(), '/static/img/blank_image.png')
self.assertEqual(helpers.getBlankThumbnail(), '/static/img/blank_image.thumbnail.png') self.assertEqual(
helpers.getBlankThumbnail(), '/static/img/blank_image.thumbnail.png'
)
def testMediaUrl(self): def testMediaUrl(self):
"""Test getMediaUrl.""" """Test getMediaUrl."""
@ -476,12 +442,7 @@ class TestHelpers(TestCase):
def test_download_image(self): def test_download_image(self):
"""Test function for downloading image from remote URL""" """Test function for downloading image from remote URL"""
# Run check with a sequence of bad URLs # Run check with a sequence of bad URLs
for url in [ for url in ['blog', 'htp://test.com/?', 'google', '\\invalid-url']:
"blog",
"htp://test.com/?",
"google",
"\\invalid-url"
]:
with self.assertRaises(django_exceptions.ValidationError): with self.assertRaises(django_exceptions.ValidationError):
InvenTree.helpers_model.download_image_from_url(url) InvenTree.helpers_model.download_image_from_url(url)
@ -496,16 +457,17 @@ class TestHelpers(TestCase):
with self.assertRaises(expected_error): with self.assertRaises(expected_error):
while tries < retries: while tries < retries:
try: try:
InvenTree.helpers_model.download_image_from_url(url, timeout=timeout) InvenTree.helpers_model.download_image_from_url(
url, timeout=timeout
)
break break
except Exception as exc: except Exception as exc:
if type(exc) is expected_error: if type(exc) is expected_error:
# Re-throw this error # Re-throw this error
raise exc raise exc
else: else:
print("Unexpected error:", type(exc), exc) print('Unexpected error:', type(exc), exc)
tries += 1 tries += 1
time.sleep(10 * tries) time.sleep(10 * tries)
@ -518,16 +480,20 @@ class TestHelpers(TestCase):
# TODO: Re-implement this test when we are happier with the external service # TODO: Re-implement this test when we are happier with the external service
# dl_helper("https://httpstat.us/200?sleep=5000", requests.exceptions.ReadTimeout, timeout=1) # dl_helper("https://httpstat.us/200?sleep=5000", requests.exceptions.ReadTimeout, timeout=1)
large_img = "https://github.com/inventree/InvenTree/raw/master/InvenTree/InvenTree/static/img/paper_splash_large.jpg" large_img = 'https://github.com/inventree/InvenTree/raw/master/InvenTree/InvenTree/static/img/paper_splash_large.jpg'
InvenTreeSetting.set_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 1, change_user=None) InvenTreeSetting.set_setting(
'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 1, change_user=None
)
# Attempt to download an image which is too large # Attempt to download an image which is too large
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
InvenTree.helpers_model.download_image_from_url(large_img, timeout=10) InvenTree.helpers_model.download_image_from_url(large_img, timeout=10)
# Increase allowable download size # Increase allowable download size
InvenTreeSetting.set_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 5, change_user=None) InvenTreeSetting.set_setting(
'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 5, change_user=None
)
# Download a valid image (should not throw an error) # Download a valid image (should not throw an error)
InvenTree.helpers_model.download_image_from_url(large_img, timeout=10) InvenTree.helpers_model.download_image_from_url(large_img, timeout=10)
@ -561,14 +527,14 @@ class TestIncrement(TestCase):
def tests(self): def tests(self):
"""Test 'intelligent' incrementing function.""" """Test 'intelligent' incrementing function."""
tests = [ tests = [
("", '1'), ('', '1'),
(1, "2"), (1, '2'),
("001", "002"), ('001', '002'),
("1001", "1002"), ('1001', '1002'),
("ABC123", "ABC124"), ('ABC123', 'ABC124'),
("XYZ0", "XYZ1"), ('XYZ0', 'XYZ1'),
("123Q", "123Q"), ('123Q', '123Q'),
("QQQ", "QQQ"), ('QQQ', 'QQQ'),
] ]
for test in tests: for test in tests:
@ -584,13 +550,7 @@ class TestMakeBarcode(TestCase):
def test_barcode_extended(self): def test_barcode_extended(self):
"""Test creation of barcode with extended data.""" """Test creation of barcode with extended data."""
bc = helpers.MakeBarcode( bc = helpers.MakeBarcode(
"part", 'part', 3, {'id': 3, 'url': 'www.google.com'}, brief=False
3,
{
"id": 3,
"url": "www.google.com",
},
brief=False
) )
self.assertIn('part', bc) self.assertIn('part', bc)
@ -604,10 +564,7 @@ class TestMakeBarcode(TestCase):
def test_barcode_brief(self): def test_barcode_brief(self):
"""Test creation of simple barcode.""" """Test creation of simple barcode."""
bc = helpers.MakeBarcode( bc = helpers.MakeBarcode('stockitem', 7)
"stockitem",
7,
)
data = json.loads(bc) data = json.loads(bc)
self.assertEqual(len(data), 1) self.assertEqual(len(data), 1)
@ -619,16 +576,14 @@ class TestDownloadFile(TestCase):
def test_download(self): def test_download(self):
"""Tests for DownloadFile.""" """Tests for DownloadFile."""
helpers.DownloadFile("hello world", "out.txt") helpers.DownloadFile('hello world', 'out.txt')
helpers.DownloadFile(bytes(b"hello world"), "out.bin") helpers.DownloadFile(bytes(b'hello world'), 'out.bin')
class TestMPTT(TestCase): class TestMPTT(TestCase):
"""Tests for the MPTT tree models.""" """Tests for the MPTT tree models."""
fixtures = [ fixtures = ['location']
'location',
]
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -681,62 +636,62 @@ class TestSerialNumberExtraction(TestCase):
e = helpers.extract_serial_numbers e = helpers.extract_serial_numbers
# Test a range of numbers # Test a range of numbers
sn = e("1-5", 5, 1) sn = e('1-5', 5, 1)
self.assertEqual(len(sn), 5) self.assertEqual(len(sn), 5)
for i in range(1, 6): for i in range(1, 6):
self.assertIn(str(i), sn) self.assertIn(str(i), sn)
sn = e("11-30", 20, 1) sn = e('11-30', 20, 1)
self.assertEqual(len(sn), 20) self.assertEqual(len(sn), 20)
sn = e("1, 2, 3, 4, 5", 5, 1) sn = e('1, 2, 3, 4, 5', 5, 1)
self.assertEqual(len(sn), 5) self.assertEqual(len(sn), 5)
# Test partially specifying serials # Test partially specifying serials
sn = e("1, 2, 4+", 5, 1) sn = e('1, 2, 4+', 5, 1)
self.assertEqual(len(sn), 5) self.assertEqual(len(sn), 5)
self.assertEqual(sn, ['1', '2', '4', '5', '6']) self.assertEqual(sn, ['1', '2', '4', '5', '6'])
# Test groups are not interpolated if enough serials are supplied # Test groups are not interpolated if enough serials are supplied
sn = e("1, 2, 3, AF5-69H, 5", 5, 1) sn = e('1, 2, 3, AF5-69H, 5', 5, 1)
self.assertEqual(len(sn), 5) self.assertEqual(len(sn), 5)
self.assertEqual(sn, ['1', '2', '3', 'AF5-69H', '5']) self.assertEqual(sn, ['1', '2', '3', 'AF5-69H', '5'])
# Test groups are not interpolated with more than one hyphen in a word # Test groups are not interpolated with more than one hyphen in a word
sn = e("1, 2, TG-4SR-92, 4+", 5, 1) sn = e('1, 2, TG-4SR-92, 4+', 5, 1)
self.assertEqual(len(sn), 5) self.assertEqual(len(sn), 5)
self.assertEqual(sn, ['1', '2', "TG-4SR-92", '4', '5']) self.assertEqual(sn, ['1', '2', 'TG-4SR-92', '4', '5'])
# Test multiple placeholders # Test multiple placeholders
sn = e("1 2 ~ ~ ~", 5, 2) sn = e('1 2 ~ ~ ~', 5, 2)
self.assertEqual(len(sn), 5) self.assertEqual(len(sn), 5)
self.assertEqual(sn, ['1', '2', '3', '4', '5']) self.assertEqual(sn, ['1', '2', '3', '4', '5'])
sn = e("1-5, 10-15", 11, 1) sn = e('1-5, 10-15', 11, 1)
self.assertIn('3', sn) self.assertIn('3', sn)
self.assertIn('13', sn) self.assertIn('13', sn)
sn = e("1+", 10, 1) sn = e('1+', 10, 1)
self.assertEqual(len(sn), 10) self.assertEqual(len(sn), 10)
self.assertEqual(sn, [str(_) for _ in range(1, 11)]) self.assertEqual(sn, [str(_) for _ in range(1, 11)])
sn = e("4, 1+2", 4, 1) sn = e('4, 1+2', 4, 1)
self.assertEqual(len(sn), 4) self.assertEqual(len(sn), 4)
self.assertEqual(sn, ['4', '1', '2', '3']) self.assertEqual(sn, ['4', '1', '2', '3'])
sn = e("~", 1, 1) sn = e('~', 1, 1)
self.assertEqual(len(sn), 1) self.assertEqual(len(sn), 1)
self.assertEqual(sn, ['2']) self.assertEqual(sn, ['2'])
sn = e("~", 1, 3) sn = e('~', 1, 3)
self.assertEqual(len(sn), 1) self.assertEqual(len(sn), 1)
self.assertEqual(sn, ['4']) self.assertEqual(sn, ['4'])
sn = e("~+", 2, 4) sn = e('~+', 2, 4)
self.assertEqual(len(sn), 2) self.assertEqual(len(sn), 2)
self.assertEqual(sn, ['5', '6']) self.assertEqual(sn, ['5', '6'])
sn = e("~+3", 4, 4) sn = e('~+3', 4, 4)
self.assertEqual(len(sn), 4) self.assertEqual(len(sn), 4)
self.assertEqual(sn, ['5', '6', '7', '8']) self.assertEqual(sn, ['5', '6', '7', '8'])
@ -746,70 +701,70 @@ class TestSerialNumberExtraction(TestCase):
# Test duplicates # Test duplicates
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
e("1,2,3,3,3", 5, 1) e('1,2,3,3,3', 5, 1)
# Test invalid length # Test invalid length
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
e("1,2,3", 5, 1) e('1,2,3', 5, 1)
# Test empty string # Test empty string
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
e(", , ,", 0, 1) e(', , ,', 0, 1)
# Test incorrect sign in group # Test incorrect sign in group
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
e("10-2", 8, 1) e('10-2', 8, 1)
# Test invalid group # Test invalid group
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
e("1-5-10", 10, 1) e('1-5-10', 10, 1)
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
e("10, a, 7-70j", 4, 1) e('10, a, 7-70j', 4, 1)
# Test groups are not interpolated with word characters # Test groups are not interpolated with word characters
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
e("1, 2, 3, E-5", 5, 1) e('1, 2, 3, E-5', 5, 1)
# Extract a range of values with a smaller range # Extract a range of values with a smaller range
with self.assertRaises(ValidationError) as exc: with self.assertRaises(ValidationError) as exc:
e("11-50", 10, 1) e('11-50', 10, 1)
self.assertIn('Range quantity exceeds 10', str(exc)) self.assertIn('Range quantity exceeds 10', str(exc))
# Test groups are not interpolated with alpha characters # Test groups are not interpolated with alpha characters
with self.assertRaises(ValidationError) as exc: with self.assertRaises(ValidationError) as exc:
e("1, A-2, 3+", 5, 1) e('1, A-2, 3+', 5, 1)
self.assertIn('Invalid group range: A-2', str(exc)) self.assertIn('Invalid group range: A-2', str(exc))
def test_combinations(self): def test_combinations(self):
"""Test complex serial number combinations.""" """Test complex serial number combinations."""
e = helpers.extract_serial_numbers e = helpers.extract_serial_numbers
sn = e("1 3-5 9+2", 7, 1) sn = e('1 3-5 9+2', 7, 1)
self.assertEqual(len(sn), 7) self.assertEqual(len(sn), 7)
self.assertEqual(sn, ['1', '3', '4', '5', '9', '10', '11']) self.assertEqual(sn, ['1', '3', '4', '5', '9', '10', '11'])
sn = e("1,3-5,9+2", 7, 1) sn = e('1,3-5,9+2', 7, 1)
self.assertEqual(len(sn), 7) self.assertEqual(len(sn), 7)
self.assertEqual(sn, ['1', '3', '4', '5', '9', '10', '11']) self.assertEqual(sn, ['1', '3', '4', '5', '9', '10', '11'])
sn = e("~+2", 3, 13) sn = e('~+2', 3, 13)
self.assertEqual(len(sn), 3) self.assertEqual(len(sn), 3)
self.assertEqual(sn, ['14', '15', '16']) self.assertEqual(sn, ['14', '15', '16'])
sn = e("~+", 2, 13) sn = e('~+', 2, 13)
self.assertEqual(len(sn), 2) self.assertEqual(len(sn), 2)
self.assertEqual(sn, ['14', '15']) self.assertEqual(sn, ['14', '15'])
# Test multiple increment groups # Test multiple increment groups
sn = e("~+4, 20+4, 30+4", 15, 10) sn = e('~+4, 20+4, 30+4', 15, 10)
self.assertEqual(len(sn), 15) self.assertEqual(len(sn), 15)
for v in [14, 24, 34]: for v in [14, 24, 34]:
self.assertIn(str(v), sn) self.assertIn(str(v), sn)
# Test multiple range groups # Test multiple range groups
sn = e("11-20, 41-50, 91-100", 30, 1) sn = e('11-20, 41-50, 91-100', 30, 1)
self.assertEqual(len(sn), 30) self.assertEqual(len(sn), 30)
for v in range(11, 21): for v in range(11, 21):
@ -848,7 +803,7 @@ class TestVersionNumber(TestCase):
"""Test that the git commit information is extracted successfully.""" """Test that the git commit information is extracted successfully."""
envs = { envs = {
'INVENTREE_COMMIT_HASH': 'abcdef', 'INVENTREE_COMMIT_HASH': 'abcdef',
'INVENTREE_COMMIT_DATE': '2022-12-31' 'INVENTREE_COMMIT_DATE': '2022-12-31',
} }
# Check that the environment variables take priority # Check that the environment variables take priority
@ -861,10 +816,16 @@ class TestVersionNumber(TestCase):
# Check that the current .git values work too # Check that the current .git values work too
hash = str(subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8').strip() hash = str(
subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8'
).strip()
self.assertEqual(hash, version.inventreeCommitHash()) self.assertEqual(hash, version.inventreeCommitHash())
d = str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8').strip().split(' ')[0] d = (
str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8')
.strip()
.split(' ')[0]
)
self.assertEqual(d, version.inventreeCommitDate()) self.assertEqual(d, version.inventreeCommitDate())
@ -898,7 +859,7 @@ class CurrencyTests(TestCase):
break break
else: # pragma: no cover else: # pragma: no cover
print("Exchange rate update failed - retrying") print('Exchange rate update failed - retrying')
print(f'Expected {currency_codes()}, got {[a.currency for a in rates]}') print(f'Expected {currency_codes()}, got {[a.currency for a in rates]}')
time.sleep(1) time.sleep(1)
@ -975,16 +936,14 @@ class TestSettings(InvenTreeTestCase):
self.assertEqual(user_count(), 1) self.assertEqual(user_count(), 1)
# not enough set # not enough set
self.run_reload({ self.run_reload({'INVENTREE_ADMIN_USER': 'admin'})
'INVENTREE_ADMIN_USER': 'admin'
})
self.assertEqual(user_count(), 1) self.assertEqual(user_count(), 1)
# enough set # enough set
self.run_reload({ self.run_reload({
'INVENTREE_ADMIN_USER': 'admin', # set username 'INVENTREE_ADMIN_USER': 'admin', # set username
'INVENTREE_ADMIN_EMAIL': 'info@example.com', # set email 'INVENTREE_ADMIN_EMAIL': 'info@example.com', # set email
'INVENTREE_ADMIN_PASSWORD': 'password123' # set password 'INVENTREE_ADMIN_PASSWORD': 'password123', # set password
}) })
self.assertEqual(user_count(), 2) self.assertEqual(user_count(), 2)
@ -1027,27 +986,27 @@ class TestSettings(InvenTreeTestCase):
"""Test get_config_file.""" """Test get_config_file."""
# normal run - not configured # normal run - not configured
valid = [ valid = ['inventree/config.yaml', 'inventree/data/config.yaml']
'inventree/config.yaml',
'inventree/data/config.yaml',
]
self.assertTrue(any(opt in str(config.get_config_file()).lower() for opt in valid)) self.assertTrue(
any(opt in str(config.get_config_file()).lower() for opt in valid)
)
# with env set # with env set
with self.in_env_context({'INVENTREE_CONFIG_FILE': 'my_special_conf.yaml'}): with self.in_env_context({'INVENTREE_CONFIG_FILE': 'my_special_conf.yaml'}):
self.assertIn('inventree/my_special_conf.yaml', str(config.get_config_file()).lower()) self.assertIn(
'inventree/my_special_conf.yaml', str(config.get_config_file()).lower()
)
def test_helpers_plugin_file(self): def test_helpers_plugin_file(self):
"""Test get_plugin_file.""" """Test get_plugin_file."""
# normal run - not configured # normal run - not configured
valid = [ valid = ['inventree/plugins.txt', 'inventree/data/plugins.txt']
'inventree/plugins.txt',
'inventree/data/plugins.txt',
]
self.assertTrue(any(opt in str(config.get_plugin_file()).lower() for opt in valid)) self.assertTrue(
any(opt in str(config.get_plugin_file()).lower() for opt in valid)
)
# with env set # with env set
with self.in_env_context({'INVENTREE_PLUGIN_FILE': 'my_special_plugins.txt'}): with self.in_env_context({'INVENTREE_PLUGIN_FILE': 'my_special_plugins.txt'}):
@ -1064,11 +1023,15 @@ class TestSettings(InvenTreeTestCase):
self.assertEqual(config.get_setting(TEST_ENV_NAME, None), '321') self.assertEqual(config.get_setting(TEST_ENV_NAME, None), '321')
# test typecasting to dict - None should be mapped to empty dict # test typecasting to dict - None should be mapped to empty dict
self.assertEqual(config.get_setting(TEST_ENV_NAME, None, None, typecast=dict), {}) self.assertEqual(
config.get_setting(TEST_ENV_NAME, None, None, typecast=dict), {}
)
# test typecasting to dict - valid JSON string should be mapped to corresponding dict # test typecasting to dict - valid JSON string should be mapped to corresponding dict
with self.in_env_context({TEST_ENV_NAME: '{"a": 1}'}): with self.in_env_context({TEST_ENV_NAME: '{"a": 1}'}):
self.assertEqual(config.get_setting(TEST_ENV_NAME, None, typecast=dict), {"a": 1}) self.assertEqual(
config.get_setting(TEST_ENV_NAME, None, typecast=dict), {'a': 1}
)
# test typecasting to dict - invalid JSON string should be mapped to empty dict # test typecasting to dict - invalid JSON string should be mapped to empty dict
with self.in_env_context({TEST_ENV_NAME: "{'a': 1}"}): with self.in_env_context({TEST_ENV_NAME: "{'a': 1}"}):
@ -1084,8 +1047,8 @@ class TestInstanceName(InvenTreeTestCase):
self.assertEqual(version.inventreeInstanceTitle(), 'InvenTree') self.assertEqual(version.inventreeInstanceTitle(), 'InvenTree')
# set up required setting # set up required setting
InvenTreeSetting.set_setting("INVENTREE_INSTANCE_TITLE", True, self.user) InvenTreeSetting.set_setting('INVENTREE_INSTANCE_TITLE', True, self.user)
InvenTreeSetting.set_setting("INVENTREE_INSTANCE", "Testing title", self.user) InvenTreeSetting.set_setting('INVENTREE_INSTANCE', 'Testing title', self.user)
self.assertEqual(version.inventreeInstanceTitle(), 'Testing title') self.assertEqual(version.inventreeInstanceTitle(), 'Testing title')
@ -1096,7 +1059,9 @@ class TestInstanceName(InvenTreeTestCase):
def test_instance_url(self): def test_instance_url(self):
"""Test instance url settings.""" """Test instance url settings."""
# Set up required setting # Set up required setting
InvenTreeSetting.set_setting("INVENTREE_BASE_URL", "http://127.1.2.3", self.user) InvenTreeSetting.set_setting(
'INVENTREE_BASE_URL', 'http://127.1.2.3', self.user
)
# The site should also be changed # The site should also be changed
site_obj = Site.objects.all().order_by('id').first() site_obj = Site.objects.all().order_by('id').first()
@ -1106,12 +1071,7 @@ class TestInstanceName(InvenTreeTestCase):
class TestOffloadTask(InvenTreeTestCase): class TestOffloadTask(InvenTreeTestCase):
"""Tests for offloading tasks to the background worker""" """Tests for offloading tasks to the background worker"""
fixtures = [ fixtures = ['category', 'part', 'location', 'stock']
'category',
'part',
'location',
'stock',
]
def test_offload_tasks(self): def test_offload_tasks(self):
"""Test that we can offload various tasks to the background worker thread. """Test that we can offload various tasks to the background worker thread.
@ -1127,38 +1087,38 @@ class TestOffloadTask(InvenTreeTestCase):
Ref: https://github.com/inventree/InvenTree/pull/3273 Ref: https://github.com/inventree/InvenTree/pull/3273
""" """
self.assertTrue(offload_task( self.assertTrue(
offload_task(
'dummy_tasks.stock', 'dummy_tasks.stock',
item=StockItem.objects.get(pk=1), item=StockItem.objects.get(pk=1),
loc=StockLocation.objects.get(pk=1), loc=StockLocation.objects.get(pk=1),
force_async=True force_async=True,
)) )
)
self.assertTrue(offload_task( self.assertTrue(
'dummy_task.numbers', offload_task('dummy_task.numbers', 1, 2, 3, 4, 5, force_async=True)
1, 2, 3, 4, 5, )
force_async=True
))
# Offload a dummy task, but force sync # Offload a dummy task, but force sync
# This should fail, because the function does not exist # This should fail, because the function does not exist
with self.assertLogs(logger='inventree', level='WARNING') as log: with self.assertLogs(logger='inventree', level='WARNING') as log:
self.assertFalse(offload_task( self.assertFalse(
'dummy_task.numbers', offload_task('dummy_task.numbers', 1, 1, 1, force_sync=True)
1, 1, 1, )
force_sync=True
))
self.assertIn("Malformed function path", str(log.output)) self.assertIn('Malformed function path', str(log.output))
# Offload dummy task with a Part instance # Offload dummy task with a Part instance
# This should succeed, ensuring that the Part instance is correctly pickled # This should succeed, ensuring that the Part instance is correctly pickled
self.assertTrue(offload_task( self.assertTrue(
offload_task(
'dummy_tasks.parts', 'dummy_tasks.parts',
part=Part.objects.get(pk=1), part=Part.objects.get(pk=1),
cat=PartCategory.objects.get(pk=1), cat=PartCategory.objects.get(pk=1),
force_async=True force_async=True,
)) )
)
def test_daily_holdoff(self): def test_daily_holdoff(self):
"""Tests for daily task holdoff helper functions""" """Tests for daily task holdoff helper functions"""
@ -1180,7 +1140,9 @@ class TestOffloadTask(InvenTreeTestCase):
# An attempt has been logged, but it is too recent # An attempt has been logged, but it is too recent
result = InvenTree.tasks.check_daily_holdoff('dummy_task') result = InvenTree.tasks.check_daily_holdoff('dummy_task')
self.assertFalse(result) self.assertFalse(result)
self.assertIn("Last attempt for 'dummy_task' was too recent", str(cm.output)) self.assertIn(
"Last attempt for 'dummy_task' was too recent", str(cm.output)
)
# Mark last attempt a few days ago - should now return True # Mark last attempt a few days ago - should now return True
t_old = datetime.now() - timedelta(days=3) t_old = datetime.now() - timedelta(days=3)
@ -1191,13 +1153,17 @@ class TestOffloadTask(InvenTreeTestCase):
self.assertTrue(result) self.assertTrue(result)
# Last attempt should have been updated # Last attempt should have been updated
self.assertNotEqual(t_old, InvenTreeSetting.get_setting('_dummy_task_ATTEMPT', '', cache=False)) self.assertNotEqual(
t_old, InvenTreeSetting.get_setting('_dummy_task_ATTEMPT', '', cache=False)
)
# Last attempt should prevent us now # Last attempt should prevent us now
with self.assertLogs(logger='inventree', level='INFO') as cm: with self.assertLogs(logger='inventree', level='INFO') as cm:
result = InvenTree.tasks.check_daily_holdoff('dummy_task') result = InvenTree.tasks.check_daily_holdoff('dummy_task')
self.assertFalse(result) self.assertFalse(result)
self.assertIn("Last attempt for 'dummy_task' was too recent", str(cm.output)) self.assertIn(
"Last attempt for 'dummy_task' was too recent", str(cm.output)
)
# Configure so a task was successful too recently # Configure so a task was successful too recently
InvenTreeSetting.set_setting('_dummy_task_ATTEMPT', t_old, None) InvenTreeSetting.set_setting('_dummy_task_ATTEMPT', t_old, None)

View File

@ -24,7 +24,7 @@ def reload_translation_stats():
keys = _translation_stats.keys() keys = _translation_stats.keys()
# Note that the names used in the stats file may not align 100% # Note that the names used in the stats file may not align 100%
for (code, _lang) in settings.LANGUAGES: for code, _lang in settings.LANGUAGES:
if code in keys: if code in keys:
# Direct match, move on # Direct match, move on
continue continue

View File

@ -39,7 +39,7 @@ def getMigrationFileNames(app):
files = local_dir.joinpath('..', app, 'migrations').iterdir() files = local_dir.joinpath('..', app, 'migrations').iterdir()
# Regex pattern for migration files # Regex pattern for migration files
regex = re.compile(r"^[\d]+_.*\.py$") regex = re.compile(r'^[\d]+_.*\.py$')
migration_files = [] migration_files = []
@ -56,7 +56,6 @@ def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True):
oldest_file = None oldest_file = None
for f in getMigrationFileNames(app): for f in getMigrationFileNames(app):
if ignore_initial and f.startswith('0001_initial'): if ignore_initial and f.startswith('0001_initial'):
continue continue
@ -115,9 +114,7 @@ class UserMixin:
# Create a user to log in with # Create a user to log in with
cls.user = get_user_model().objects.create_user( cls.user = get_user_model().objects.create_user(
username=cls.username, username=cls.username, password=cls.password, email=cls.email
password=cls.password,
email=cls.email
) )
# Create a group for the user # Create a group for the user
@ -163,15 +160,15 @@ class UserMixin:
raise TypeError('assignRole: assign_all must be a boolean value') raise TypeError('assignRole: assign_all must be a boolean value')
if not role and not assign_all: if not role and not assign_all:
raise ValueError('assignRole: either role must be provided, or assign_all must be set') raise ValueError(
'assignRole: either role must be provided, or assign_all must be set'
)
if not assign_all and role: if not assign_all and role:
rule, perm = role.split('.') rule, perm = role.split('.')
for ruleset in group.rule_sets.all(): for ruleset in group.rule_sets.all():
if assign_all or ruleset.name == rule: if assign_all or ruleset.name == rule:
if assign_all or perm == 'view': if assign_all or perm == 'view':
ruleset.can_view = True ruleset.can_view = True
elif assign_all or perm == 'change': elif assign_all or perm == 'change':
@ -205,37 +202,24 @@ class ExchangeRateMixin:
def generate_exchange_rates(self): def generate_exchange_rates(self):
"""Helper function which generates some exchange rates to work with""" """Helper function which generates some exchange rates to work with"""
rates = { rates = {'AUD': 1.5, 'CAD': 1.7, 'GBP': 0.9, 'USD': 1.0}
'AUD': 1.5,
'CAD': 1.7,
'GBP': 0.9,
'USD': 1.0,
}
# Create a dummy backend # Create a dummy backend
ExchangeBackend.objects.create( ExchangeBackend.objects.create(name='InvenTreeExchange', base_currency='USD')
name='InvenTreeExchange',
base_currency='USD',
)
backend = ExchangeBackend.objects.get(name='InvenTreeExchange') backend = ExchangeBackend.objects.get(name='InvenTreeExchange')
items = [] items = []
for currency, rate in rates.items(): for currency, rate in rates.items():
items.append( items.append(Rate(currency=currency, value=rate, backend=backend))
Rate(
currency=currency,
value=rate,
backend=backend,
)
)
Rate.objects.bulk_create(items) Rate.objects.bulk_create(items)
class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase): class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase):
"""Testcase with user setup buildin.""" """Testcase with user setup buildin."""
pass pass
@ -243,7 +227,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
"""Base class for running InvenTree API tests.""" """Base class for running InvenTree API tests."""
@contextmanager @contextmanager
def assertNumQueriesLessThan(self, value, using='default', verbose=False, debug=False): def assertNumQueriesLessThan(
self, value, using='default', verbose=False, debug=False
):
"""Context manager to check that the number of queries is less than a certain value. """Context manager to check that the number of queries is less than a certain value.
Example: Example:
@ -255,14 +241,14 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
yield # your test will be run here yield # your test will be run here
if verbose: if verbose:
msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4) msg = '\r\n%s' % json.dumps(context.captured_queries, indent=4)
else: else:
msg = None msg = None
n = len(context.captured_queries) n = len(context.captured_queries)
if debug: if debug:
print(f"Expected less than {value} queries, got {n} queries") print(f'Expected less than {value} queries, got {n} queries')
self.assertLess(n, value, msg=msg) self.assertLess(n, value, msg=msg)
@ -273,8 +259,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
return return
if expected_code != response.status_code: if expected_code != response.status_code:
print(
print(f"Unexpected {method} response at '{url}': status_code = {response.status_code}") f"Unexpected {method} response at '{url}': status_code = {response.status_code}"
)
if hasattr(response, 'data'): if hasattr(response, 'data'):
print('data:', response.data) print('data:', response.data)
@ -359,7 +346,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
return response return response
def download_file(self, url, data, expected_code=None, expected_fn=None, decode=True): def download_file(
self, url, data, expected_code=None, expected_fn=None, decode=True
):
"""Download a file from the server, and return an in-memory file.""" """Download a file from the server, and return an in-memory file."""
response = self.client.get(url, data=data, format='json') response = self.client.get(url, data=data, format='json')
@ -367,7 +356,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
# Check that the response is of the correct type # Check that the response is of the correct type
if not isinstance(response, StreamingHttpResponse): if not isinstance(response, StreamingHttpResponse):
raise ValueError("Response is not a StreamingHttpResponse object as expected") raise ValueError(
'Response is not a StreamingHttpResponse object as expected'
)
# Extract filename # Extract filename
disposition = response.headers['Content-Disposition'] disposition = response.headers['Content-Disposition']
@ -394,7 +385,14 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
return file return file
def process_csv(self, file_object, delimiter=',', required_cols=None, excluded_cols=None, required_rows=None): def process_csv(
self,
file_object,
delimiter=',',
required_cols=None,
excluded_cols=None,
required_rows=None,
):
"""Helper function to process and validate a downloaded csv file.""" """Helper function to process and validate a downloaded csv file."""
# Check that the correct object type has been passed # Check that the correct object type has been passed
self.assertTrue(isinstance(file_object, io.StringIO)) self.assertTrue(isinstance(file_object, io.StringIO))

View File

@ -10,9 +10,11 @@ from django.urls import include, path, re_path
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from django.views.generic.base import RedirectView from django.views.generic.base import RedirectView
from dj_rest_auth.registration.views import (ConfirmEmailView, from dj_rest_auth.registration.views import (
ConfirmEmailView,
SocialAccountDisconnectView, SocialAccountDisconnectView,
SocialAccountListView) SocialAccountListView,
)
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView
from sesame.views import LoginView from sesame.views import LoginView
@ -28,36 +30,49 @@ import stock.api
import users.api import users.api
from build.urls import build_urls from build.urls import build_urls
from common.urls import common_urls from common.urls import common_urls
from company.urls import (company_urls, manufacturer_part_urls, from company.urls import company_urls, manufacturer_part_urls, supplier_part_urls
supplier_part_urls)
from order.urls import order_urls from order.urls import order_urls
from part.urls import part_urls from part.urls import part_urls
from plugin.urls import get_plugin_urls from plugin.urls import get_plugin_urls
from stock.urls import stock_urls from stock.urls import stock_urls
from web.urls import urlpatterns as platform_urls from web.urls import urlpatterns as platform_urls
from .api import (APISearchView, InfoView, NotFoundView, VersionTextView, from .api import APISearchView, InfoView, NotFoundView, VersionTextView, VersionView
VersionView)
from .magic_login import GetSimpleLoginView from .magic_login import GetSimpleLoginView
from .social_auth_urls import (EmailListView, EmailPrimaryView, from .social_auth_urls import (
EmailRemoveView, EmailVerifyView, EmailListView,
SocialProviderListView, social_auth_urlpatterns) EmailPrimaryView,
from .views import (AboutView, AppearanceSelectView, CustomConnectionsView, EmailRemoveView,
CustomEmailView, CustomLoginView, EmailVerifyView,
SocialProviderListView,
social_auth_urlpatterns,
)
from .views import (
AboutView,
AppearanceSelectView,
CustomConnectionsView,
CustomEmailView,
CustomLoginView,
CustomPasswordResetFromKeyView, CustomPasswordResetFromKeyView,
CustomSessionDeleteOtherView, CustomSessionDeleteView, CustomSessionDeleteOtherView,
DatabaseStatsView, DynamicJsView, EditUserView, IndexView, CustomSessionDeleteView,
NotificationsView, SearchView, SetPasswordView, DatabaseStatsView,
SettingsView, auth_request) DynamicJsView,
EditUserView,
IndexView,
NotificationsView,
SearchView,
SetPasswordView,
SettingsView,
auth_request,
)
admin.site.site_header = "InvenTree Admin" admin.site.site_header = 'InvenTree Admin'
apipatterns = [ apipatterns = [
# Global search # Global search
path('search/', APISearchView.as_view(), name='api-search'), path('search/', APISearchView.as_view(), name='api-search'),
re_path(r'^settings/', include(common.api.settings_api_urls)), re_path(r'^settings/', include(common.api.settings_api_urls)),
re_path(r'^part/', include(part.api.part_api_urls)), re_path(r'^part/', include(part.api.part_api_urls)),
re_path(r'^bom/', include(part.api.bom_api_urls)), re_path(r'^bom/', include(part.api.bom_api_urls)),
@ -69,117 +84,288 @@ apipatterns = [
re_path(r'^report/', include(report.api.report_api_urls)), re_path(r'^report/', include(report.api.report_api_urls)),
re_path(r'^user/', include(users.api.user_urls)), re_path(r'^user/', include(users.api.user_urls)),
re_path(r'^admin/', include(common.api.admin_api_urls)), re_path(r'^admin/', include(common.api.admin_api_urls)),
# Plugin endpoints # Plugin endpoints
path('', include(plugin.api.plugin_api_urls)), path('', include(plugin.api.plugin_api_urls)),
# Common endpoints endpoint # Common endpoints endpoint
path('', include(common.api.common_api_urls)), path('', include(common.api.common_api_urls)),
# OpenAPI Schema # OpenAPI Schema
re_path('schema/', SpectacularAPIView.as_view(custom_settings={'SCHEMA_PATH_PREFIX': '/api/'}), name='schema'), re_path(
'schema/',
SpectacularAPIView.as_view(custom_settings={'SCHEMA_PATH_PREFIX': '/api/'}),
name='schema',
),
# InvenTree information endpoints # InvenTree information endpoints
path("version-text", VersionTextView.as_view(), name="api-version-text"), # version text path(
'version-text', VersionTextView.as_view(), name='api-version-text'
), # version text
path('version/', VersionView.as_view(), name='api-version'), # version info path('version/', VersionView.as_view(), name='api-version'), # version info
path('', InfoView.as_view(), name='api-inventree-info'), # server info path('', InfoView.as_view(), name='api-inventree-info'), # server info
# Auth API endpoints # Auth API endpoints
path('auth/', include([ path(
re_path(r'^registration/account-confirm-email/(?P<key>[-:\w]+)/$', ConfirmEmailView.as_view(), name='account_confirm_email'), 'auth/',
include([
re_path(
r'^registration/account-confirm-email/(?P<key>[-:\w]+)/$',
ConfirmEmailView.as_view(),
name='account_confirm_email',
),
path('registration/', include('dj_rest_auth.registration.urls')), path('registration/', include('dj_rest_auth.registration.urls')),
path('providers/', SocialProviderListView.as_view(), name='social_providers'), path(
path('emails/', include([path('<int:pk>/', include([ 'providers/', SocialProviderListView.as_view(), name='social_providers'
path('primary/', EmailPrimaryView.as_view(), name='email-primary'), ),
path('verify/', EmailVerifyView.as_view(), name='email-verify'), path(
path('remove/', EmailRemoveView().as_view(), name='email-remove'),])), 'emails/',
path('', EmailListView.as_view(), name='email-list') include([
])), path(
'<int:pk>/',
include([
path(
'primary/',
EmailPrimaryView.as_view(),
name='email-primary',
),
path(
'verify/',
EmailVerifyView.as_view(),
name='email-verify',
),
path(
'remove/',
EmailRemoveView().as_view(),
name='email-remove',
),
]),
),
path('', EmailListView.as_view(), name='email-list'),
]),
),
path('social/', include(social_auth_urlpatterns)), path('social/', include(social_auth_urlpatterns)),
path('social/', SocialAccountListView.as_view(), name='social_account_list'), path(
path('social/<int:pk>/disconnect/', SocialAccountDisconnectView.as_view(), name='social_account_disconnect'), 'social/', SocialAccountListView.as_view(), name='social_account_list'
),
path(
'social/<int:pk>/disconnect/',
SocialAccountDisconnectView.as_view(),
name='social_account_disconnect',
),
path('', include('dj_rest_auth.urls')), path('', include('dj_rest_auth.urls')),
])), ]),
),
# Magic login URLs # Magic login URLs
path("email/generate/", csrf_exempt(GetSimpleLoginView().as_view()), name="sesame-generate",), path(
path("email/login/", LoginView.as_view(), name="sesame-login"), 'email/generate/',
csrf_exempt(GetSimpleLoginView().as_view()),
name='sesame-generate',
),
path('email/login/', LoginView.as_view(), name='sesame-login'),
# Unknown endpoint # Unknown endpoint
re_path(r'^.*$', NotFoundView.as_view(), name='api-404'), re_path(r'^.*$', NotFoundView.as_view(), name='api-404'),
] ]
settings_urls = [ settings_urls = [
re_path(r'^i18n/?', include('django.conf.urls.i18n')), re_path(r'^i18n/?', include('django.conf.urls.i18n')),
re_path(
re_path(r'^appearance/?', AppearanceSelectView.as_view(), name='settings-appearance'), r'^appearance/?', AppearanceSelectView.as_view(), name='settings-appearance'
),
# Catch any other urls # Catch any other urls
re_path(r'^.*$', SettingsView.as_view(template_name='InvenTree/settings/settings.html'), name='settings'), re_path(
r'^.*$',
SettingsView.as_view(template_name='InvenTree/settings/settings.html'),
name='settings',
),
] ]
notifications_urls = [ notifications_urls = [
# Catch any other urls # Catch any other urls
re_path(r'^.*$', NotificationsView.as_view(), name='notifications'), re_path(r'^.*$', NotificationsView.as_view(), name='notifications')
] ]
# These javascript files are served "dynamically" - i.e. rendered on demand # These javascript files are served "dynamically" - i.e. rendered on demand
dynamic_javascript_urls = [ dynamic_javascript_urls = [
re_path(r'^calendar.js', DynamicJsView.as_view(template_name='js/dynamic/calendar.js'), name='calendar.js'), re_path(
re_path(r'^nav.js', DynamicJsView.as_view(template_name='js/dynamic/nav.js'), name='nav.js'), r'^calendar.js',
re_path(r'^permissions.js', DynamicJsView.as_view(template_name='js/dynamic/permissions.js'), name='permissions.js'), DynamicJsView.as_view(template_name='js/dynamic/calendar.js'),
re_path(r'^settings.js', DynamicJsView.as_view(template_name='js/dynamic/settings.js'), name='settings.js'), name='calendar.js',
),
re_path(
r'^nav.js',
DynamicJsView.as_view(template_name='js/dynamic/nav.js'),
name='nav.js',
),
re_path(
r'^permissions.js',
DynamicJsView.as_view(template_name='js/dynamic/permissions.js'),
name='permissions.js',
),
re_path(
r'^settings.js',
DynamicJsView.as_view(template_name='js/dynamic/settings.js'),
name='settings.js',
),
] ]
# These javascript files are passed through the Django translation layer # These javascript files are passed through the Django translation layer
translated_javascript_urls = [ translated_javascript_urls = [
re_path(r'^api.js', DynamicJsView.as_view(template_name='js/translated/api.js'), name='api.js'), re_path(
re_path(r'^attachment.js', DynamicJsView.as_view(template_name='js/translated/attachment.js'), name='attachment.js'), r'^api.js',
re_path(r'^barcode.js', DynamicJsView.as_view(template_name='js/translated/barcode.js'), name='barcode.js'), DynamicJsView.as_view(template_name='js/translated/api.js'),
re_path(r'^bom.js', DynamicJsView.as_view(template_name='js/translated/bom.js'), name='bom.js'), name='api.js',
re_path(r'^build.js', DynamicJsView.as_view(template_name='js/translated/build.js'), name='build.js'), ),
re_path(r'^charts.js', DynamicJsView.as_view(template_name='js/translated/charts.js'), name='charts.js'), re_path(
re_path(r'^company.js', DynamicJsView.as_view(template_name='js/translated/company.js'), name='company.js'), r'^attachment.js',
re_path(r'^filters.js', DynamicJsView.as_view(template_name='js/translated/filters.js'), name='filters.js'), DynamicJsView.as_view(template_name='js/translated/attachment.js'),
re_path(r'^forms.js', DynamicJsView.as_view(template_name='js/translated/forms.js'), name='forms.js'), name='attachment.js',
re_path(r'^helpers.js', DynamicJsView.as_view(template_name='js/translated/helpers.js'), name='helpers.js'), ),
re_path(r'^index.js', DynamicJsView.as_view(template_name='js/translated/index.js'), name='index.js'), re_path(
re_path(r'^label.js', DynamicJsView.as_view(template_name='js/translated/label.js'), name='label.js'), r'^barcode.js',
re_path(r'^model_renderers.js', DynamicJsView.as_view(template_name='js/translated/model_renderers.js'), name='model_renderers.js'), DynamicJsView.as_view(template_name='js/translated/barcode.js'),
re_path(r'^modals.js', DynamicJsView.as_view(template_name='js/translated/modals.js'), name='modals.js'), name='barcode.js',
re_path(r'^order.js', DynamicJsView.as_view(template_name='js/translated/order.js'), name='order.js'), ),
re_path(r'^part.js', DynamicJsView.as_view(template_name='js/translated/part.js'), name='part.js'), re_path(
re_path(r'^purchase_order.js', DynamicJsView.as_view(template_name='js/translated/purchase_order.js'), name='purchase_order.js'), r'^bom.js',
re_path(r'^return_order.js', DynamicJsView.as_view(template_name='js/translated/return_order.js'), name='return_order.js'), DynamicJsView.as_view(template_name='js/translated/bom.js'),
re_path(r'^report.js', DynamicJsView.as_view(template_name='js/translated/report.js'), name='report.js'), name='bom.js',
re_path(r'^sales_order.js', DynamicJsView.as_view(template_name='js/translated/sales_order.js'), name='sales_order.js'), ),
re_path(r'^search.js', DynamicJsView.as_view(template_name='js/translated/search.js'), name='search.js'), re_path(
re_path(r'^stock.js', DynamicJsView.as_view(template_name='js/translated/stock.js'), name='stock.js'), r'^build.js',
re_path(r'^status_codes.js', DynamicJsView.as_view(template_name='js/translated/status_codes.js'), name='status_codes.js'), DynamicJsView.as_view(template_name='js/translated/build.js'),
re_path(r'^plugin.js', DynamicJsView.as_view(template_name='js/translated/plugin.js'), name='plugin.js'), name='build.js',
re_path(r'^pricing.js', DynamicJsView.as_view(template_name='js/translated/pricing.js'), name='pricing.js'), ),
re_path(r'^news.js', DynamicJsView.as_view(template_name='js/translated/news.js'), name='news.js'), re_path(
re_path(r'^tables.js', DynamicJsView.as_view(template_name='js/translated/tables.js'), name='tables.js'), r'^charts.js',
re_path(r'^table_filters.js', DynamicJsView.as_view(template_name='js/translated/table_filters.js'), name='table_filters.js'), DynamicJsView.as_view(template_name='js/translated/charts.js'),
re_path(r'^notification.js', DynamicJsView.as_view(template_name='js/translated/notification.js'), name='notification.js'), name='charts.js',
),
re_path(
r'^company.js',
DynamicJsView.as_view(template_name='js/translated/company.js'),
name='company.js',
),
re_path(
r'^filters.js',
DynamicJsView.as_view(template_name='js/translated/filters.js'),
name='filters.js',
),
re_path(
r'^forms.js',
DynamicJsView.as_view(template_name='js/translated/forms.js'),
name='forms.js',
),
re_path(
r'^helpers.js',
DynamicJsView.as_view(template_name='js/translated/helpers.js'),
name='helpers.js',
),
re_path(
r'^index.js',
DynamicJsView.as_view(template_name='js/translated/index.js'),
name='index.js',
),
re_path(
r'^label.js',
DynamicJsView.as_view(template_name='js/translated/label.js'),
name='label.js',
),
re_path(
r'^model_renderers.js',
DynamicJsView.as_view(template_name='js/translated/model_renderers.js'),
name='model_renderers.js',
),
re_path(
r'^modals.js',
DynamicJsView.as_view(template_name='js/translated/modals.js'),
name='modals.js',
),
re_path(
r'^order.js',
DynamicJsView.as_view(template_name='js/translated/order.js'),
name='order.js',
),
re_path(
r'^part.js',
DynamicJsView.as_view(template_name='js/translated/part.js'),
name='part.js',
),
re_path(
r'^purchase_order.js',
DynamicJsView.as_view(template_name='js/translated/purchase_order.js'),
name='purchase_order.js',
),
re_path(
r'^return_order.js',
DynamicJsView.as_view(template_name='js/translated/return_order.js'),
name='return_order.js',
),
re_path(
r'^report.js',
DynamicJsView.as_view(template_name='js/translated/report.js'),
name='report.js',
),
re_path(
r'^sales_order.js',
DynamicJsView.as_view(template_name='js/translated/sales_order.js'),
name='sales_order.js',
),
re_path(
r'^search.js',
DynamicJsView.as_view(template_name='js/translated/search.js'),
name='search.js',
),
re_path(
r'^stock.js',
DynamicJsView.as_view(template_name='js/translated/stock.js'),
name='stock.js',
),
re_path(
r'^status_codes.js',
DynamicJsView.as_view(template_name='js/translated/status_codes.js'),
name='status_codes.js',
),
re_path(
r'^plugin.js',
DynamicJsView.as_view(template_name='js/translated/plugin.js'),
name='plugin.js',
),
re_path(
r'^pricing.js',
DynamicJsView.as_view(template_name='js/translated/pricing.js'),
name='pricing.js',
),
re_path(
r'^news.js',
DynamicJsView.as_view(template_name='js/translated/news.js'),
name='news.js',
),
re_path(
r'^tables.js',
DynamicJsView.as_view(template_name='js/translated/tables.js'),
name='tables.js',
),
re_path(
r'^table_filters.js',
DynamicJsView.as_view(template_name='js/translated/table_filters.js'),
name='table_filters.js',
),
re_path(
r'^notification.js',
DynamicJsView.as_view(template_name='js/translated/notification.js'),
name='notification.js',
),
] ]
backendpatterns = [ backendpatterns = [
# "Dynamic" javascript files which are rendered using InvenTree templating. # "Dynamic" javascript files which are rendered using InvenTree templating.
re_path(r'^js/dynamic/', include(dynamic_javascript_urls)), re_path(r'^js/dynamic/', include(dynamic_javascript_urls)),
re_path(r'^js/i18n/', include(translated_javascript_urls)), re_path(r'^js/i18n/', include(translated_javascript_urls)),
re_path(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), re_path(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
re_path(r'^auth/?', auth_request), re_path(r'^auth/?', auth_request),
re_path(r'^api/', include(apipatterns)), re_path(r'^api/', include(apipatterns)),
re_path(r'^api-doc/', SpectacularRedocView.as_view(url_name='schema'), name='api-doc'), re_path(
r'^api-doc/', SpectacularRedocView.as_view(url_name='schema'), name='api-doc'
),
] ]
classic_frontendpatterns = [ classic_frontendpatterns = [
# Apps # Apps
re_path(r'^build/', include(build_urls)), re_path(r'^build/', include(build_urls)),
re_path(r'^common/', include(common_urls)), re_path(r'^common/', include(common_urls)),
@ -189,30 +375,40 @@ classic_frontendpatterns = [
re_path(r'^part/', include(part_urls)), re_path(r'^part/', include(part_urls)),
re_path(r'^stock/', include(stock_urls)), re_path(r'^stock/', include(stock_urls)),
re_path(r'^supplier-part/', include(supplier_part_urls)), re_path(r'^supplier-part/', include(supplier_part_urls)),
re_path(r'^edit-user/', EditUserView.as_view(), name='edit-user'), re_path(r'^edit-user/', EditUserView.as_view(), name='edit-user'),
re_path(r'^set-password/', SetPasswordView.as_view(), name='set-password'), re_path(r'^set-password/', SetPasswordView.as_view(), name='set-password'),
re_path(r'^index/', IndexView.as_view(), name='index'), re_path(r'^index/', IndexView.as_view(), name='index'),
re_path(r'^notifications/', include(notifications_urls)), re_path(r'^notifications/', include(notifications_urls)),
re_path(r'^search/', SearchView.as_view(), name='search'), re_path(r'^search/', SearchView.as_view(), name='search'),
re_path(r'^settings/', include(settings_urls)), re_path(r'^settings/', include(settings_urls)),
re_path(r'^about/', AboutView.as_view(), name='about'), re_path(r'^about/', AboutView.as_view(), name='about'),
re_path(r'^stats/', DatabaseStatsView.as_view(), name='stats'), re_path(r'^stats/', DatabaseStatsView.as_view(), name='stats'),
# DB user sessions # DB user sessions
path('accounts/sessions/other/delete/', view=CustomSessionDeleteOtherView.as_view(), name='session_delete_other', ), path(
re_path(r'^accounts/sessions/(?P<pk>\w+)/delete/$', view=CustomSessionDeleteView.as_view(), name='session_delete', ), 'accounts/sessions/other/delete/',
view=CustomSessionDeleteOtherView.as_view(),
name='session_delete_other',
),
re_path(
r'^accounts/sessions/(?P<pk>\w+)/delete/$',
view=CustomSessionDeleteView.as_view(),
name='session_delete',
),
# Single Sign On / allauth # Single Sign On / allauth
# overrides of urlpatterns # overrides of urlpatterns
re_path(r'^accounts/email/', CustomEmailView.as_view(), name='account_email'), re_path(r'^accounts/email/', CustomEmailView.as_view(), name='account_email'),
re_path(r'^accounts/social/connections/', CustomConnectionsView.as_view(), name='socialaccount_connections'), re_path(
re_path(r"^accounts/password/reset/key/(?P<uidb36>[0-9A-Za-z]+)-(?P<key>.+)/$", CustomPasswordResetFromKeyView.as_view(), name="account_reset_password_from_key"), r'^accounts/social/connections/',
CustomConnectionsView.as_view(),
name='socialaccount_connections',
),
re_path(
r'^accounts/password/reset/key/(?P<uidb36>[0-9A-Za-z]+)-(?P<key>.+)/$',
CustomPasswordResetFromKeyView.as_view(),
name='account_reset_password_from_key',
),
# Override login page # Override login page
re_path("accounts/login/", CustomLoginView.as_view(), name="account_login"), re_path('accounts/login/', CustomLoginView.as_view(), name='account_login'),
re_path(r'^accounts/', include('allauth_2fa.urls')), # MFA support re_path(r'^accounts/', include('allauth_2fa.urls')), # MFA support
re_path(r'^accounts/', include('allauth.urls')), # included urlpatterns re_path(r'^accounts/', include('allauth.urls')), # included urlpatterns
] ]
@ -220,7 +416,7 @@ classic_frontendpatterns = [
urlpatterns = [] urlpatterns = []
if settings.INVENTREE_ADMIN_ENABLED: if settings.INVENTREE_ADMIN_ENABLED:
admin_url = settings.INVENTREE_ADMIN_URL, admin_url = (settings.INVENTREE_ADMIN_URL,)
urlpatterns += [ urlpatterns += [
path(f'{admin_url}/error_log/', include('error_report.urls')), path(f'{admin_url}/error_log/', include('error_report.urls')),
path(f'{admin_url}/', admin.site.urls, name='inventree-admin'), path(f'{admin_url}/', admin.site.urls, name='inventree-admin'),
@ -251,8 +447,13 @@ if settings.DEBUG:
# Redirect for favicon.ico # Redirect for favicon.ico
urlpatterns.append( urlpatterns.append(
path('favicon.ico', RedirectView.as_view(url=f'{settings.STATIC_URL}img/favicon/favicon.ico')) path(
'favicon.ico',
RedirectView.as_view(url=f'{settings.STATIC_URL}img/favicon/favicon.ico'),
)
) )
# Send any unknown URLs to the parts page # Send any unknown URLs to the parts page
urlpatterns += [re_path(r'^.*$', RedirectView.as_view(url='/index/', permanent=False), name='index')] urlpatterns += [
re_path(r'^.*$', RedirectView.as_view(url='/index/', permanent=False), name='index')
]

View File

@ -69,7 +69,9 @@ class AllowedURLValidator(validators.URLValidator):
self.schemes = allowable_url_schemes() self.schemes = allowable_url_schemes()
# Determine if 'strict' URL validation is required (i.e. if the URL must have a schema prefix) # Determine if 'strict' URL validation is required (i.e. if the URL must have a schema prefix)
strict_urls = common.models.InvenTreeSetting.get_setting('INVENTREE_STRICT_URLS', True, cache=False) strict_urls = common.models.InvenTreeSetting.get_setting(
'INVENTREE_STRICT_URLS', True, cache=False
)
if not strict_urls: if not strict_urls:
# Allow URLs which do not have a provided schema # Allow URLs which do not have a provided schema
@ -117,7 +119,7 @@ def validate_overage(value):
i = Decimal(value) i = Decimal(value)
if i < 0: if i < 0:
raise ValidationError(_("Overage value must not be negative")) raise ValidationError(_('Overage value must not be negative'))
# Looks like a number # Looks like a number
return True return True
@ -133,17 +135,15 @@ def validate_overage(value):
f = float(v) f = float(v)
if f < 0: if f < 0:
raise ValidationError(_("Overage value must not be negative")) raise ValidationError(_('Overage value must not be negative'))
elif f > 100: elif f > 100:
raise ValidationError(_("Overage must not exceed 100%")) raise ValidationError(_('Overage must not exceed 100%'))
return True return True
except ValueError: except ValueError:
pass pass
raise ValidationError( raise ValidationError(_('Invalid value for overage'))
_("Invalid value for overage")
)
def validate_part_name_format(value): def validate_part_name_format(value):
@ -179,8 +179,6 @@ def validate_part_name_format(value):
try: try:
Template(value).render({'part': p}) Template(value).render({'part': p})
except Exception as exc: except Exception as exc:
raise ValidationError({ raise ValidationError({'value': str(exc)})
'value': str(exc)
})
return True return True

View File

@ -19,7 +19,7 @@ from dulwich.repo import NotGitRepository, Repo
from .api_version import INVENTREE_API_TEXT, INVENTREE_API_VERSION from .api_version import INVENTREE_API_TEXT, INVENTREE_API_VERSION
# InvenTree software version # InvenTree software version
INVENTREE_SW_VERSION = "0.14.0 dev" INVENTREE_SW_VERSION = '0.14.0 dev'
# Discover git # Discover git
try: try:
@ -32,8 +32,8 @@ except (NotGitRepository, FileNotFoundError):
def checkMinPythonVersion(): def checkMinPythonVersion():
"""Check that the Python version is at least 3.9""" """Check that the Python version is at least 3.9"""
version = sys.version.split(" ")[0] version = sys.version.split(' ')[0]
docs = "https://docs.inventree.org/en/stable/start/intro/#python-requirements" docs = 'https://docs.inventree.org/en/stable/start/intro/#python-requirements'
msg = f""" msg = f"""
InvenTree requires Python 3.9 or above - you are running version {version}. InvenTree requires Python 3.9 or above - you are running version {version}.
@ -47,22 +47,22 @@ def checkMinPythonVersion():
if sys.version_info.major == 3 and sys.version_info.minor < 9: if sys.version_info.major == 3 and sys.version_info.minor < 9:
raise RuntimeError(msg) raise RuntimeError(msg)
print(f"Python version {version} - {sys.executable}") print(f'Python version {version} - {sys.executable}')
def inventreeInstanceName(): def inventreeInstanceName():
"""Returns the InstanceName settings for the current database.""" """Returns the InstanceName settings for the current database."""
import common.models import common.models
return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "") return common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE', '')
def inventreeInstanceTitle(): def inventreeInstanceTitle():
"""Returns the InstanceTitle for the current database.""" """Returns the InstanceTitle for the current database."""
import common.models import common.models
if common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE_TITLE", False): if common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE_TITLE', False):
return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "") return common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE', '')
return 'InvenTree' return 'InvenTree'
@ -76,7 +76,7 @@ def inventreeVersionTuple(version=None):
if version is None: if version is None:
version = INVENTREE_SW_VERSION version = INVENTREE_SW_VERSION
match = re.match(r"^.*(\d+)\.(\d+)\.(\d+).*$", str(version)) match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', str(version))
return [int(g) for g in match.groups()] return [int(g) for g in match.groups()]
@ -93,14 +93,14 @@ def inventreeDocsVersion():
Release -> "major.minor.sub" e.g. "0.5.2" Release -> "major.minor.sub" e.g. "0.5.2"
""" """
if isInvenTreeDevelopmentVersion(): if isInvenTreeDevelopmentVersion():
return "latest" return 'latest'
return INVENTREE_SW_VERSION # pragma: no cover return INVENTREE_SW_VERSION # pragma: no cover
def inventreeDocUrl(): def inventreeDocUrl():
"""Return URL for InvenTree documentation site.""" """Return URL for InvenTree documentation site."""
tag = inventreeDocsVersion() tag = inventreeDocsVersion()
return f"https://docs.inventree.org/en/{tag}" return f'https://docs.inventree.org/en/{tag}'
def inventreeAppUrl(): def inventreeAppUrl():
@ -110,12 +110,12 @@ def inventreeAppUrl():
def inventreeCreditsUrl(): def inventreeCreditsUrl():
"""Return URL for InvenTree credits site.""" """Return URL for InvenTree credits site."""
return "https://docs.inventree.org/en/latest/credits/" return 'https://docs.inventree.org/en/latest/credits/'
def inventreeGithubUrl(): def inventreeGithubUrl():
"""Return URL for InvenTree github site.""" """Return URL for InvenTree github site."""
return "https://github.com/InvenTree/InvenTree/" return 'https://github.com/InvenTree/InvenTree/'
def isInvenTreeUpToDate(): def isInvenTreeUpToDate():
@ -124,7 +124,10 @@ def isInvenTreeUpToDate():
A background task periodically queries GitHub for latest version, and stores it to the database as "_INVENTREE_LATEST_VERSION" A background task periodically queries GitHub for latest version, and stores it to the database as "_INVENTREE_LATEST_VERSION"
""" """
import common.models import common.models
latest = common.models.InvenTreeSetting.get_setting('_INVENTREE_LATEST_VERSION', backup_value=None, create=False)
latest = common.models.InvenTreeSetting.get_setting(
'_INVENTREE_LATEST_VERSION', backup_value=None, create=False
)
# No record for "latest" version - we must assume we are up to date! # No record for "latest" version - we must assume we are up to date!
if not latest: if not latest:
@ -144,24 +147,26 @@ def inventreeApiVersion():
def parse_version_text(): def parse_version_text():
"""Parse the version text to structured data.""" """Parse the version text to structured data."""
patched_data = INVENTREE_API_TEXT.split("\n\n") patched_data = INVENTREE_API_TEXT.split('\n\n')
# Remove first newline on latest version # Remove first newline on latest version
patched_data[0] = patched_data[0].replace("\n", "", 1) patched_data[0] = patched_data[0].replace('\n', '', 1)
version_data = {} version_data = {}
for version in patched_data: for version in patched_data:
data = version.split("\n") data = version.split('\n')
version_split = data[0].split(' -> ') version_split = data[0].split(' -> ')
version_detail = version_split[1].split(':', 1) if len(version_split) > 1 else ['', ] version_detail = (
version_split[1].split(':', 1) if len(version_split) > 1 else ['']
)
new_data = { new_data = {
"version": version_split[0].strip(), 'version': version_split[0].strip(),
"date": version_detail[0].strip(), 'date': version_detail[0].strip(),
"gh": version_detail[1].strip() if len(version_detail) > 1 else None, 'gh': version_detail[1].strip() if len(version_detail) > 1 else None,
"text": data[1:], 'text': data[1:],
"latest": False, 'latest': False,
} }
version_data[new_data["version"]] = new_data version_data[new_data['version']] = new_data
return version_data return version_data
@ -183,7 +188,7 @@ def inventreeApiText(versions: int = 10, start_version: int = 0):
start_version = INVENTREE_API_VERSION - versions start_version = INVENTREE_API_VERSION - versions
return { return {
f"v{a}": version_data.get(f"v{a}", None) f'v{a}': version_data.get(f'v{a}', None)
for a in range(start_version, start_version + versions) for a in range(start_version, start_version + versions)
} }
@ -222,7 +227,9 @@ def inventreeCommitDate():
if main_commit is None: if main_commit is None:
return None return None
commit_dt = dt.fromtimestamp(main_commit.commit_time) + td(seconds=main_commit.commit_timezone) commit_dt = dt.fromtimestamp(main_commit.commit_time) + td(
seconds=main_commit.commit_timezone
)
return str(commit_dt.date()) return str(commit_dt.date())

View File

@ -5,8 +5,7 @@ as JSON objects and passing them to modal forms (using jQuery / bootstrap).
""" """
from django.contrib.auth import password_validation from django.contrib.auth import password_validation
from django.contrib.auth.mixins import (LoginRequiredMixin, from django.contrib.auth.mixins import LoginRequiredMixin, PermissionRequiredMixin
PermissionRequiredMixin)
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.http import HttpResponse, HttpResponseRedirect, JsonResponse from django.http import HttpResponse, HttpResponseRedirect, JsonResponse
from django.shortcuts import redirect from django.shortcuts import redirect
@ -20,8 +19,7 @@ from django.views.generic.base import RedirectView, TemplateView
from allauth.account.forms import AddEmailForm from allauth.account.forms import AddEmailForm
from allauth.account.models import EmailAddress from allauth.account.models import EmailAddress
from allauth.account.views import (EmailView, LoginView, from allauth.account.views import EmailView, LoginView, PasswordResetFromKeyView
PasswordResetFromKeyView)
from allauth.socialaccount.forms import DisconnectForm from allauth.socialaccount.forms import DisconnectForm
from allauth.socialaccount.views import ConnectionsView from allauth.socialaccount.views import ConnectionsView
from djmoney.contrib.exchange.models import ExchangeBackend, Rate from djmoney.contrib.exchange.models import ExchangeBackend, Rate
@ -104,7 +102,6 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
return True return True
for required in roles_required: for required in roles_required:
(role, permission) = required.split('.') (role, permission) = required.split('.')
if role not in RuleSet.RULESET_NAMES: if role not in RuleSet.RULESET_NAMES:
@ -138,12 +135,14 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
app_label = model._meta.app_label app_label = model._meta.app_label
model_name = model._meta.model_name model_name = model._meta.model_name
table = f"{app_label}_{model_name}" table = f'{app_label}_{model_name}'
permission = self.get_permission_class() permission = self.get_permission_class()
if not permission: if not permission:
raise AttributeError(f"permission_class not defined for {type(self).__name__}") raise AttributeError(
f'permission_class not defined for {type(self).__name__}'
)
# Check if the user has the required permission # Check if the user has the required permission
return RuleSet.check_table_permission(user, table, permission) return RuleSet.check_table_permission(user, table, permission)
@ -182,7 +181,6 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
} }
for view_class in permission_map.keys(): for view_class in permission_map.keys():
if issubclass(type(self), view_class): if issubclass(type(self), view_class):
return permission_map[view_class] return permission_map[view_class]
@ -282,9 +280,7 @@ class AjaxMixin(InvenTreeRoleMixin):
data['title'] = self.get_form_title() data['title'] = self.get_form_title()
data['html_form'] = render_to_string( data['html_form'] = render_to_string(
self.ajax_template_name, self.ajax_template_name, context, request=request
context,
request=request
) )
# Custom feedback`data # Custom feedback`data
@ -329,7 +325,9 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
""" """
super(UpdateView, self).get(request, *args, **kwargs) super(UpdateView, self).get(request, *args, **kwargs)
return self.renderJsonResponse(request, self.get_form(), context=self.get_context_data()) return self.renderJsonResponse(
request, self.get_form(), context=self.get_context_data()
)
def save(self, object, form, **kwargs): def save(self, object, form, **kwargs):
"""Method for updating the object in the database. Default implementation is very simple, but can be overridden if required. """Method for updating the object in the database. Default implementation is very simple, but can be overridden if required.
@ -379,7 +377,6 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
data[key] = value data[key] = value
if valid: if valid:
# Save the updated object to the database # Save the updated object to the database
self.save(self.object, form) self.save(self.object, form)
@ -399,8 +396,8 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
class EditUserView(AjaxUpdateView): class EditUserView(AjaxUpdateView):
"""View for editing user information.""" """View for editing user information."""
ajax_template_name = "modal_form.html" ajax_template_name = 'modal_form.html'
ajax_form_title = _("Edit User Information") ajax_form_title = _('Edit User Information')
form_class = EditUserForm form_class = EditUserForm
def get_object(self): def get_object(self):
@ -411,8 +408,8 @@ class EditUserView(AjaxUpdateView):
class SetPasswordView(AjaxUpdateView): class SetPasswordView(AjaxUpdateView):
"""View for setting user password.""" """View for setting user password."""
ajax_template_name = "InvenTree/password.html" ajax_template_name = 'InvenTree/password.html'
ajax_form_title = _("Set Password") ajax_form_title = _('Set Password')
form_class = SetPasswordForm form_class = SetPasswordForm
def get_object(self): def get_object(self):
@ -494,14 +491,14 @@ class SearchView(TemplateView):
class DynamicJsView(TemplateView): class DynamicJsView(TemplateView):
"""View for returning javacsript files, which instead of being served dynamically, are passed through the django translation engine!""" """View for returning javacsript files, which instead of being served dynamically, are passed through the django translation engine!"""
template_name = "" template_name = ''
content_type = 'text/javascript' content_type = 'text/javascript'
class SettingsView(TemplateView): class SettingsView(TemplateView):
"""View for configuring User settings.""" """View for configuring User settings."""
template_name = "InvenTree/settings/settings.html" template_name = 'InvenTree/settings/settings.html'
def get_context_data(self, **kwargs): def get_context_data(self, **kwargs):
"""Add data for template.""" """Add data for template."""
@ -509,33 +506,37 @@ class SettingsView(TemplateView):
ctx['settings'] = common_models.InvenTreeSetting.objects.all().order_by('key') ctx['settings'] = common_models.InvenTreeSetting.objects.all().order_by('key')
ctx["base_currency"] = common_settings.currency_code_default() ctx['base_currency'] = common_settings.currency_code_default()
ctx["currencies"] = common_settings.currency_codes ctx['currencies'] = common_settings.currency_codes
ctx["rates"] = Rate.objects.filter(backend="InvenTreeExchange") ctx['rates'] = Rate.objects.filter(backend='InvenTreeExchange')
ctx["categories"] = PartCategory.objects.all().order_by('tree_id', 'lft', 'name') ctx['categories'] = PartCategory.objects.all().order_by(
'tree_id', 'lft', 'name'
)
# When were the rates last updated? # When were the rates last updated?
try: try:
backend = ExchangeBackend.objects.filter(name='InvenTreeExchange') backend = ExchangeBackend.objects.filter(name='InvenTreeExchange')
if backend.exists(): if backend.exists():
backend = backend.first() backend = backend.first()
ctx["rates_updated"] = backend.last_update ctx['rates_updated'] = backend.last_update
except Exception: except Exception:
ctx["rates_updated"] = None ctx['rates_updated'] = None
# Forms and context for allauth # Forms and context for allauth
ctx['add_email_form'] = AddEmailForm ctx['add_email_form'] = AddEmailForm
ctx["can_add_email"] = EmailAddress.objects.can_add_email(self.request.user) ctx['can_add_email'] = EmailAddress.objects.can_add_email(self.request.user)
# Form and context for allauth social-accounts # Form and context for allauth social-accounts
ctx["request"] = self.request ctx['request'] = self.request
ctx['social_form'] = DisconnectForm(request=self.request) ctx['social_form'] = DisconnectForm(request=self.request)
# user db sessions # user db sessions
ctx['session_key'] = self.request.session.session_key ctx['session_key'] = self.request.session.session_key
ctx['session_list'] = self.request.user.session_set.filter(expire_date__gt=now()).order_by('-last_activity') ctx['session_list'] = self.request.user.session_set.filter(
expire_date__gt=now()
).order_by('-last_activity')
return ctx return ctx
@ -550,20 +551,23 @@ class AllauthOverrides(LoginRequiredMixin):
class CustomEmailView(AllauthOverrides, EmailView): class CustomEmailView(AllauthOverrides, EmailView):
"""Override of allauths EmailView to always show the settings but leave the functions allow.""" """Override of allauths EmailView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("settings")
success_url = reverse_lazy('settings')
class CustomConnectionsView(AllauthOverrides, ConnectionsView): class CustomConnectionsView(AllauthOverrides, ConnectionsView):
"""Override of allauths ConnectionsView to always show the settings but leave the functions allow.""" """Override of allauths ConnectionsView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("settings")
success_url = reverse_lazy('settings')
class CustomPasswordResetFromKeyView(PasswordResetFromKeyView): class CustomPasswordResetFromKeyView(PasswordResetFromKeyView):
"""Override of allauths PasswordResetFromKeyView to always show the settings but leave the functions allow.""" """Override of allauths PasswordResetFromKeyView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("account_login")
success_url = reverse_lazy('account_login')
class UserSessionOverride(): class UserSessionOverride:
"""Overrides sucessurl to lead to settings.""" """Overrides sucessurl to lead to settings."""
def get_success_url(self): def get_success_url(self):
@ -573,11 +577,13 @@ class UserSessionOverride():
class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView): class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView):
"""Revert to settings after session delete.""" """Revert to settings after session delete."""
pass pass
class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView): class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView):
"""Revert to settings after session delete.""" """Revert to settings after session delete."""
pass pass
@ -607,7 +613,9 @@ class AppearanceSelectView(RedirectView):
def get_user_theme(self): def get_user_theme(self):
"""Get current user color theme.""" """Get current user color theme."""
try: try:
user_theme = common_models.ColorTheme.objects.filter(user=self.request.user).get() user_theme = common_models.ColorTheme.objects.filter(
user=self.request.user
).get()
except common_models.ColorTheme.DoesNotExist: except common_models.ColorTheme.DoesNotExist:
user_theme = None user_theme = None
@ -638,18 +646,18 @@ class AppearanceSelectView(RedirectView):
class DatabaseStatsView(AjaxView): class DatabaseStatsView(AjaxView):
"""View for displaying database statistics.""" """View for displaying database statistics."""
ajax_template_name = "stats.html" ajax_template_name = 'stats.html'
ajax_form_title = _("System Information") ajax_form_title = _('System Information')
class AboutView(AjaxView): class AboutView(AjaxView):
"""A view for displaying InvenTree version information""" """A view for displaying InvenTree version information"""
ajax_template_name = "about.html" ajax_template_name = 'about.html'
ajax_form_title = _("About InvenTree") ajax_form_title = _('About InvenTree')
class NotificationsView(TemplateView): class NotificationsView(TemplateView):
"""View for showing notifications.""" """View for showing notifications."""
template_name = "InvenTree/notifications/notifications.html" template_name = 'InvenTree/notifications/notifications.html'

View File

@ -10,6 +10,8 @@ import os # pragma: no cover
from django.core.wsgi import get_wsgi_application # pragma: no cover from django.core.wsgi import get_wsgi_application # pragma: no cover
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "InvenTree.settings") # pragma: no cover os.environ.setdefault(
'DJANGO_SETTINGS_MODULE', 'InvenTree.settings'
) # pragma: no cover
application = get_wsgi_application() # pragma: no cover application = get_wsgi_application() # pragma: no cover

View File

@ -22,7 +22,7 @@ class SettingsAdmin(ImportExportModelAdmin):
class UserSettingsAdmin(ImportExportModelAdmin): class UserSettingsAdmin(ImportExportModelAdmin):
"""Admin settings for InvenTreeUserSetting.""" """Admin settings for InvenTreeUserSetting."""
list_display = ('key', 'value', 'user', ) list_display = ('key', 'value', 'user')
def get_readonly_fields(self, request, obj=None): # pragma: no cover def get_readonly_fields(self, request, obj=None): # pragma: no cover
"""Prevent the 'key' field being edited once the setting is created.""" """Prevent the 'key' field being edited once the setting is created."""
@ -40,23 +40,31 @@ class WebhookAdmin(ImportExportModelAdmin):
class NotificationEntryAdmin(admin.ModelAdmin): class NotificationEntryAdmin(admin.ModelAdmin):
"""Admin settings for NotificationEntry.""" """Admin settings for NotificationEntry."""
list_display = ('key', 'uid', 'updated', ) list_display = ('key', 'uid', 'updated')
class NotificationMessageAdmin(admin.ModelAdmin): class NotificationMessageAdmin(admin.ModelAdmin):
"""Admin settings for NotificationMessage.""" """Admin settings for NotificationMessage."""
list_display = ('age_human', 'user', 'category', 'name', 'read', 'target_object', 'source_object', ) list_display = (
'age_human',
'user',
'category',
'name',
'read',
'target_object',
'source_object',
)
list_filter = ('category', 'read', 'user', ) list_filter = ('category', 'read', 'user')
search_fields = ('name', 'category', 'message', ) search_fields = ('name', 'category', 'message')
class NewsFeedEntryAdmin(admin.ModelAdmin): class NewsFeedEntryAdmin(admin.ModelAdmin):
"""Admin settings for NewsFeedEntry.""" """Admin settings for NewsFeedEntry."""
list_display = ('title', 'author', 'published', 'summary', ) list_display = ('title', 'author', 'published', 'summary')
admin.site.register(common.models.InvenTreeSetting, SettingsAdmin) admin.site.register(common.models.InvenTreeSetting, SettingsAdmin)

View File

@ -23,8 +23,13 @@ from InvenTree.api import BulkDeleteMixin, MetadataView
from InvenTree.config import CONFIG_LOOKUPS from InvenTree.config import CONFIG_LOOKUPS
from InvenTree.filters import ORDER_FILTER, SEARCH_ORDER_FILTER from InvenTree.filters import ORDER_FILTER, SEARCH_ORDER_FILTER
from InvenTree.helpers import inheritors from InvenTree.helpers import inheritors
from InvenTree.mixins import (ListAPI, ListCreateAPI, RetrieveAPI, from InvenTree.mixins import (
RetrieveUpdateAPI, RetrieveUpdateDestroyAPI) ListAPI,
ListCreateAPI,
RetrieveAPI,
RetrieveUpdateAPI,
RetrieveUpdateDestroyAPI,
)
from InvenTree.permissions import IsStaffOrReadOnly, IsSuperuser from InvenTree.permissions import IsStaffOrReadOnly, IsSuperuser
from plugin.models import NotificationUserSetting from plugin.models import NotificationUserSetting
from plugin.serializers import NotificationUserSettingSerializer from plugin.serializers import NotificationUserSettingSerializer
@ -41,6 +46,7 @@ class CsrfExemptMixin(object):
class WebhookView(CsrfExemptMixin, APIView): class WebhookView(CsrfExemptMixin, APIView):
"""Endpoint for receiving webhooks.""" """Endpoint for receiving webhooks."""
authentication_classes = [] authentication_classes = []
permission_classes = [] permission_classes = []
model_class = common.models.WebhookEndpoint model_class = common.models.WebhookEndpoint
@ -66,8 +72,7 @@ class WebhookView(CsrfExemptMixin, APIView):
async_task(self._process_payload, message.id) async_task(self._process_payload, message.id)
else: else:
self._process_result( self._process_result(
self.webhook.process_payload(message, payload, headers), self.webhook.process_payload(message, payload, headers), message
message,
) )
data = self.webhook.get_return(payload, headers, request) data = self.webhook.get_return(payload, headers, request)
@ -76,8 +81,7 @@ class WebhookView(CsrfExemptMixin, APIView):
def _process_payload(self, message_id): def _process_payload(self, message_id):
message = common.models.WebhookMessage.objects.get(message_id=message_id) message = common.models.WebhookMessage.objects.get(message_id=message_id)
self._process_result( self._process_result(
self.webhook.process_payload(message, message.body, message.header), self.webhook.process_payload(message, message.body, message.header), message
message,
) )
def _process_result(self, result, message): def _process_result(self, result, message):
@ -108,9 +112,7 @@ class WebhookView(CsrfExemptMixin, APIView):
class CurrencyExchangeView(APIView): class CurrencyExchangeView(APIView):
"""API endpoint for displaying currency information""" """API endpoint for displaying currency information"""
permission_classes = [ permission_classes = [permissions.IsAuthenticated]
permissions.IsAuthenticated,
]
def get(self, request, format=None): def get(self, request, format=None):
"""Return information on available currency conversions""" """Return information on available currency conversions"""
@ -133,7 +135,9 @@ class CurrencyExchangeView(APIView):
updated = None updated = None
response = { response = {
'base_currency': common.models.InvenTreeSetting.get_setting('INVENTREE_DEFAULT_CURRENCY', 'USD'), 'base_currency': common.models.InvenTreeSetting.get_setting(
'INVENTREE_DEFAULT_CURRENCY', 'USD'
),
'exchange_rates': {}, 'exchange_rates': {},
'updated': updated, 'updated': updated,
} }
@ -150,10 +154,7 @@ class CurrencyRefreshView(APIView):
User must be a 'staff' user to access this endpoint User must be a 'staff' user to access this endpoint
""" """
permission_classes = [ permission_classes = [permissions.IsAuthenticated, permissions.IsAdminUser]
permissions.IsAuthenticated,
permissions.IsAdminUser,
]
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
"""Performing a POST request will update currency exchange rates""" """Performing a POST request will update currency exchange rates"""
@ -161,9 +162,7 @@ class CurrencyRefreshView(APIView):
update_exchange_rates(force=True) update_exchange_rates(force=True)
return Response({ return Response({'success': 'Exchange rates updated'})
'success': 'Exchange rates updated',
})
class SettingsList(ListAPI): class SettingsList(ListAPI):
@ -174,21 +173,15 @@ class SettingsList(ListAPI):
filter_backends = SEARCH_ORDER_FILTER filter_backends = SEARCH_ORDER_FILTER
ordering_fields = [ ordering_fields = ['pk', 'key', 'name']
'pk',
'key',
'name',
]
search_fields = [ search_fields = ['key']
'key',
]
class GlobalSettingsList(SettingsList): class GlobalSettingsList(SettingsList):
"""API endpoint for accessing a list of global settings objects.""" """API endpoint for accessing a list of global settings objects."""
queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith="_") queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith='_')
serializer_class = common.serializers.GlobalSettingsSerializer serializer_class = common.serializers.GlobalSettingsSerializer
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
@ -221,25 +214,24 @@ class GlobalSettingsDetail(RetrieveUpdateAPI):
""" """
lookup_field = 'key' lookup_field = 'key'
queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith="_") queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith='_')
serializer_class = common.serializers.GlobalSettingsSerializer serializer_class = common.serializers.GlobalSettingsSerializer
def get_object(self): def get_object(self):
"""Attempt to find a global setting object with the provided key.""" """Attempt to find a global setting object with the provided key."""
key = str(self.kwargs['key']).upper() key = str(self.kwargs['key']).upper()
if key.startswith('_') or key not in common.models.InvenTreeSetting.SETTINGS.keys(): if (
key.startswith('_')
or key not in common.models.InvenTreeSetting.SETTINGS.keys()
):
raise NotFound() raise NotFound()
return common.models.InvenTreeSetting.get_setting_object( return common.models.InvenTreeSetting.get_setting_object(
key, key, cache=False, create=True
cache=False, create=True
) )
permission_classes = [ permission_classes = [permissions.IsAuthenticated, GlobalSettingsPermissions]
permissions.IsAuthenticated,
GlobalSettingsPermissions,
]
class UserSettingsList(SettingsList): class UserSettingsList(SettingsList):
@ -294,18 +286,17 @@ class UserSettingsDetail(RetrieveUpdateAPI):
"""Attempt to find a user setting object with the provided key.""" """Attempt to find a user setting object with the provided key."""
key = str(self.kwargs['key']).upper() key = str(self.kwargs['key']).upper()
if key.startswith('_') or key not in common.models.InvenTreeUserSetting.SETTINGS.keys(): if (
key.startswith('_')
or key not in common.models.InvenTreeUserSetting.SETTINGS.keys()
):
raise NotFound() raise NotFound()
return common.models.InvenTreeUserSetting.get_setting_object( return common.models.InvenTreeUserSetting.get_setting_object(
key, key, user=self.request.user, cache=False, create=True
user=self.request.user,
cache=False, create=True
) )
permission_classes = [ permission_classes = [UserSettingsPermissions]
UserSettingsPermissions,
]
class NotificationUserSettingsList(SettingsList): class NotificationUserSettingsList(SettingsList):
@ -334,39 +325,29 @@ class NotificationUserSettingsDetail(RetrieveUpdateAPI):
queryset = NotificationUserSetting.objects.all() queryset = NotificationUserSetting.objects.all()
serializer_class = NotificationUserSettingSerializer serializer_class = NotificationUserSettingSerializer
permission_classes = [UserSettingsPermissions, ] permission_classes = [UserSettingsPermissions]
class NotificationMessageMixin: class NotificationMessageMixin:
"""Generic mixin for NotificationMessage.""" """Generic mixin for NotificationMessage."""
queryset = common.models.NotificationMessage.objects.all() queryset = common.models.NotificationMessage.objects.all()
serializer_class = common.serializers.NotificationMessageSerializer serializer_class = common.serializers.NotificationMessageSerializer
permission_classes = [UserSettingsPermissions, ] permission_classes = [UserSettingsPermissions]
class NotificationList(NotificationMessageMixin, BulkDeleteMixin, ListAPI): class NotificationList(NotificationMessageMixin, BulkDeleteMixin, ListAPI):
"""List view for all notifications of the current user.""" """List view for all notifications of the current user."""
permission_classes = [permissions.IsAuthenticated, ] permission_classes = [permissions.IsAuthenticated]
filter_backends = SEARCH_ORDER_FILTER filter_backends = SEARCH_ORDER_FILTER
ordering_fields = [ ordering_fields = ['category', 'name', 'read', 'creation']
'category',
'name',
'read',
'creation',
]
search_fields = [ search_fields = ['name', 'message']
'name',
'message',
]
filterset_fields = [ filterset_fields = ['category', 'read']
'category',
'read',
]
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
"""Only list notifications which apply to the current user.""" """Only list notifications which apply to the current user."""
@ -401,29 +382,27 @@ class NotificationReadAll(NotificationMessageMixin, RetrieveAPI):
self.queryset.filter(user=request.user, read=False).update(read=True) self.queryset.filter(user=request.user, read=False).update(read=True)
return Response({'status': 'ok'}) return Response({'status': 'ok'})
except Exception as exc: except Exception as exc:
raise serializers.ValidationError(detail=serializers.as_serializer_error(exc)) raise serializers.ValidationError(
detail=serializers.as_serializer_error(exc)
)
class NewsFeedMixin: class NewsFeedMixin:
"""Generic mixin for NewsFeedEntry.""" """Generic mixin for NewsFeedEntry."""
queryset = common.models.NewsFeedEntry.objects.all() queryset = common.models.NewsFeedEntry.objects.all()
serializer_class = common.serializers.NewsFeedEntrySerializer serializer_class = common.serializers.NewsFeedEntrySerializer
permission_classes = [IsAdminUser, ] permission_classes = [IsAdminUser]
class NewsFeedEntryList(NewsFeedMixin, BulkDeleteMixin, ListAPI): class NewsFeedEntryList(NewsFeedMixin, BulkDeleteMixin, ListAPI):
"""List view for all news items.""" """List view for all news items."""
filter_backends = ORDER_FILTER filter_backends = ORDER_FILTER
ordering_fields = [ ordering_fields = ['published', 'author', 'read']
'published',
'author',
'read',
]
filterset_fields = [ filterset_fields = ['read']
'read',
]
class NewsFeedEntryDetail(NewsFeedMixin, RetrieveUpdateDestroyAPI): class NewsFeedEntryDetail(NewsFeedMixin, RetrieveUpdateDestroyAPI):
@ -435,14 +414,14 @@ class ConfigList(ListAPI):
queryset = CONFIG_LOOKUPS queryset = CONFIG_LOOKUPS
serializer_class = common.serializers.ConfigSerializer serializer_class = common.serializers.ConfigSerializer
permission_classes = [IsSuperuser, ] permission_classes = [IsSuperuser]
class ConfigDetail(RetrieveAPI): class ConfigDetail(RetrieveAPI):
"""Detail view for an individual configuration.""" """Detail view for an individual configuration."""
serializer_class = common.serializers.ConfigSerializer serializer_class = common.serializers.ConfigSerializer
permission_classes = [IsSuperuser, ] permission_classes = [IsSuperuser]
def get_object(self): def get_object(self):
"""Attempt to find a config object with the provided key.""" """Attempt to find a config object with the provided key."""
@ -458,7 +437,7 @@ class NotesImageList(ListCreateAPI):
queryset = common.models.NotesImage.objects.all() queryset = common.models.NotesImage.objects.all()
serializer_class = common.serializers.NotesImageSerializer serializer_class = common.serializers.NotesImageSerializer
permission_classes = [permissions.IsAuthenticated, ] permission_classes = [permissions.IsAuthenticated]
def perform_create(self, serializer): def perform_create(self, serializer):
"""Create (upload) a new notes image""" """Create (upload) a new notes image"""
@ -475,14 +454,9 @@ class ProjectCodeList(ListCreateAPI):
permission_classes = [permissions.IsAuthenticated, IsStaffOrReadOnly] permission_classes = [permissions.IsAuthenticated, IsStaffOrReadOnly]
filter_backends = SEARCH_ORDER_FILTER filter_backends = SEARCH_ORDER_FILTER
ordering_fields = [ ordering_fields = ['code']
'code',
]
search_fields = [ search_fields = ['code', 'description']
'code',
'description',
]
class ProjectCodeDetail(RetrieveUpdateDestroyAPI): class ProjectCodeDetail(RetrieveUpdateDestroyAPI):
@ -515,14 +489,14 @@ class FlagList(ListAPI):
queryset = settings.FLAGS queryset = settings.FLAGS
serializer_class = common.serializers.FlagSerializer serializer_class = common.serializers.FlagSerializer
permission_classes = [permissions.AllowAny, ] permission_classes = [permissions.AllowAny]
class FlagDetail(RetrieveAPI): class FlagDetail(RetrieveAPI):
"""Detail view for an individual feature flag.""" """Detail view for an individual feature flag."""
serializer_class = common.serializers.FlagSerializer serializer_class = common.serializers.FlagSerializer
permission_classes = [permissions.AllowAny, ] permission_classes = [permissions.AllowAny]
def get_object(self): def get_object(self):
"""Attempt to find a config object with the provided key.""" """Attempt to find a config object with the provided key."""
@ -535,97 +509,175 @@ class FlagDetail(RetrieveAPI):
settings_api_urls = [ settings_api_urls = [
# User settings # User settings
re_path(r'^user/', include([ re_path(
r'^user/',
include([
# User Settings Detail # User Settings Detail
re_path(r'^(?P<key>\w+)/', UserSettingsDetail.as_view(), name='api-user-setting-detail'), re_path(
r'^(?P<key>\w+)/',
UserSettingsDetail.as_view(),
name='api-user-setting-detail',
),
# User Settings List # User Settings List
re_path(r'^.*$', UserSettingsList.as_view(), name='api-user-setting-list'), re_path(r'^.*$', UserSettingsList.as_view(), name='api-user-setting-list'),
])), ]),
),
# Notification settings # Notification settings
re_path(r'^notification/', include([ re_path(
r'^notification/',
include([
# Notification Settings Detail # Notification Settings Detail
path(r'<int:pk>/', NotificationUserSettingsDetail.as_view(), name='api-notification-setting-detail'), path(
r'<int:pk>/',
NotificationUserSettingsDetail.as_view(),
name='api-notification-setting-detail',
),
# Notification Settings List # Notification Settings List
re_path(r'^.*$', NotificationUserSettingsList.as_view(), name='api-notification-setting-list'), re_path(
])), r'^.*$',
NotificationUserSettingsList.as_view(),
name='api-notification-setting-list',
),
]),
),
# Global settings # Global settings
re_path(r'^global/', include([ re_path(
r'^global/',
include([
# Global Settings Detail # Global Settings Detail
re_path(r'^(?P<key>\w+)/', GlobalSettingsDetail.as_view(), name='api-global-setting-detail'), re_path(
r'^(?P<key>\w+)/',
GlobalSettingsDetail.as_view(),
name='api-global-setting-detail',
),
# Global Settings List # Global Settings List
re_path(r'^.*$', GlobalSettingsList.as_view(), name='api-global-setting-list'), re_path(
])), r'^.*$', GlobalSettingsList.as_view(), name='api-global-setting-list'
),
]),
),
] ]
common_api_urls = [ common_api_urls = [
# Webhooks # Webhooks
path('webhook/<slug:endpoint>/', WebhookView.as_view(), name='api-webhook'), path('webhook/<slug:endpoint>/', WebhookView.as_view(), name='api-webhook'),
# Uploaded images for notes # Uploaded images for notes
re_path(r'^notes-image-upload/', NotesImageList.as_view(), name='api-notes-image-list'), re_path(
r'^notes-image-upload/', NotesImageList.as_view(), name='api-notes-image-list'
),
# Project codes # Project codes
re_path(r'^project-code/', include([ re_path(
path(r'<int:pk>/', include([ r'^project-code/',
re_path(r'^metadata/', MetadataView.as_view(), {'model': common.models.ProjectCode}, name='api-project-code-metadata'), include([
re_path(r'^.*$', ProjectCodeDetail.as_view(), name='api-project-code-detail'), path(
])), r'<int:pk>/',
include([
re_path(
r'^metadata/',
MetadataView.as_view(),
{'model': common.models.ProjectCode},
name='api-project-code-metadata',
),
re_path(
r'^.*$',
ProjectCodeDetail.as_view(),
name='api-project-code-detail',
),
]),
),
re_path(r'^.*$', ProjectCodeList.as_view(), name='api-project-code-list'), re_path(r'^.*$', ProjectCodeList.as_view(), name='api-project-code-list'),
])), ]),
),
# Custom physical units # Custom physical units
re_path(r'^units/', include([ re_path(
path(r'<int:pk>/', include([ r'^units/',
re_path(r'^.*$', CustomUnitDetail.as_view(), name='api-custom-unit-detail'), include([
])), path(
r'<int:pk>/',
include([
re_path(
r'^.*$',
CustomUnitDetail.as_view(),
name='api-custom-unit-detail',
)
]),
),
re_path(r'^.*$', CustomUnitList.as_view(), name='api-custom-unit-list'), re_path(r'^.*$', CustomUnitList.as_view(), name='api-custom-unit-list'),
])), ]),
),
# Currencies # Currencies
re_path(r'^currency/', include([ re_path(
re_path(r'^exchange/', CurrencyExchangeView.as_view(), name='api-currency-exchange'), r'^currency/',
re_path(r'^refresh/', CurrencyRefreshView.as_view(), name='api-currency-refresh'), include([
])), re_path(
r'^exchange/',
CurrencyExchangeView.as_view(),
name='api-currency-exchange',
),
re_path(
r'^refresh/', CurrencyRefreshView.as_view(), name='api-currency-refresh'
),
]),
),
# Notifications # Notifications
re_path(r'^notifications/', include([ re_path(
r'^notifications/',
include([
# Individual purchase order detail URLs # Individual purchase order detail URLs
path(r'<int:pk>/', include([ path(
re_path(r'.*$', NotificationDetail.as_view(), name='api-notifications-detail'), r'<int:pk>/',
])), include([
re_path(
r'.*$',
NotificationDetail.as_view(),
name='api-notifications-detail',
)
]),
),
# Read all # Read all
re_path(r'^readall/', NotificationReadAll.as_view(), name='api-notifications-readall'), re_path(
r'^readall/',
NotificationReadAll.as_view(),
name='api-notifications-readall',
),
# Notification messages list # Notification messages list
re_path(r'^.*$', NotificationList.as_view(), name='api-notifications-list'), re_path(r'^.*$', NotificationList.as_view(), name='api-notifications-list'),
])), ]),
),
# News # News
re_path(r'^news/', include([ re_path(
path(r'<int:pk>/', include([ r'^news/',
re_path(r'.*$', NewsFeedEntryDetail.as_view(), name='api-news-detail'), include([
])), path(
r'<int:pk>/',
include([
re_path(
r'.*$', NewsFeedEntryDetail.as_view(), name='api-news-detail'
)
]),
),
re_path(r'^.*$', NewsFeedEntryList.as_view(), name='api-news-list'), re_path(r'^.*$', NewsFeedEntryList.as_view(), name='api-news-list'),
])), ]),
),
# Flags # Flags
path('flags/', include([ path(
'flags/',
include([
path('<str:key>/', FlagDetail.as_view(), name='api-flag-detail'), path('<str:key>/', FlagDetail.as_view(), name='api-flag-detail'),
re_path(r'^.*$', FlagList.as_view(), name='api-flag-list'), re_path(r'^.*$', FlagList.as_view(), name='api-flag-list'),
])), ]),
),
# Status # Status
path('generic/status/', include([ path(
path(f'<str:{StatusView.MODEL_REF}>/', include([ 'generic/status/',
path('', StatusView.as_view(), name='api-status'), include([
])), path(
f'<str:{StatusView.MODEL_REF}>/',
include([path('', StatusView.as_view(), name='api-status')]),
),
path('', AllStatusViews.as_view(), name='api-status-all'), path('', AllStatusViews.as_view(), name='api-status-all'),
])), ]),
),
] ]
admin_api_urls = [ admin_api_urls = [

View File

@ -30,10 +30,14 @@ class CommonConfig(AppConfig):
try: try:
import common.models import common.models
if common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED', backup_value=False, create=False, cache=False): if common.models.InvenTreeSetting.get_setting(
logger.info("Clearing SERVER_RESTART_REQUIRED flag") 'SERVER_RESTART_REQUIRED', backup_value=False, create=False, cache=False
):
logger.info('Clearing SERVER_RESTART_REQUIRED flag')
if not InvenTree.ready.isImportingData(): if not InvenTree.ready.isImportingData():
common.models.InvenTreeSetting.set_setting('SERVER_RESTART_REQUIRED', False, None) common.models.InvenTreeSetting.set_setting(
'SERVER_RESTART_REQUIRED', False, None
)
except Exception: except Exception:
pass pass

View File

@ -49,12 +49,12 @@ class FileManager:
ext = os.path.splitext(file.name)[-1].lower().replace('.', '') ext = os.path.splitext(file.name)[-1].lower().replace('.', '')
try: try:
if ext in ['csv', 'tsv', ]: if ext in ['csv', 'tsv']:
# These file formats need string decoding # These file formats need string decoding
raw_data = file.read().decode('utf-8') raw_data = file.read().decode('utf-8')
# Reset stream position to beginning of file # Reset stream position to beginning of file
file.seek(0) file.seek(0)
elif ext in ['xls', 'xlsx', 'json', 'yaml', ]: elif ext in ['xls', 'xlsx', 'json', 'yaml']:
raw_data = file.read() raw_data = file.read()
# Reset stream position to beginning of file # Reset stream position to beginning of file
file.seek(0) file.seek(0)
@ -81,7 +81,12 @@ class FileManager:
def update_headers(self): def update_headers(self):
"""Update headers.""" """Update headers."""
self.HEADERS = self.REQUIRED_HEADERS + self.ITEM_MATCH_HEADERS + self.OPTIONAL_MATCH_HEADERS + self.OPTIONAL_HEADERS self.HEADERS = (
self.REQUIRED_HEADERS
+ self.ITEM_MATCH_HEADERS
+ self.OPTIONAL_MATCH_HEADERS
+ self.OPTIONAL_HEADERS
)
def setup(self): def setup(self):
"""Setup headers should be overridden in usage to set the Different Headers.""" """Setup headers should be overridden in usage to set the Different Headers."""
@ -149,15 +154,9 @@ class FileManager:
break break
if not guess_exists: if not guess_exists:
headers.append({ headers.append({'name': header, 'guess': guess})
'name': header,
'guess': guess
})
else: else:
headers.append({ headers.append({'name': header, 'guess': None})
'name': header,
'guess': None
})
return headers return headers
@ -180,7 +179,6 @@ class FileManager:
rows = [] rows = []
for i in range(self.row_count()): for i in range(self.row_count()):
data = list(self.get_row_data(i)) data = list(self.get_row_data(i))
# Is the row completely empty? Skip! # Is the row completely empty? Skip!
@ -203,10 +201,7 @@ class FileManager:
if empty: if empty:
continue continue
row = { row = {'data': data, 'index': i}
'data': data,
'index': i
}
rows.append(row) rows.append(row)

View File

@ -9,10 +9,7 @@ from .files import FileManager
class UploadFileForm(forms.Form): class UploadFileForm(forms.Form):
"""Step 1 of FileManagementFormView.""" """Step 1 of FileManagementFormView."""
file = forms.FileField( file = forms.FileField(label=_('File'), help_text=_('Select file to upload'))
label=_('File'),
help_text=_('Select file to upload'),
)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Update label and help_text.""" """Update label and help_text."""
@ -67,9 +64,7 @@ class MatchFieldForm(forms.Form):
self.fields[field_name] = forms.ChoiceField( self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + headers_choices, choices=[('', '-' * 10)] + headers_choices,
required=False, required=False,
widget=forms.Select(attrs={ widget=forms.Select(attrs={'class': 'select fieldselect'}),
'class': 'select fieldselect',
})
) )
if col['guess']: if col['guess']:
self.fields[field_name].initial = col['guess'] self.fields[field_name].initial = col['guess']
@ -107,7 +102,9 @@ class MatchItemForm(forms.Form):
field_name = col_guess.lower() + '-' + str(row['index']) field_name = col_guess.lower() + '-' + str(row['index'])
# check if field def was overridden # check if field def was overridden
overriden_field = self.get_special_field(col_guess, row, file_manager) overriden_field = self.get_special_field(
col_guess, row, file_manager
)
if overriden_field: if overriden_field:
self.fields[field_name] = overriden_field self.fields[field_name] = overriden_field
@ -117,23 +114,23 @@ class MatchItemForm(forms.Form):
value = row.get(col_guess.lower(), '') value = row.get(col_guess.lower(), '')
# Set field input box # Set field input box
self.fields[field_name] = forms.CharField( self.fields[field_name] = forms.CharField(
required=True, required=True, initial=value
initial=value,
) )
# Create item selection box # Create item selection box
elif col_guess in file_manager.OPTIONAL_MATCH_HEADERS: elif col_guess in file_manager.OPTIONAL_MATCH_HEADERS:
# Get item options # Get item options
item_options = [(option.id, option) for option in row['match_options_' + col_guess]] item_options = [
(option.id, option)
for option in row['match_options_' + col_guess]
]
# Get item match # Get item match
item_match = row['match_' + col_guess] item_match = row['match_' + col_guess]
# Set field select box # Set field select box
self.fields[field_name] = forms.ChoiceField( self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + item_options, choices=[('', '-' * 10)] + item_options,
required=False, required=False,
widget=forms.Select(attrs={ widget=forms.Select(attrs={'class': 'select bomselect'}),
'class': 'select bomselect',
})
) )
# Update select box when match was found # Update select box when match was found
if item_match: if item_match:
@ -142,7 +139,9 @@ class MatchItemForm(forms.Form):
# Create item selection box # Create item selection box
elif col_guess in file_manager.ITEM_MATCH_HEADERS: elif col_guess in file_manager.ITEM_MATCH_HEADERS:
# Get item options # Get item options
item_options = [(option.id, option) for option in row['item_options']] item_options = [
(option.id, option) for option in row['item_options']
]
# Get item match # Get item match
item_match = row['item_match'] item_match = row['item_match']
# Set field name # Set field name
@ -151,9 +150,7 @@ class MatchItemForm(forms.Form):
self.fields[field_name] = forms.ChoiceField( self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + item_options, choices=[('', '-' * 10)] + item_options,
required=False, required=False,
widget=forms.Select(attrs={ widget=forms.Select(attrs={'class': 'select bomselect'}),
'class': 'select bomselect',
})
) )
# Update select box when match was found # Update select box when match was found
if item_match: if item_match:
@ -169,8 +166,7 @@ class MatchItemForm(forms.Form):
value = row.get(col_guess.lower(), '') value = row.get(col_guess.lower(), '')
# Set field input box # Set field input box
self.fields[field_name] = forms.CharField( self.fields[field_name] = forms.CharField(
required=False, required=False, initial=value
initial=value,
) )
def get_special_field(self, col_guess, row, file_manager): def get_special_field(self, col_guess, row, file_manager):

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,7 @@ class NotificationMethod:
METHOD_NAME = '' METHOD_NAME = ''
METHOD_ICON = None METHOD_ICON = None
CONTEXT_BUILTIN = ['name', 'message', ] CONTEXT_BUILTIN = ['name', 'message']
CONTEXT_EXTRA = [] CONTEXT_EXTRA = []
GLOBAL_SETTING = None GLOBAL_SETTING = None
USER_SETTING = None USER_SETTING = None
@ -39,11 +39,15 @@ class NotificationMethod:
""" """
# Check if a sending fnc is defined # Check if a sending fnc is defined
if (not hasattr(self, 'send')) and (not hasattr(self, 'send_bulk')): if (not hasattr(self, 'send')) and (not hasattr(self, 'send_bulk')):
raise NotImplementedError('A NotificationMethod must either define a `send` or a `send_bulk` method') raise NotImplementedError(
'A NotificationMethod must either define a `send` or a `send_bulk` method'
)
# No method name is no good # No method name is no good
if self.METHOD_NAME in ('', None): if self.METHOD_NAME in ('', None):
raise NotImplementedError(f'The NotificationMethod {self.__class__} did not provide a METHOD_NAME') raise NotImplementedError(
f'The NotificationMethod {self.__class__} did not provide a METHOD_NAME'
)
# Check if plugin is disabled - if so do not gather targets etc. # Check if plugin is disabled - if so do not gather targets etc.
if self.global_setting_disable(): if self.global_setting_disable():
@ -61,9 +65,10 @@ class NotificationMethod:
def check_context(self, context): def check_context(self, context):
"""Check that all values defined in the methods CONTEXT were provided in the current context.""" """Check that all values defined in the methods CONTEXT were provided in the current context."""
def check(ref, obj): def check(ref, obj):
# the obj is not accessible so we are on the end # the obj is not accessible so we are on the end
if not isinstance(obj, (list, dict, tuple, )): if not isinstance(obj, (list, dict, tuple)):
return ref return ref
# check if the ref exists # check if the ref exists
@ -82,7 +87,9 @@ class NotificationMethod:
return check(ref[1:], obj[ref[0]]) return check(ref[1:], obj[ref[0]])
# other cases -> raise # other cases -> raise
raise NotImplementedError('This type can not be used as a context reference') raise NotImplementedError(
'This type can not be used as a context reference'
)
missing = [] missing = []
for item in (*self.CONTEXT_BUILTIN, *self.CONTEXT_EXTRA): for item in (*self.CONTEXT_BUILTIN, *self.CONTEXT_EXTRA):
@ -91,7 +98,9 @@ class NotificationMethod:
missing.append(ret) missing.append(ret)
if missing: if missing:
raise NotImplementedError(f'The `context` is missing the following items:\n{missing}') raise NotImplementedError(
f'The `context` is missing the following items:\n{missing}'
)
return context return context
@ -142,7 +151,12 @@ class NotificationMethod:
def usersetting(self, target): def usersetting(self, target):
"""Returns setting for this method for a given user.""" """Returns setting for this method for a given user."""
return NotificationUserSetting.get_setting(f'NOTIFICATION_METHOD_{self.METHOD_NAME.upper()}', user=target, method=self.METHOD_NAME) return NotificationUserSetting.get_setting(
f'NOTIFICATION_METHOD_{self.METHOD_NAME.upper()}',
user=target,
method=self.METHOD_NAME,
)
# endregion # endregion
@ -160,6 +174,8 @@ class BulkNotificationMethod(NotificationMethod):
def send_bulk(self): def send_bulk(self):
"""This function must be overridden.""" """This function must be overridden."""
raise NotImplementedError('The `send` method must be overridden!') raise NotImplementedError('The `send` method must be overridden!')
# endregion # endregion
@ -181,17 +197,25 @@ class MethodStorageClass:
selected_classes (class, optional): References to the classes that should be registered. Defaults to None. selected_classes (class, optional): References to the classes that should be registered. Defaults to None.
""" """
logger.debug('Collecting notification methods') logger.debug('Collecting notification methods')
current_method = InvenTree.helpers.inheritors(NotificationMethod) - IGNORED_NOTIFICATION_CLS current_method = (
InvenTree.helpers.inheritors(NotificationMethod) - IGNORED_NOTIFICATION_CLS
)
# for testing selective loading is made available # for testing selective loading is made available
if selected_classes: if selected_classes:
current_method = [item for item in current_method if item is selected_classes] current_method = [
item for item in current_method if item is selected_classes
]
# make sure only one of each method is added # make sure only one of each method is added
filtered_list = {} filtered_list = {}
for item in current_method: for item in current_method:
plugin = item.get_plugin(item) plugin = item.get_plugin(item)
ref = f'{plugin.package_path}_{item.METHOD_NAME}' if plugin else item.METHOD_NAME ref = (
f'{plugin.package_path}_{item.METHOD_NAME}'
if plugin
else item.METHOD_NAME
)
item.plugin = plugin() if plugin else None item.plugin = plugin() if plugin else None
filtered_list[ref] = item filtered_list[ref] = item
@ -217,9 +241,7 @@ class MethodStorageClass:
# make sure the setting exists # make sure the setting exists
self.user_settings[new_key] = item.USER_SETTING self.user_settings[new_key] = item.USER_SETTING
NotificationUserSetting.get_setting( NotificationUserSetting.get_setting(
key=new_key, key=new_key, user=user, method=item.METHOD_NAME
user=user,
method=item.METHOD_NAME,
) )
# save definition # save definition
@ -231,7 +253,7 @@ class MethodStorageClass:
return methods return methods
IGNORED_NOTIFICATION_CLS = {SingleNotificationMethod, BulkNotificationMethod, } IGNORED_NOTIFICATION_CLS = {SingleNotificationMethod, BulkNotificationMethod}
storage = MethodStorageClass() storage = MethodStorageClass()
@ -275,6 +297,7 @@ class NotificationBody:
app_label: App label (slugified) of the model app_label: App label (slugified) of the model
model_name': Name (slugified) of the model model_name': Name (slugified) of the model
""" """
name: str name: str
slug: str slug: str
message: str message: str
@ -286,24 +309,25 @@ class InvenTreeNotificationBodies:
Contains regularly used notification bodies. Contains regularly used notification bodies.
""" """
NewOrder = NotificationBody( NewOrder = NotificationBody(
name=_("New {verbose_name}"), name=_('New {verbose_name}'),
slug='{app_label}.new_{model_name}', slug='{app_label}.new_{model_name}',
message=_("A new order has been created and assigned to you"), message=_('A new order has been created and assigned to you'),
template='email/new_order_assigned.html', template='email/new_order_assigned.html',
) )
"""Send when a new order (build, sale or purchase) was created.""" """Send when a new order (build, sale or purchase) was created."""
OrderCanceled = NotificationBody( OrderCanceled = NotificationBody(
name=_("{verbose_name} canceled"), name=_('{verbose_name} canceled'),
slug='{app_label}.canceled_{model_name}', slug='{app_label}.canceled_{model_name}',
message=_("A order that is assigned to you was canceled"), message=_('A order that is assigned to you was canceled'),
template='email/canceled_order_assigned.html', template='email/canceled_order_assigned.html',
) )
"""Send when a order (sale, return or purchase) was canceled.""" """Send when a order (sale, return or purchase) was canceled."""
ItemsReceived = NotificationBody( ItemsReceived = NotificationBody(
name=_("Items Received"), name=_('Items Received'),
slug='purchase_order.items_received', slug='purchase_order.items_received',
message=_('Items have been received against a purchase order'), message=_('Items have been received against a purchase order'),
template='email/purchase_order_received.html', template='email/purchase_order_received.html',
@ -340,13 +364,19 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
if not obj_ref_value: if not obj_ref_value:
obj_ref_value = getattr(obj, 'id', None) obj_ref_value = getattr(obj, 'id', None)
if not obj_ref_value: if not obj_ref_value:
raise KeyError(f"Could not resolve an object reference for '{str(obj)}' with {obj_ref}, pk, id") raise KeyError(
f"Could not resolve an object reference for '{str(obj)}' with {obj_ref}, pk, id"
)
# Check if we have notified recently... # Check if we have notified recently...
delta = timedelta(days=1) delta = timedelta(days=1)
if common.models.NotificationEntry.check_recent(category, obj_ref_value, delta): if common.models.NotificationEntry.check_recent(category, obj_ref_value, delta):
logger.info("Notification '%s' has recently been sent for '%s' - SKIPPING", category, str(obj)) logger.info(
"Notification '%s' has recently been sent for '%s' - SKIPPING",
category,
str(obj),
)
return return
logger.info("Gathering users for notification '%s'", category) logger.info("Gathering users for notification '%s'", category)
@ -383,7 +413,9 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
target_users.add(user) target_users.add(user)
# Unhandled type # Unhandled type
else: else:
logger.error("Unknown target passed to trigger_notification method: %s", target) logger.error(
'Unknown target passed to trigger_notification method: %s', target
)
if target_users: if target_users:
logger.info("Sending notification '%s' for '%s'", category, str(obj)) logger.info("Sending notification '%s' for '%s'", category, str(obj))
@ -392,7 +424,7 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
if delivery_methods is None: if delivery_methods is None:
delivery_methods = storage.liste delivery_methods = storage.liste
else: else:
delivery_methods = (delivery_methods - IGNORED_NOTIFICATION_CLS) delivery_methods = delivery_methods - IGNORED_NOTIFICATION_CLS
for method in delivery_methods: for method in delivery_methods:
logger.info("Triggering notification method '%s'", method.METHOD_NAME) logger.info("Triggering notification method '%s'", method.METHOD_NAME)
@ -422,17 +454,15 @@ def trigger_superuser_notification(plugin: PluginConfig, msg: str):
trigger_notification( trigger_notification(
plugin, plugin,
'inventree.plugin', 'inventree.plugin',
context={ context={'error': plugin, 'name': _('Error raised by plugin'), 'message': msg},
'error': plugin,
'name': _('Error raised by plugin'),
'message': msg,
},
targets=users, targets=users,
delivery_methods={UIMessageNotification, }, delivery_methods={UIMessageNotification},
) )
def deliver_notification(cls: NotificationMethod, obj, category: str, targets, context: dict): def deliver_notification(
cls: NotificationMethod, obj, category: str, targets, context: dict
):
"""Send notification with the provided class. """Send notification with the provided class.
This: This:
@ -447,7 +477,12 @@ def deliver_notification(cls: NotificationMethod, obj, category: str, targets, c
if method.targets and len(method.targets) > 0: if method.targets and len(method.targets) > 0:
# Log start # Log start
logger.info("Notify users via '%s' for notification '%s' for '%s'", method.METHOD_NAME, category, str(obj)) logger.info(
"Notify users via '%s' for notification '%s' for '%s'",
method.METHOD_NAME,
category,
str(obj),
)
# Run setup for delivery method # Run setup for delivery method
method.setup() method.setup()
@ -472,6 +507,12 @@ def deliver_notification(cls: NotificationMethod, obj, category: str, targets, c
method.cleanup() method.cleanup()
# Log results # Log results
logger.info("Notified %s users via '%s' for notification '%s' for '%s' successfully", success_count, method.METHOD_NAME, category, str(obj)) logger.info(
"Notified %s users via '%s' for notification '%s' for '%s' successfully",
success_count,
method.METHOD_NAME,
category,
str(obj),
)
if not success: if not success:
logger.info("There were some problems") logger.info('There were some problems')

View File

@ -1,6 +1,5 @@
"""JSON serializers for common components.""" """JSON serializers for common components."""
from django.urls import reverse from django.urls import reverse
from flags.state import flag_state from flags.state import flag_state
@ -9,8 +8,10 @@ from rest_framework import serializers
import common.models as common_models import common.models as common_models
from InvenTree.helpers import get_objectreference from InvenTree.helpers import get_objectreference
from InvenTree.helpers_model import construct_absolute_url from InvenTree.helpers_model import construct_absolute_url
from InvenTree.serializers import (InvenTreeImageSerializerField, from InvenTree.serializers import (
InvenTreeModelSerializer) InvenTreeImageSerializerField,
InvenTreeModelSerializer,
)
from users.serializers import OwnerSerializer from users.serializers import OwnerSerializer
@ -62,10 +63,7 @@ class SettingsSerializer(InvenTreeModelSerializer):
if choices: if choices:
for choice in choices: for choice in choices:
results.append({ results.append({'value': choice[0], 'display_name': choice[1]})
'value': choice[0],
'display_name': choice[1],
})
return results return results
@ -131,8 +129,10 @@ class GenericReferencedSettingSerializer(SettingsSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Init overrides the Meta class to make it dynamic.""" """Init overrides the Meta class to make it dynamic."""
class CustomMeta: class CustomMeta:
"""Scaffold for custom Meta class.""" """Scaffold for custom Meta class."""
fields = [ fields = [
'pk', 'pk',
'key', 'key',
@ -204,10 +204,12 @@ class NotificationMessageSerializer(InvenTreeModelSerializer):
request = self.context['request'] request = self.context['request']
if request.user and request.user.is_staff: if request.user and request.user.is_staff:
meta = obj.target_object._meta meta = obj.target_object._meta
target['link'] = construct_absolute_url(reverse( target['link'] = construct_absolute_url(
reverse(
f'admin:{meta.db_table}_change', f'admin:{meta.db_table}_change',
kwargs={'object_id': obj.target_object_id} kwargs={'object_id': obj.target_object_id},
)) )
)
return target return target
@ -257,17 +259,9 @@ class NotesImageSerializer(InvenTreeModelSerializer):
"""Meta options for NotesImageSerializer.""" """Meta options for NotesImageSerializer."""
model = common_models.NotesImage model = common_models.NotesImage
fields = [ fields = ['pk', 'image', 'user', 'date']
'pk',
'image',
'user',
'date',
]
read_only_fields = [ read_only_fields = ['date', 'user']
'date',
'user',
]
image = InvenTreeImageSerializerField(required=True) image = InvenTreeImageSerializerField(required=True)
@ -279,13 +273,7 @@ class ProjectCodeSerializer(InvenTreeModelSerializer):
"""Meta options for ProjectCodeSerializer.""" """Meta options for ProjectCodeSerializer."""
model = common_models.ProjectCode model = common_models.ProjectCode
fields = [ fields = ['pk', 'code', 'description', 'responsible', 'responsible_detail']
'pk',
'code',
'description',
'responsible',
'responsible_detail',
]
responsible_detail = OwnerSerializer(source='responsible', read_only=True) responsible_detail = OwnerSerializer(source='responsible', read_only=True)
@ -313,9 +301,4 @@ class CustomUnitSerializer(InvenTreeModelSerializer):
"""Meta options for CustomUnitSerializer.""" """Meta options for CustomUnitSerializer."""
model = common_models.CustomUnit model = common_models.CustomUnit
fields = [ fields = ['pk', 'name', 'symbol', 'definition']
'pk',
'name',
'symbol',
'definition',
]

View File

@ -20,7 +20,9 @@ def currency_code_default():
return cached_value return cached_value
try: try:
code = InvenTreeSetting.get_setting('INVENTREE_DEFAULT_CURRENCY', backup_value='', create=True, cache=True) code = InvenTreeSetting.get_setting(
'INVENTREE_DEFAULT_CURRENCY', backup_value='', create=True, cache=True
)
except Exception: # pragma: no cover except Exception: # pragma: no cover
# Database may not yet be ready, no need to throw an error here # Database may not yet be ready, no need to throw an error here
code = '' code = ''

View File

@ -27,7 +27,9 @@ def delete_old_notifications():
try: try:
from common.models import NotificationEntry from common.models import NotificationEntry
except AppRegistryNotReady: # pragma: no cover except AppRegistryNotReady: # pragma: no cover
logger.info("Could not perform 'delete_old_notifications' - App registry not ready") logger.info(
"Could not perform 'delete_old_notifications' - App registry not ready"
)
return return
before = timezone.now() - timedelta(days=90) before = timezone.now() - timedelta(days=90)
@ -49,7 +51,7 @@ def update_news_feed():
try: try:
d = feedparser.parse(settings.INVENTREE_NEWS_URL) d = feedparser.parse(settings.INVENTREE_NEWS_URL)
except Exception as entry: # pragma: no cover except Exception as entry: # pragma: no cover
logger.warning("update_news_feed: Error parsing the newsfeed", entry) logger.warning('update_news_feed: Error parsing the newsfeed', entry)
return return
# Get a reference list # Get a reference list
@ -87,13 +89,15 @@ def delete_old_notes_images():
try: try:
from common.models import NotesImage from common.models import NotesImage
except AppRegistryNotReady: except AppRegistryNotReady:
logger.info("Could not perform 'delete_old_notes_images' - App registry not ready") logger.info(
"Could not perform 'delete_old_notes_images' - App registry not ready"
)
return return
# Remove any notes which point to non-existent image files # Remove any notes which point to non-existent image files
for note in NotesImage.objects.all(): for note in NotesImage.objects.all():
if not os.path.exists(note.image.path): if not os.path.exists(note.image.path):
logger.info("Deleting note %s - image file does not exist", note.image.path) logger.info('Deleting note %s - image file does not exist', note.image.path)
note.delete() note.delete()
note_classes = getModelsWithMixin(InvenTreeNotesMixin) note_classes = getModelsWithMixin(InvenTreeNotesMixin)
@ -112,7 +116,7 @@ def delete_old_notes_images():
break break
if not found: if not found:
logger.info("Deleting note %s - image file not linked to a note", img) logger.info('Deleting note %s - image file not linked to a note', img)
note.delete() note.delete()
# Finally, remove any images in the notes dir which are not linked to a note # Finally, remove any images in the notes dir which are not linked to a note
@ -127,7 +131,6 @@ def delete_old_notes_images():
all_notes = NotesImage.objects.all() all_notes = NotesImage.objects.all()
for image in images: for image in images:
found = False found = False
for note in all_notes: for note in all_notes:
img_path = os.path.basename(note.image.path) img_path = os.path.basename(note.image.path)
@ -136,5 +139,5 @@ def delete_old_notes_images():
break break
if not found: if not found:
logger.info("Deleting note %s - image file not linked to a note", image) logger.info('Deleting note %s - image file not linked to a note', image)
os.remove(os.path.join(notes_dir, image)) os.remove(os.path.join(notes_dir, image))

View File

@ -1,8 +1,12 @@
"""Tests for basic notification methods and functions in InvenTree.""" """Tests for basic notification methods and functions in InvenTree."""
import plugin.templatetags.plugin_extras as plugin_tags import plugin.templatetags.plugin_extras as plugin_tags
from common.notifications import (BulkNotificationMethod, NotificationMethod, from common.notifications import (
SingleNotificationMethod, storage) BulkNotificationMethod,
NotificationMethod,
SingleNotificationMethod,
storage,
)
from part.test_part import BaseNotificationIntegrationTest from part.test_part import BaseNotificationIntegrationTest
from plugin.models import NotificationUserSetting from plugin.models import NotificationUserSetting
@ -23,37 +27,31 @@ class BaseNotificationTests(BaseNotificationIntegrationTest):
"""A comment so we do not need a pass.""" """A comment so we do not need a pass."""
class NoNameNotificationMethod(NotificationMethod): class NoNameNotificationMethod(NotificationMethod):
def send(self): def send(self):
"""A comment so we do not need a pass.""" """A comment so we do not need a pass."""
class WrongContextNotificationMethod(NotificationMethod): class WrongContextNotificationMethod(NotificationMethod):
METHOD_NAME = 'WrongContextNotification' METHOD_NAME = 'WrongContextNotification'
CONTEXT_EXTRA = [ CONTEXT_EXTRA = ['aa', ('aa', 'bb'), ('templates', 'ccc'), (123,)]
'aa',
('aa', 'bb', ),
('templates', 'ccc', ),
(123, )
]
def send(self): def send(self):
"""A comment so we do not need a pass.""" """A comment so we do not need a pass."""
# no send / send bulk # no send / send bulk
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
FalseNotificationMethod('', '', '', '', ) FalseNotificationMethod('', '', '', '')
# no METHOD_NAME # no METHOD_NAME
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
NoNameNotificationMethod('', '', '', '', ) NoNameNotificationMethod('', '', '', '')
# a not existent context check # a not existent context check
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
WrongContextNotificationMethod('', '', '', '', ) WrongContextNotificationMethod('', '', '', '')
# no get_targets # no get_targets
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
AnotherFalseNotificationMethod('', '', '', {'name': 1, 'message': 2, }, ) AnotherFalseNotificationMethod('', '', '', {'name': 1, 'message': 2})
def test_failing_passing(self): def test_failing_passing(self):
"""Ensure that an error in one deliverymethod is not blocking all mehthods.""" """Ensure that an error in one deliverymethod is not blocking all mehthods."""
@ -67,7 +65,7 @@ class BaseNotificationTests(BaseNotificationIntegrationTest):
METHOD_NAME = 'ErrorImplementation' METHOD_NAME = 'ErrorImplementation'
def get_targets(self): def get_targets(self):
return [1, ] return [1]
def send(self, target): def send(self, target):
raise KeyError('This could be any error') raise KeyError('This could be any error')
@ -91,7 +89,7 @@ class BulkNotificationMethodTests(BaseNotificationIntegrationTest):
METHOD_NAME = 'WrongImplementationBulk' METHOD_NAME = 'WrongImplementationBulk'
def get_targets(self): def get_targets(self):
return [1, ] return [1]
with self.assertLogs(logger='inventree', level='ERROR'): with self.assertLogs(logger='inventree', level='ERROR'):
self._notification_run(WrongImplementation) self._notification_run(WrongImplementation)
@ -113,11 +111,12 @@ class SingleNotificationMethodTests(BaseNotificationIntegrationTest):
METHOD_NAME = 'WrongImplementationSingle' METHOD_NAME = 'WrongImplementationSingle'
def get_targets(self): def get_targets(self):
return [1, ] return [1]
with self.assertLogs(logger='inventree', level='ERROR'): with self.assertLogs(logger='inventree', level='ERROR'):
self._notification_run(WrongImplementation) self._notification_run(WrongImplementation)
# A integration test for notifications is provided in test_part.PartNotificationTest # A integration test for notifications is provided in test_part.PartNotificationTest
@ -144,7 +143,7 @@ class NotificationUserSettingTests(BaseNotificationIntegrationTest):
} }
def get_targets(self): def get_targets(self):
return [1, ] return [1]
def send_bulk(self): def send_bulk(self):
return True return True
@ -158,10 +157,14 @@ class NotificationUserSettingTests(BaseNotificationIntegrationTest):
# assertions for settings # assertions for settings
self.assertEqual(setting.name, 'Enable test notifications') self.assertEqual(setting.name, 'Enable test notifications')
self.assertEqual(setting.default_value, True) self.assertEqual(setting.default_value, True)
self.assertEqual(setting.description, 'Allow sending of test for event notifications') self.assertEqual(
setting.description, 'Allow sending of test for event notifications'
)
self.assertEqual(setting.units, 'alpha') self.assertEqual(setting.units, 'alpha')
# test tag and array # test tag and array
self.assertEqual(plugin_tags.notification_settings_list({'user': self.user}), array) self.assertEqual(
plugin_tags.notification_settings_list({'user': self.user}), array
)
self.assertEqual(array[0]['key'], 'NOTIFICATION_METHOD_TEST') self.assertEqual(array[0]['key'], 'NOTIFICATION_METHOD_TEST')
self.assertEqual(array[0]['method'], 'test') self.assertEqual(array[0]['method'], 'test')

View File

@ -15,4 +15,4 @@ class TaskTest(TestCase):
"""Test that the task `delete_old_notifications` runs through without errors.""" """Test that the task `delete_old_notifications` runs through without errors."""
# check empty run # check empty run
self.assertEqual(NotificationEntry.objects.all().count(), 0) self.assertEqual(NotificationEntry.objects.all().count(), 0)
offload_task(common_tasks.delete_old_notifications,) offload_task(common_tasks.delete_old_notifications)

View File

@ -17,16 +17,23 @@ from django.urls import reverse
import PIL import PIL
from InvenTree.helpers import str2bool from InvenTree.helpers import str2bool
from InvenTree.unit_test import (InvenTreeAPITestCase, InvenTreeTestCase, from InvenTree.unit_test import InvenTreeAPITestCase, InvenTreeTestCase, PluginMixin
PluginMixin)
from plugin import registry from plugin import registry
from plugin.models import NotificationUserSetting from plugin.models import NotificationUserSetting
from .api import WebhookView from .api import WebhookView
from .models import (ColorTheme, CustomUnit, InvenTreeSetting, from .models import (
InvenTreeUserSetting, NotesImage, NotificationEntry, ColorTheme,
NotificationMessage, ProjectCode, WebhookEndpoint, CustomUnit,
WebhookMessage) InvenTreeSetting,
InvenTreeUserSetting,
NotesImage,
NotificationEntry,
NotificationMessage,
ProjectCode,
WebhookEndpoint,
WebhookMessage,
)
CONTENT_TYPE_JSON = 'application/json' CONTENT_TYPE_JSON = 'application/json'
@ -34,9 +41,7 @@ CONTENT_TYPE_JSON = 'application/json'
class SettingsTest(InvenTreeTestCase): class SettingsTest(InvenTreeTestCase):
"""Tests for the 'settings' model.""" """Tests for the 'settings' model."""
fixtures = [ fixtures = ['settings']
'settings',
]
def test_settings_objects(self): def test_settings_objects(self):
"""Test fixture loading and lookup for settings.""" """Test fixture loading and lookup for settings."""
@ -50,7 +55,9 @@ class SettingsTest(InvenTreeTestCase):
self.assertEqual(instance_name.value, 'My very first InvenTree Instance') self.assertEqual(instance_name.value, 'My very first InvenTree Instance')
# Check object lookup (case insensitive) # Check object lookup (case insensitive)
self.assertEqual(InvenTreeSetting.get_setting_object('iNvEnTrEE_inSTanCE').pk, 1) self.assertEqual(
InvenTreeSetting.get_setting_object('iNvEnTrEE_inSTanCE').pk, 1
)
def test_settings_functions(self): def test_settings_functions(self):
"""Test settings functions and properties.""" """Test settings functions and properties."""
@ -61,14 +68,25 @@ class SettingsTest(InvenTreeTestCase):
stale_ref = 'STOCK_STALE_DAYS' stale_ref = 'STOCK_STALE_DAYS'
stale_days = InvenTreeSetting.get_setting_object(stale_ref, cache=False) stale_days = InvenTreeSetting.get_setting_object(stale_ref, cache=False)
report_size_obj = InvenTreeSetting.get_setting_object('REPORT_DEFAULT_PAGE_SIZE') report_size_obj = InvenTreeSetting.get_setting_object(
report_test_obj = InvenTreeSetting.get_setting_object('REPORT_ENABLE_TEST_REPORT') 'REPORT_DEFAULT_PAGE_SIZE'
)
report_test_obj = InvenTreeSetting.get_setting_object(
'REPORT_ENABLE_TEST_REPORT'
)
# check settings base fields # check settings base fields
self.assertEqual(instance_obj.name, 'Server Instance Name') self.assertEqual(instance_obj.name, 'Server Instance Name')
self.assertEqual(instance_obj.get_setting_name(instance_ref), 'Server Instance Name') self.assertEqual(
self.assertEqual(instance_obj.description, 'String descriptor for the server instance') instance_obj.get_setting_name(instance_ref), 'Server Instance Name'
self.assertEqual(instance_obj.get_setting_description(instance_ref), 'String descriptor for the server instance') )
self.assertEqual(
instance_obj.description, 'String descriptor for the server instance'
)
self.assertEqual(
instance_obj.get_setting_description(instance_ref),
'String descriptor for the server instance',
)
# check units # check units
self.assertEqual(instance_obj.units, '') self.assertEqual(instance_obj.units, '')
@ -90,7 +108,9 @@ class SettingsTest(InvenTreeTestCase):
# check as_int # check as_int
self.assertEqual(stale_days.as_int(), 0) self.assertEqual(stale_days.as_int(), 0)
self.assertEqual(instance_obj.as_int(), 'InvenTree') # not an int -> return default self.assertEqual(
instance_obj.as_int(), 'InvenTree'
) # not an int -> return default
# check as_bool # check as_bool
self.assertEqual(report_test_obj.as_bool(), True) self.assertEqual(report_test_obj.as_bool(), True)
@ -116,62 +136,66 @@ class SettingsTest(InvenTreeTestCase):
def test_all_settings(self): def test_all_settings(self):
"""Make sure that the all_settings function returns correctly""" """Make sure that the all_settings function returns correctly"""
result = InvenTreeSetting.all_settings() result = InvenTreeSetting.all_settings()
self.assertIn("INVENTREE_INSTANCE", result) self.assertIn('INVENTREE_INSTANCE', result)
self.assertIsInstance(result['INVENTREE_INSTANCE'], InvenTreeSetting) self.assertIsInstance(result['INVENTREE_INSTANCE'], InvenTreeSetting)
@mock.patch("common.models.InvenTreeSetting.get_setting_definition") @mock.patch('common.models.InvenTreeSetting.get_setting_definition')
def test_check_all_settings(self, get_setting_definition): def test_check_all_settings(self, get_setting_definition):
"""Make sure that the check_all_settings function returns correctly""" """Make sure that the check_all_settings function returns correctly"""
# define partial schema # define partial schema
settings_definition = { settings_definition = {
"AB": { # key that's has not already been accessed 'AB': { # key that's has not already been accessed
"required": True, 'required': True
}, },
"CD": { 'CD': {'required': True, 'protected': True},
"required": True, 'EF': {},
"protected": True,
},
"EF": {}
} }
def mocked(key, **kwargs): def mocked(key, **kwargs):
return settings_definition.get(key, {}) return settings_definition.get(key, {})
get_setting_definition.side_effect = mocked get_setting_definition.side_effect = mocked
self.assertEqual(InvenTreeSetting.check_all_settings(settings_definition=settings_definition), (False, ["AB", "CD"])) self.assertEqual(
InvenTreeSetting.set_setting('AB', "hello", self.user) InvenTreeSetting.check_all_settings(
InvenTreeSetting.set_setting('CD', "world", self.user) settings_definition=settings_definition
),
(False, ['AB', 'CD']),
)
InvenTreeSetting.set_setting('AB', 'hello', self.user)
InvenTreeSetting.set_setting('CD', 'world', self.user)
self.assertEqual(InvenTreeSetting.check_all_settings(), (True, [])) self.assertEqual(InvenTreeSetting.check_all_settings(), (True, []))
@mock.patch("common.models.InvenTreeSetting.get_setting_definition") @mock.patch('common.models.InvenTreeSetting.get_setting_definition')
def test_settings_validator(self, get_setting_definition): def test_settings_validator(self, get_setting_definition):
"""Make sure that the validator function gets called on set setting.""" """Make sure that the validator function gets called on set setting."""
def validator(x): def validator(x):
if x == "hello": if x == 'hello':
return x return x
raise ValidationError(f"{x} is not valid") raise ValidationError(f'{x} is not valid')
mock_validator = mock.Mock(side_effect=validator) mock_validator = mock.Mock(side_effect=validator)
# define partial schema # define partial schema
settings_definition = { settings_definition = {
"AB": { # key that's has not already been accessed 'AB': { # key that's has not already been accessed
"validator": mock_validator, 'validator': mock_validator
}, }
} }
def mocked(key, **kwargs): def mocked(key, **kwargs):
return settings_definition.get(key, {}) return settings_definition.get(key, {})
get_setting_definition.side_effect = mocked get_setting_definition.side_effect = mocked
InvenTreeSetting.set_setting("AB", "hello", self.user) InvenTreeSetting.set_setting('AB', 'hello', self.user)
mock_validator.assert_called_with("hello") mock_validator.assert_called_with('hello')
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
InvenTreeSetting.set_setting("AB", "world", self.user) InvenTreeSetting.set_setting('AB', 'world', self.user)
mock_validator.assert_called_with("world") mock_validator.assert_called_with('world')
def run_settings_check(self, key, setting): def run_settings_check(self, key, setting):
"""Test that all settings are valid. """Test that all settings are valid.
@ -194,7 +218,9 @@ class SettingsTest(InvenTreeTestCase):
self.assertIn('django.utils.functional.lazy', str(type(description))) self.assertIn('django.utils.functional.lazy', str(type(description)))
if key != key.upper(): if key != key.upper():
raise ValueError(f"Setting key '{key}' is not uppercase") # pragma: no cover raise ValueError(
f"Setting key '{key}' is not uppercase"
) # pragma: no cover
# Check that only allowed keys are provided # Check that only allowed keys are provided
allowed_keys = [ allowed_keys = [
@ -232,7 +258,6 @@ class SettingsTest(InvenTreeTestCase):
- Ensure that every setting has a description, which is translated - Ensure that every setting has a description, which is translated
""" """
for key, setting in InvenTreeSetting.SETTINGS.items(): for key, setting in InvenTreeSetting.SETTINGS.items():
try: try:
self.run_settings_check(key, setting) self.run_settings_check(key, setting)
except Exception as exc: # pragma: no cover except Exception as exc: # pragma: no cover
@ -249,7 +274,6 @@ class SettingsTest(InvenTreeTestCase):
def test_defaults(self): def test_defaults(self):
"""Populate the settings with default values.""" """Populate the settings with default values."""
for key in InvenTreeSetting.SETTINGS.keys(): for key in InvenTreeSetting.SETTINGS.keys():
value = InvenTreeSetting.get_setting_default(key) value = InvenTreeSetting.get_setting_default(key)
InvenTreeSetting.set_setting(key, value, self.user) InvenTreeSetting.set_setting(key, value, self.user)
@ -261,10 +285,14 @@ class SettingsTest(InvenTreeTestCase):
if setting.is_bool(): if setting.is_bool():
if setting.default_value in ['', None]: if setting.default_value in ['', None]:
raise ValueError(f'Default value for boolean setting {key} not provided') # pragma: no cover raise ValueError(
f'Default value for boolean setting {key} not provided'
) # pragma: no cover
if setting.default_value not in [True, False]: if setting.default_value not in [True, False]:
raise ValueError(f'Non-boolean default value specified for {key}') # pragma: no cover raise ValueError(
f'Non-boolean default value specified for {key}'
) # pragma: no cover
def test_global_setting_caching(self): def test_global_setting_caching(self):
"""Test caching operations for the global settings class""" """Test caching operations for the global settings class"""
@ -294,9 +322,7 @@ class SettingsTest(InvenTreeTestCase):
# Generate a number of new users # Generate a number of new users
for idx in range(5): for idx in range(5):
get_user_model().objects.create( get_user_model().objects.create(
username=f"User_{idx}", username=f'User_{idx}', password='hunter42', email='email@dot.com'
password="hunter42",
email="email@dot.com",
) )
key = 'SEARCH_PREVIEW_RESULTS' key = 'SEARCH_PREVIEW_RESULTS'
@ -305,7 +331,10 @@ class SettingsTest(InvenTreeTestCase):
for user in get_user_model().objects.all(): for user in get_user_model().objects.all():
setting = InvenTreeUserSetting.get_setting_object(key, user=user) setting = InvenTreeUserSetting.get_setting_object(key, user=user)
cache_key = setting.cache_key cache_key = setting.cache_key
self.assertEqual(cache_key, f"InvenTreeUserSetting:SEARCH_PREVIEW_RESULTS_user:{user.username}") self.assertEqual(
cache_key,
f'InvenTreeUserSetting:SEARCH_PREVIEW_RESULTS_user:{user.username}',
)
InvenTreeUserSetting.set_setting(key, user.pk, None, user=user) InvenTreeUserSetting.set_setting(key, user.pk, None, user=user)
self.assertIsNotNone(cache.get(cache_key)) self.assertIsNotNone(cache.get(cache_key))
@ -333,7 +362,9 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
n_public_settings = len([k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_')]) n_public_settings = len([
k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_')
])
# Number of results should match the number of settings # Number of results should match the number of settings
self.assertEqual(len(response.data), n_public_settings) self.assertEqual(len(response.data), n_public_settings)
@ -358,13 +389,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
# Test setting via the API # Test setting via the API
for val in ['cat', 'hat', 'bat', 'mat']: for val in ['cat', 'hat', 'bat', 'mat']:
response = self.patch( response = self.patch(url, {'value': val}, expected_code=200)
url,
{
'value': val,
},
expected_code=200
)
self.assertEqual(response.data['value'], val) self.assertEqual(response.data['value'], val)
@ -374,7 +399,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
def test_api_detail(self): def test_api_detail(self):
"""Test that we can access the detail view for a setting based on the <key>.""" """Test that we can access the detail view for a setting based on the <key>."""
# These keys are invalid, and should return 404 # These keys are invalid, and should return 404
for key in ["apple", "carrot", "dog"]: for key in ['apple', 'carrot', 'dog']:
response = self.get( response = self.get(
reverse('api-global-setting-detail', kwargs={'key': key}), reverse('api-global-setting-detail', kwargs={'key': key}),
expected_code=404, expected_code=404,
@ -394,13 +419,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
self.assertEqual(response.data['value'], 'InvenTree') self.assertEqual(response.data['value'], 'InvenTree')
# Now, the object should have been created in the DB # Now, the object should have been created in the DB
self.patch( self.patch(url, {'value': 'My new title'}, expected_code=200)
url,
{
'value': 'My new title',
},
expected_code=200,
)
setting = InvenTreeSetting.objects.get(key=key) setting = InvenTreeSetting.objects.get(key=key)
@ -451,8 +470,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
"""Test a boolean user setting value.""" """Test a boolean user setting value."""
# Ensure we have a boolean setting available # Ensure we have a boolean setting available
setting = InvenTreeUserSetting.get_setting_object( setting = InvenTreeUserSetting.get_setting_object(
'SEARCH_PREVIEW_SHOW_PARTS', 'SEARCH_PREVIEW_SHOW_PARTS', user=self.user
user=self.user
) )
# Check default values # Check default values
@ -465,20 +483,16 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
self.assertEqual(response.data['pk'], setting.pk) self.assertEqual(response.data['pk'], setting.pk)
self.assertEqual(response.data['key'], 'SEARCH_PREVIEW_SHOW_PARTS') self.assertEqual(response.data['key'], 'SEARCH_PREVIEW_SHOW_PARTS')
self.assertEqual(response.data['description'], 'Display parts in search preview window') self.assertEqual(
response.data['description'], 'Display parts in search preview window'
)
self.assertEqual(response.data['type'], 'boolean') self.assertEqual(response.data['type'], 'boolean')
self.assertEqual(len(response.data['choices']), 0) self.assertEqual(len(response.data['choices']), 0)
self.assertTrue(str2bool(response.data['value'])) self.assertTrue(str2bool(response.data['value']))
# Assign some truthy values # Assign some truthy values
for v in ['true', True, 1, 'y', 'TRUE']: for v in ['true', True, 1, 'y', 'TRUE']:
self.patch( self.patch(url, {'value': str(v)}, expected_code=200)
url,
{
'value': str(v),
},
expected_code=200,
)
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
@ -486,13 +500,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Assign some false(ish) values # Assign some false(ish) values
for v in ['false', False, '0', 'n', 'FalSe']: for v in ['false', False, '0', 'n', 'FalSe']:
self.patch( self.patch(url, {'value': str(v)}, expected_code=200)
url,
{
'value': str(v),
},
expected_code=200,
)
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
@ -500,13 +508,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Assign some invalid values # Assign some invalid values
for v in ['x', '', 'invalid', None, '-1', 'abcde']: for v in ['x', '', 'invalid', None, '-1', 'abcde']:
response = self.patch( response = self.patch(url, {'value': str(v)}, expected_code=200)
url,
{
'value': str(v),
},
expected_code=200
)
# Invalid values evaluate to False # Invalid values evaluate to False
self.assertFalse(str2bool(response.data['value'])) self.assertFalse(str2bool(response.data['value']))
@ -514,8 +516,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
def test_user_setting_choice(self): def test_user_setting_choice(self):
"""Test a user setting with choices.""" """Test a user setting with choices."""
setting = InvenTreeUserSetting.get_setting_object( setting = InvenTreeUserSetting.get_setting_object(
'DATE_DISPLAY_FORMAT', 'DATE_DISPLAY_FORMAT', user=self.user
user=self.user
) )
url = reverse('api-user-setting-detail', kwargs={'key': setting.key}) url = reverse('api-user-setting-detail', kwargs={'key': setting.key})
@ -525,37 +526,21 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Check that a valid option can be assigned via the API # Check that a valid option can be assigned via the API
for opt in ['YYYY-MM-DD', 'DD-MM-YYYY', 'MM/DD/YYYY']: for opt in ['YYYY-MM-DD', 'DD-MM-YYYY', 'MM/DD/YYYY']:
self.patch(url, {'value': opt}, expected_code=200)
self.patch(
url,
{
'value': opt,
},
expected_code=200,
)
setting.refresh_from_db() setting.refresh_from_db()
self.assertEqual(setting.value, opt) self.assertEqual(setting.value, opt)
# Send an invalid option # Send an invalid option
for opt in ['cat', 'dog', 12345]: for opt in ['cat', 'dog', 12345]:
response = self.patch(url, {'value': opt}, expected_code=400)
response = self.patch(
url,
{
'value': opt,
},
expected_code=400,
)
self.assertIn('Chosen value is not a valid option', str(response.data)) self.assertIn('Chosen value is not a valid option', str(response.data))
def test_user_setting_integer(self): def test_user_setting_integer(self):
"""Test a integer user setting value.""" """Test a integer user setting value."""
setting = InvenTreeUserSetting.get_setting_object( setting = InvenTreeUserSetting.get_setting_object(
'SEARCH_PREVIEW_RESULTS', 'SEARCH_PREVIEW_RESULTS', user=self.user, cache=False
user=self.user,
cache=False,
) )
url = reverse('api-user-setting-detail', kwargs={'key': setting.key}) url = reverse('api-user-setting-detail', kwargs={'key': setting.key})
@ -573,13 +558,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Set valid options via the api # Set valid options via the api
for v in [5, 15, 25]: for v in [5, 15, 25]:
self.patch( self.patch(url, {'value': v}, expected_code=200)
url,
{
'value': v,
},
expected_code=200,
)
setting.refresh_from_db() setting.refresh_from_db()
self.assertEqual(setting.to_native_value(), v) self.assertEqual(setting.to_native_value(), v)
@ -587,14 +566,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Set invalid options via the API # Set invalid options via the API
# Note that this particular setting has a MinValueValidator(1) associated with it # Note that this particular setting has a MinValueValidator(1) associated with it
for v in [0, -1, -5]: for v in [0, -1, -5]:
response = self.patch(url, {'value': v}, expected_code=400)
response = self.patch(
url,
{
'value': v,
},
expected_code=400,
)
class NotificationUserSettingsApiTest(InvenTreeAPITestCase): class NotificationUserSettingsApiTest(InvenTreeAPITestCase):
@ -608,9 +580,15 @@ class NotificationUserSettingsApiTest(InvenTreeAPITestCase):
def test_setting(self): def test_setting(self):
"""Test the string name for NotificationUserSetting.""" """Test the string name for NotificationUserSetting."""
NotificationUserSetting.set_setting('NOTIFICATION_METHOD_MAIL', True, change_user=self.user, user=self.user) NotificationUserSetting.set_setting(
test_setting = NotificationUserSetting.get_setting_object('NOTIFICATION_METHOD_MAIL', user=self.user) 'NOTIFICATION_METHOD_MAIL', True, change_user=self.user, user=self.user
self.assertEqual(str(test_setting), 'NOTIFICATION_METHOD_MAIL (for testuser): True') )
test_setting = NotificationUserSetting.get_setting_object(
'NOTIFICATION_METHOD_MAIL', user=self.user
)
self.assertEqual(
str(test_setting), 'NOTIFICATION_METHOD_MAIL (for testuser): True'
)
class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase): class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase):
@ -638,26 +616,38 @@ class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase):
registry.set_plugin_state('sample', True) registry.set_plugin_state('sample', True)
# get data # get data
url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'API_KEY'}) url = reverse(
'api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'API_KEY'}
)
response = self.get(url, expected_code=200) response = self.get(url, expected_code=200)
# check the right setting came through # check the right setting came through
self.assertTrue(response.data['key'], 'API_KEY') self.assertTrue(response.data['key'], 'API_KEY')
self.assertTrue(response.data['plugin'], 'sample') self.assertTrue(response.data['plugin'], 'sample')
self.assertTrue(response.data['type'], 'string') self.assertTrue(response.data['type'], 'string')
self.assertTrue(response.data['description'], 'Key required for accessing external API') self.assertTrue(
response.data['description'], 'Key required for accessing external API'
)
# Failure mode tests # Failure mode tests
# Non-existent plugin # Non-existent plugin
url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'doesnotexist', 'key': 'doesnotmatter'}) url = reverse(
'api-plugin-setting-detail',
kwargs={'plugin': 'doesnotexist', 'key': 'doesnotmatter'},
)
response = self.get(url, expected_code=404) response = self.get(url, expected_code=404)
self.assertIn("Plugin 'doesnotexist' not installed", str(response.data)) self.assertIn("Plugin 'doesnotexist' not installed", str(response.data))
# Wrong key # Wrong key
url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'doesnotexist'}) url = reverse(
'api-plugin-setting-detail',
kwargs={'plugin': 'sample', 'key': 'doesnotexist'},
)
response = self.get(url, expected_code=404) response = self.get(url, expected_code=404)
self.assertIn("Plugin 'sample' has no setting matching 'doesnotexist'", str(response.data)) self.assertIn(
"Plugin 'sample' has no setting matching 'doesnotexist'", str(response.data)
)
def test_invalid_setting_key(self): def test_invalid_setting_key(self):
"""Test that an invalid setting key returns a 404.""" """Test that an invalid setting key returns a 404."""
@ -684,32 +674,30 @@ class WebhookMessageTests(TestCase):
def test_missing_token(self): def test_missing_token(self):
"""Tests that token checks work.""" """Tests that token checks work."""
response = self.client.post( response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON)
self.url,
content_type=CONTENT_TYPE_JSON,
)
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status_code == HTTPStatus.FORBIDDEN
assert ( assert (
json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR json.loads(response.content)['detail']
== WebhookView.model_class.MESSAGE_TOKEN_ERROR
) )
def test_bad_token(self): def test_bad_token(self):
"""Test that a wrong token is not working.""" """Test that a wrong token is not working."""
response = self.client.post( response = self.client.post(
self.url, self.url, content_type=CONTENT_TYPE_JSON, **{'HTTP_TOKEN': '1234567fghj'}
content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': '1234567fghj'},
) )
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status_code == HTTPStatus.FORBIDDEN
assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR) assert (
json.loads(response.content)['detail']
== WebhookView.model_class.MESSAGE_TOKEN_ERROR
)
def test_bad_url(self): def test_bad_url(self):
"""Test that a wrongly formed url is not working.""" """Test that a wrongly formed url is not working."""
response = self.client.post( response = self.client.post(
'/api/webhook/1234/', '/api/webhook/1234/', content_type=CONTENT_TYPE_JSON
content_type=CONTENT_TYPE_JSON,
) )
assert response.status_code == HTTPStatus.NOT_FOUND assert response.status_code == HTTPStatus.NOT_FOUND
@ -725,7 +713,8 @@ class WebhookMessageTests(TestCase):
assert response.status_code == HTTPStatus.NOT_ACCEPTABLE assert response.status_code == HTTPStatus.NOT_ACCEPTABLE
assert ( assert (
json.loads(response.content)['detail'] == 'Expecting property name enclosed in double quotes' json.loads(response.content)['detail']
== 'Expecting property name enclosed in double quotes'
) )
def test_success_no_token_check(self): def test_success_no_token_check(self):
@ -735,10 +724,7 @@ class WebhookMessageTests(TestCase):
self.endpoint_def.save() self.endpoint_def.save()
# check # check
response = self.client.post( response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON)
self.url,
content_type=CONTENT_TYPE_JSON,
)
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK
@ -751,13 +737,13 @@ class WebhookMessageTests(TestCase):
self.endpoint_def.save() self.endpoint_def.save()
# check # check
response = self.client.post( response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON)
self.url,
content_type=CONTENT_TYPE_JSON,
)
assert response.status_code == HTTPStatus.FORBIDDEN assert response.status_code == HTTPStatus.FORBIDDEN
assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR) assert (
json.loads(response.content)['detail']
== WebhookView.model_class.MESSAGE_TOKEN_ERROR
)
def test_success_hmac(self): def test_success_hmac(self):
"""Test with a valid HMAC provided.""" """Test with a valid HMAC provided."""
@ -783,7 +769,7 @@ class WebhookMessageTests(TestCase):
""" """
response = self.client.post( response = self.client.post(
self.url, self.url,
data={"this": "is a message"}, data={'this': 'is a message'},
content_type=CONTENT_TYPE_JSON, content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': str(self.endpoint_def.token)}, **{'HTTP_TOKEN': str(self.endpoint_def.token)},
) )
@ -791,15 +777,13 @@ class WebhookMessageTests(TestCase):
assert response.status_code == HTTPStatus.OK assert response.status_code == HTTPStatus.OK
assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK
message = WebhookMessage.objects.get() message = WebhookMessage.objects.get()
assert message.body == {"this": "is a message"} assert message.body == {'this': 'is a message'}
class NotificationTest(InvenTreeAPITestCase): class NotificationTest(InvenTreeAPITestCase):
"""Tests for NotificationEntry.""" """Tests for NotificationEntry."""
fixtures = [ fixtures = ['users']
'users',
]
def test_check_notification_entries(self): def test_check_notification_entries(self):
"""Test that notification entries can be created.""" """Test that notification entries can be created."""
@ -832,7 +816,10 @@ class NotificationTest(InvenTreeAPITestCase):
self.assertIn('GET', response.data['actions']) self.assertIn('GET', response.data['actions'])
self.assertNotIn('POST', response.data['actions']) self.assertNotIn('POST', response.data['actions'])
self.assertEqual(response.data['description'], 'List view for all notifications of the current user.') self.assertEqual(
response.data['description'],
'List view for all notifications of the current user.',
)
# POST action should fail (not allowed) # POST action should fail (not allowed)
response = self.post(url, {}, expected_code=405) response = self.post(url, {}, expected_code=405)
@ -867,13 +854,7 @@ class NotificationTest(InvenTreeAPITestCase):
ntf.save() ntf.save()
# Read out via API again # Read out via API again
response = self.get( response = self.get(url, {'read': True}, expected_code=200)
url,
{
'read': True,
},
expected_code=200
)
# Check validity of returned data # Check validity of returned data
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
@ -882,15 +863,7 @@ class NotificationTest(InvenTreeAPITestCase):
# Now, let's bulk delete all 'unread' notifications via the API, # Now, let's bulk delete all 'unread' notifications via the API,
# but only associated with the logged in user # but only associated with the logged in user
response = self.delete( response = self.delete(url, {'filters': {'read': False}}, expected_code=204)
url,
{
'filters': {
'read': False,
}
},
expected_code=204,
)
# Only 7 notifications should have been deleted, # Only 7 notifications should have been deleted,
# as the notifications associated with other users must remain untouched # as the notifications associated with other users must remain untouched
@ -907,13 +880,17 @@ class CommonTest(InvenTreeAPITestCase):
from plugin import registry from plugin import registry
# set flag true # set flag true
common.models.InvenTreeSetting.set_setting('SERVER_RESTART_REQUIRED', True, None) common.models.InvenTreeSetting.set_setting(
'SERVER_RESTART_REQUIRED', True, None
)
# reload the app # reload the app
registry.reload_plugins() registry.reload_plugins()
# now it should be false again # now it should be false again
self.assertFalse(common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED')) self.assertFalse(
common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED')
)
def test_config_api(self): def test_config_api(self):
"""Test config URLs.""" """Test config URLs."""
@ -926,8 +903,13 @@ class CommonTest(InvenTreeAPITestCase):
# Successful checks # Successful checks
data = [ data = [
self.get(reverse('api-config-list'), expected_code=200).data[0], # list endpoint self.get(reverse('api-config-list'), expected_code=200).data[
self.get(reverse('api-config-detail', kwargs={'key': 'INVENTREE_DEBUG'}), expected_code=200).data, # detail endpoint 0
], # list endpoint
self.get(
reverse('api-config-detail', kwargs={'key': 'INVENTREE_DEBUG'}),
expected_code=200,
).data, # detail endpoint
] ]
for item in data: for item in data:
@ -956,21 +938,33 @@ class CommonTest(InvenTreeAPITestCase):
self.assertEqual(response.data[0]['key'], 'EXPERIMENTAL') self.assertEqual(response.data[0]['key'], 'EXPERIMENTAL')
self.assertTrue(response.data[0]['conditions']) self.assertTrue(response.data[0]['conditions'])
response = self.get(reverse('api-flag-detail', kwargs={'key': 'EXPERIMENTAL'}), expected_code=200) response = self.get(
reverse('api-flag-detail', kwargs={'key': 'EXPERIMENTAL'}),
expected_code=200,
)
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
self.assertEqual(response.data['key'], 'EXPERIMENTAL') self.assertEqual(response.data['key'], 'EXPERIMENTAL')
self.assertTrue(response.data['conditions']) self.assertTrue(response.data['conditions'])
# Try without param -> false # Try without param -> false
response = self.get(reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), expected_code=200) response = self.get(
reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), expected_code=200
)
self.assertFalse(response.data['state']) self.assertFalse(response.data['state'])
# Try with param -> true # Try with param -> true
response = self.get(reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), {'ngen': ''}, expected_code=200) response = self.get(
reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}),
{'ngen': ''},
expected_code=200,
)
self.assertTrue(response.data['state']) self.assertTrue(response.data['state'])
# Try non existent flag # Try non existent flag
response = self.get(reverse('api-flag-detail', kwargs={'key': 'NON_EXISTENT'}), expected_code=404) response = self.get(
reverse('api-flag-detail', kwargs={'key': 'NON_EXISTENT'}),
expected_code=404,
)
# Turn into normal user again # Turn into normal user again
self.user.is_superuser = False self.user.is_superuser = False
@ -1038,7 +1032,7 @@ class CurrencyAPITests(InvenTreeAPITestCase):
# Delay and try again # Delay and try again
time.sleep(10) time.sleep(10)
raise TimeoutError("Could not refresh currency exchange data after 5 attempts") raise TimeoutError('Could not refresh currency exchange data after 5 attempts')
class NotesImageTest(InvenTreeAPITestCase): class NotesImageTest(InvenTreeAPITestCase):
@ -1052,25 +1046,29 @@ class NotesImageTest(InvenTreeAPITestCase):
response = self.post( response = self.post(
reverse('api-notes-image-list'), reverse('api-notes-image-list'),
data={ data={
'image': SimpleUploadedFile('test.txt', b"this is not an image file", content_type='text/plain'), 'image': SimpleUploadedFile(
}, 'test.txt', b'this is not an image file', content_type='text/plain'
format='multipart',
expected_code=400
) )
self.assertIn("Upload a valid image", str(response.data['image']))
# Test upload of an invalid image file
response = self.post(
reverse('api-notes-image-list'),
data={
'image': SimpleUploadedFile('test.png', b"this is not an image file", content_type='image/png'),
}, },
format='multipart', format='multipart',
expected_code=400, expected_code=400,
) )
self.assertIn("Upload a valid image", str(response.data['image'])) self.assertIn('Upload a valid image', str(response.data['image']))
# Test upload of an invalid image file
response = self.post(
reverse('api-notes-image-list'),
data={
'image': SimpleUploadedFile(
'test.png', b'this is not an image file', content_type='image/png'
)
},
format='multipart',
expected_code=400,
)
self.assertIn('Upload a valid image', str(response.data['image']))
# Check that no extra database entries have been created # Check that no extra database entries have been created
self.assertEqual(NotesImage.objects.count(), n) self.assertEqual(NotesImage.objects.count(), n)
@ -1089,10 +1087,12 @@ class NotesImageTest(InvenTreeAPITestCase):
self.post( self.post(
reverse('api-notes-image-list'), reverse('api-notes-image-list'),
data={ data={
'image': SimpleUploadedFile('test.png', contents, content_type='image/png'), 'image': SimpleUploadedFile(
'test.png', contents, content_type='image/png'
)
}, },
format='multipart', format='multipart',
expected_code=201 expected_code=201,
) )
# Check that a new file has been created # Check that a new file has been created
@ -1136,7 +1136,7 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# Delete it # Delete it
self.delete( self.delete(
reverse('api-project-code-detail', kwargs={'pk': code.pk}), reverse('api-project-code-detail', kwargs={'pk': code.pk}),
expected_code=204 expected_code=204,
) )
# Check it is gone # Check it is gone
@ -1147,25 +1147,22 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# Create a new project code # Create a new project code
response = self.post( response = self.post(
self.url, self.url,
data={ data={'code': 'PRJ-001', 'description': 'Test project code'},
'code': 'PRJ-001', expected_code=400,
'description': 'Test project code',
},
expected_code=400
) )
self.assertIn('project code with this Project Code already exists', str(response.data['code'])) self.assertIn(
'project code with this Project Code already exists',
str(response.data['code']),
)
def test_write_access(self): def test_write_access(self):
"""Test that non-staff users have read-only access""" """Test that non-staff users have read-only access"""
# By default user has staff access, can create a new project code # By default user has staff access, can create a new project code
response = self.post( response = self.post(
self.url, self.url,
data={ data={'code': 'PRJ-xxx', 'description': 'Test project code'},
'code': 'PRJ-xxx', expected_code=201,
'description': 'Test project code',
},
expected_code=201
) )
pk = response.data['pk'] pk = response.data['pk']
@ -1173,10 +1170,8 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# Test we can edit, also # Test we can edit, also
response = self.patch( response = self.patch(
reverse('api-project-code-detail', kwargs={'pk': pk}), reverse('api-project-code-detail', kwargs={'pk': pk}),
data={ data={'code': 'PRJ-999'},
'code': 'PRJ-999', expected_code=200,
},
expected_code=200
) )
self.assertEqual(response.data['code'], 'PRJ-999') self.assertEqual(response.data['code'], 'PRJ-999')
@ -1188,20 +1183,15 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# As user does not have staff access, should return 403 for list endpoint # As user does not have staff access, should return 403 for list endpoint
response = self.post( response = self.post(
self.url, self.url,
data={ data={'code': 'PRJ-123', 'description': 'Test project code'},
'code': 'PRJ-123', expected_code=403,
'description': 'Test project code'
},
expected_code=403
) )
# Should also return 403 for detail endpoint # Should also return 403 for detail endpoint
response = self.patch( response = self.patch(
reverse('api-project-code-detail', kwargs={'pk': pk}), reverse('api-project-code-detail', kwargs={'pk': pk}),
data={ data={'code': 'PRJ-999'},
'code': 'PRJ-999', expected_code=403,
},
expected_code=403
) )
@ -1219,8 +1209,14 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
super().setUpTestData() super().setUpTestData()
units = [ units = [
CustomUnit(name='metres_per_amp', definition='meter / ampere', symbol='m/A'), CustomUnit(
CustomUnit(name='hectares_per_second', definition='hectares per second', symbol='ha/s'), name='metres_per_amp', definition='meter / ampere', symbol='m/A'
),
CustomUnit(
name='hectares_per_second',
definition='hectares per second',
symbol='ha/s',
),
] ]
CustomUnit.objects.bulk_create(units) CustomUnit.objects.bulk_create(units)
@ -1240,10 +1236,8 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
self.patch( self.patch(
reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}), reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}),
{ {'name': 'new_unit_name'},
'name': 'new_unit_name', expected_code=403,
},
expected_code=403
) )
# Ok, what if we have permission? # Ok, what if we have permission?
@ -1252,9 +1246,7 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
self.patch( self.patch(
reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}), reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}),
{ {'name': 'new_unit_name'},
'name': 'new_unit_name',
},
# expected_code=200 # expected_code=200
) )
@ -1269,21 +1261,9 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
self.user.save() self.user.save()
# Test invalid 'name' values (must be valid identifier) # Test invalid 'name' values (must be valid identifier)
invalid_name_values = [ invalid_name_values = ['1', '1abc', 'abc def', 'abc-def', 'abc.def']
'1',
'1abc',
'abc def',
'abc-def',
'abc.def',
]
url = reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}) url = reverse('api-custom-unit-detail', kwargs={'pk': unit.pk})
for name in invalid_name_values: for name in invalid_name_values:
self.patch( self.patch(url, {'name': name}, expected_code=400)
url,
{
'name': name,
},
expected_code=400
)

View File

@ -1,4 +1,3 @@
"""URL lookup for common views.""" """URL lookup for common views."""
common_urls = [ common_urls = []
]

View File

@ -81,11 +81,7 @@ class FileManagementFormView(MultiStepFormView):
('fields', forms.MatchFieldForm), ('fields', forms.MatchFieldForm),
('items', forms.MatchItemForm), ('items', forms.MatchItemForm),
] ]
form_steps_description = [ form_steps_description = [_('Upload File'), _('Match Fields'), _('Match Items')]
_("Upload File"),
_("Match Fields"),
_("Match Items"),
]
media_folder = 'file_upload/' media_folder = 'file_upload/'
extra_context_data = {} extra_context_data = {}
@ -95,8 +91,12 @@ class FileManagementFormView(MultiStepFormView):
super().__init__(self, *args, **kwargs) super().__init__(self, *args, **kwargs)
# Check for file manager class # Check for file manager class
if not hasattr(self, 'file_manager_class') and not issubclass(self.file_manager_class, FileManager): if not hasattr(self, 'file_manager_class') and not issubclass(
raise NotImplementedError('A subclass of a file manager class needs to be set!') self.file_manager_class, FileManager
):
raise NotImplementedError(
'A subclass of a file manager class needs to be set!'
)
def get_context_data(self, form=None, **kwargs): def get_context_data(self, form=None, **kwargs):
"""Handle context data.""" """Handle context data."""
@ -106,7 +106,6 @@ class FileManagementFormView(MultiStepFormView):
context = super().get_context_data(form=form, **kwargs) context = super().get_context_data(form=form, **kwargs)
if self.steps.current in ('fields', 'items'): if self.steps.current in ('fields', 'items'):
# Get columns and row data # Get columns and row data
self.columns = self.file_manager.columns() self.columns = self.file_manager.columns()
self.rows = self.file_manager.rows() self.rows = self.file_manager.rows()
@ -140,7 +139,9 @@ class FileManagementFormView(MultiStepFormView):
# Get file # Get file
file = upload_files.get('upload-file', None) file = upload_files.get('upload-file', None)
if file: if file:
self.file_manager = self.file_manager_class(file=file, name=self.name) self.file_manager = self.file_manager_class(
file=file, name=self.name
)
def get_form_kwargs(self, step=None): def get_form_kwargs(self, step=None):
"""Update kwargs to dynamically build forms.""" """Update kwargs to dynamically build forms."""
@ -150,15 +151,11 @@ class FileManagementFormView(MultiStepFormView):
if step == 'upload': if step == 'upload':
# Dynamically build upload form # Dynamically build upload form
if self.name: if self.name:
kwargs = { kwargs = {'name': self.name}
'name': self.name
}
return kwargs return kwargs
elif step == 'fields': elif step == 'fields':
# Dynamically build match field form # Dynamically build match field form
kwargs = { kwargs = {'file_manager': self.file_manager}
'file_manager': self.file_manager
}
return kwargs return kwargs
elif step == 'items': elif step == 'items':
# Dynamically build match item form # Dynamically build match item form
@ -206,7 +203,6 @@ class FileManagementFormView(MultiStepFormView):
self.row_data = {} self.row_data = {}
for item, value in form_data.items(): for item, value in form_data.items():
# Column names as passed as col_name_<idx> where idx is an integer # Column names as passed as col_name_<idx> where idx is an integer
# Extract the column names # Extract the column names
@ -220,7 +216,6 @@ class FileManagementFormView(MultiStepFormView):
# Extract the column selections (in the 'select fields' view) # Extract the column selections (in the 'select fields' view)
if item.startswith('fields-'): if item.startswith('fields-'):
try: try:
col_name = item.replace('fields-', '') col_name = item.replace('fields-', '')
except ValueError: except ValueError:
@ -258,10 +253,7 @@ class FileManagementFormView(MultiStepFormView):
self.columns = [] self.columns = []
for idx, value in self.column_names.items(): for idx, value in self.column_names.items():
header = ({ header = {'name': value, 'guess': self.column_selections.get(idx, '')}
'name': value,
'guess': self.column_selections.get(idx, ''),
})
self.columns.append(header) self.columns.append(header)
if self.row_data: if self.row_data:
@ -280,18 +272,10 @@ class FileManagementFormView(MultiStepFormView):
'guess': self.column_selections[idx], 'guess': self.column_selections[idx],
} }
cell_data = { cell_data = {'cell': item, 'idx': idx, 'column': column_data}
'cell': item,
'idx': idx,
'column': column_data,
}
data.append(cell_data) data.append(cell_data)
row = { row = {'index': row_idx, 'data': data, 'errors': {}}
'index': row_idx,
'data': data,
'errors': {},
}
self.rows.append(row) self.rows.append(row)
@ -344,11 +328,7 @@ class FileManagementFormView(MultiStepFormView):
try: try:
if idx not in items: if idx not in items:
# Insert into items # Insert into items
items.update({ items.update({idx: {self.form_field_map[field]: form_value}})
idx: {
self.form_field_map[field]: form_value,
}
})
else: else:
# Update items # Update items
items[idx][self.form_field_map[field]] = form_value items[idx][self.form_field_map[field]] = form_value
@ -383,14 +363,15 @@ class FileManagementFormView(MultiStepFormView):
duplicates = [] duplicates = []
for col in self.column_names: for col in self.column_names:
if col in self.column_selections: if col in self.column_selections:
guess = self.column_selections[col] guess = self.column_selections[col]
else: else:
guess = None guess = None
if guess: if guess:
n = list(self.column_selections.values()).count(self.column_selections[col]) n = list(self.column_selections.values()).count(
self.column_selections[col]
)
if n > 1 and self.column_selections[col] not in duplicates: if n > 1 and self.column_selections[col] not in duplicates:
duplicates.append(self.column_selections[col]) duplicates.append(self.column_selections[col])
@ -459,7 +440,9 @@ class FileManagementAjaxView(AjaxView):
wizard_back = self.request.POST.get('act-btn_back', None) wizard_back = self.request.POST.get('act-btn_back', None)
if wizard_back: if wizard_back:
back_step_index = self.get_step_index() - 1 back_step_index = self.get_step_index() - 1
self.storage.current_step = list(self.get_form_list().keys())[back_step_index] self.storage.current_step = list(self.get_form_list().keys())[
back_step_index
]
return self.renderJsonResponse(request, data={'form_valid': None}) return self.renderJsonResponse(request, data={'form_valid': None})
# validate form # validate form
@ -499,13 +482,19 @@ class FileManagementAjaxView(AjaxView):
data = {} data = {}
self.setTemplate() self.setTemplate()
return super().renderJsonResponse(request, form=form, data=data, context=context) return super().renderJsonResponse(
request, form=form, data=data, context=context
)
def get_data(self) -> dict: def get_data(self) -> dict:
"""Get extra context data.""" """Get extra context data."""
data = super().get_data() data = super().get_data()
data['hideErrorMessage'] = '1' # hide the error data['hideErrorMessage'] = '1' # hide the error
buttons = [{'name': 'back', 'title': _('Previous Step')}] if self.get_step_index() > 0 else [] buttons = (
[{'name': 'back', 'title': _('Previous Step')}]
if self.get_step_index() > 0
else []
)
data['buttons'] = buttons # set buttons data['buttons'] = buttons # set buttons
return data return data

View File

@ -9,9 +9,16 @@ from import_export.fields import Field
from InvenTree.admin import InvenTreeResource from InvenTree.admin import InvenTreeResource
from part.models import Part from part.models import Part
from .models import (Address, Company, Contact, ManufacturerPart, from .models import (
ManufacturerPartAttachment, ManufacturerPartParameter, Address,
SupplierPart, SupplierPriceBreak) Company,
Contact,
ManufacturerPart,
ManufacturerPartAttachment,
ManufacturerPartParameter,
SupplierPart,
SupplierPriceBreak,
)
class CompanyResource(InvenTreeResource): class CompanyResource(InvenTreeResource):
@ -19,6 +26,7 @@ class CompanyResource(InvenTreeResource):
class Meta: class Meta:
"""Metaclass defines extra options""" """Metaclass defines extra options"""
model = Company model = Company
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -32,10 +40,7 @@ class CompanyAdmin(ImportExportModelAdmin):
list_display = ('name', 'website', 'contact') list_display = ('name', 'website', 'contact')
search_fields = [ search_fields = ['name', 'description']
'name',
'description',
]
class SupplierPartResource(InvenTreeResource): class SupplierPartResource(InvenTreeResource):
@ -43,6 +48,7 @@ class SupplierPartResource(InvenTreeResource):
class Meta: class Meta:
"""Metaclass defines extra admin options""" """Metaclass defines extra admin options"""
model = SupplierPart model = SupplierPart
skip_unchanged = True skip_unchanged = True
report_skipped = True report_skipped = True
@ -70,18 +76,11 @@ class SupplierPartAdmin(ImportExportModelAdmin):
list_display = ('part', 'supplier', 'SKU') list_display = ('part', 'supplier', 'SKU')
search_fields = [ search_fields = ['supplier__name', 'part__name', 'manufacturer_part__MPN', 'SKU']
'supplier__name',
'part__name',
'manufacturer_part__MPN',
'SKU',
]
inlines = [ inlines = [SupplierPriceBreakInline]
SupplierPriceBreakInline,
]
autocomplete_fields = ('part', 'supplier', 'manufacturer_part',) autocomplete_fields = ('part', 'supplier', 'manufacturer_part')
class ManufacturerPartResource(InvenTreeResource): class ManufacturerPartResource(InvenTreeResource):
@ -89,6 +88,7 @@ class ManufacturerPartResource(InvenTreeResource):
class Meta: class Meta:
"""Metaclass defines extra admin options""" """Metaclass defines extra admin options"""
model = ManufacturerPart model = ManufacturerPart
skip_unchanged = True skip_unchanged = True
report_skipped = True report_skipped = True
@ -98,7 +98,9 @@ class ManufacturerPartResource(InvenTreeResource):
part_name = Field(attribute='part__full_name', readonly=True) part_name = Field(attribute='part__full_name', readonly=True)
manufacturer = Field(attribute='manufacturer', widget=widgets.ForeignKeyWidget(Company)) manufacturer = Field(
attribute='manufacturer', widget=widgets.ForeignKeyWidget(Company)
)
manufacturer_name = Field(attribute='manufacturer__name', readonly=True) manufacturer_name = Field(attribute='manufacturer__name', readonly=True)
@ -110,13 +112,9 @@ class ManufacturerPartAdmin(ImportExportModelAdmin):
list_display = ('part', 'manufacturer', 'MPN') list_display = ('part', 'manufacturer', 'MPN')
search_fields = [ search_fields = ['manufacturer__name', 'part__name', 'MPN']
'manufacturer__name',
'part__name',
'MPN',
]
autocomplete_fields = ('part', 'manufacturer',) autocomplete_fields = ('part', 'manufacturer')
class ManufacturerPartAttachmentAdmin(ImportExportModelAdmin): class ManufacturerPartAttachmentAdmin(ImportExportModelAdmin):
@ -132,6 +130,7 @@ class ManufacturerPartParameterResource(InvenTreeResource):
class Meta: class Meta:
"""Metaclass defines extra admin options""" """Metaclass defines extra admin options"""
model = ManufacturerPartParameter model = ManufacturerPartParameter
skip_unchanged = True skip_unchanged = True
report_skipped = True report_skipped = True
@ -145,11 +144,7 @@ class ManufacturerPartParameterAdmin(ImportExportModelAdmin):
list_display = ('manufacturer_part', 'name', 'value') list_display = ('manufacturer_part', 'name', 'value')
search_fields = [ search_fields = ['manufacturer_part__manufacturer__name', 'name', 'value']
'manufacturer_part__manufacturer__name',
'name',
'value'
]
autocomplete_fields = ('manufacturer_part',) autocomplete_fields = ('manufacturer_part',)
@ -159,6 +154,7 @@ class SupplierPriceBreakResource(InvenTreeResource):
class Meta: class Meta:
"""Metaclass defines extra admin options""" """Metaclass defines extra admin options"""
model = SupplierPriceBreak model = SupplierPriceBreak
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -192,6 +188,7 @@ class AddressResource(InvenTreeResource):
class Meta: class Meta:
"""Metaclass defining extra options""" """Metaclass defining extra options"""
model = Address model = Address
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -207,11 +204,7 @@ class AddressAdmin(ImportExportModelAdmin):
list_display = ('company', 'line1', 'postal_code', 'country') list_display = ('company', 'line1', 'postal_code', 'country')
search_fields = [ search_fields = ['company', 'country', 'postal_code']
'company',
'country',
'postal_code',
]
class ContactResource(InvenTreeResource): class ContactResource(InvenTreeResource):
@ -219,6 +212,7 @@ class ContactResource(InvenTreeResource):
class Meta: class Meta:
"""Metaclass defining extra options""" """Metaclass defining extra options"""
model = Contact model = Contact
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -234,11 +228,7 @@ class ContactAdmin(ImportExportModelAdmin):
list_display = ('company', 'name', 'role', 'email', 'phone') list_display = ('company', 'name', 'role', 'email', 'phone')
search_fields = [ search_fields = ['company', 'name', 'email']
'company',
'name',
'email',
]
admin.site.register(Company, CompanyAdmin) admin.site.register(Company, CompanyAdmin)

View File

@ -6,23 +6,37 @@ from django.urls import include, path, re_path
from django_filters import rest_framework as rest_filters from django_filters import rest_framework as rest_filters
import part.models import part.models
from InvenTree.api import (AttachmentMixin, ListCreateDestroyAPIView, from InvenTree.api import AttachmentMixin, ListCreateDestroyAPIView, MetadataView
MetadataView) from InvenTree.filters import (
from InvenTree.filters import (ORDER_FILTER, SEARCH_ORDER_FILTER, ORDER_FILTER,
SEARCH_ORDER_FILTER_ALIAS) SEARCH_ORDER_FILTER,
SEARCH_ORDER_FILTER_ALIAS,
)
from InvenTree.helpers import str2bool from InvenTree.helpers import str2bool
from InvenTree.mixins import ListCreateAPI, RetrieveUpdateDestroyAPI from InvenTree.mixins import ListCreateAPI, RetrieveUpdateDestroyAPI
from .models import (Address, Company, CompanyAttachment, Contact, from .models import (
ManufacturerPart, ManufacturerPartAttachment, Address,
ManufacturerPartParameter, SupplierPart, Company,
SupplierPriceBreak) CompanyAttachment,
from .serializers import (AddressSerializer, CompanyAttachmentSerializer, Contact,
CompanySerializer, ContactSerializer, ManufacturerPart,
ManufacturerPartAttachment,
ManufacturerPartParameter,
SupplierPart,
SupplierPriceBreak,
)
from .serializers import (
AddressSerializer,
CompanyAttachmentSerializer,
CompanySerializer,
ContactSerializer,
ManufacturerPartAttachmentSerializer, ManufacturerPartAttachmentSerializer,
ManufacturerPartParameterSerializer, ManufacturerPartParameterSerializer,
ManufacturerPartSerializer, SupplierPartSerializer, ManufacturerPartSerializer,
SupplierPriceBreakSerializer) SupplierPartSerializer,
SupplierPriceBreakSerializer,
)
class CompanyList(ListCreateAPI): class CompanyList(ListCreateAPI):
@ -46,24 +60,11 @@ class CompanyList(ListCreateAPI):
filter_backends = SEARCH_ORDER_FILTER filter_backends = SEARCH_ORDER_FILTER
filterset_fields = [ filterset_fields = ['is_customer', 'is_manufacturer', 'is_supplier', 'name']
'is_customer',
'is_manufacturer',
'is_supplier',
'name',
]
search_fields = [ search_fields = ['name', 'description', 'website']
'name',
'description',
'website',
]
ordering_fields = [ ordering_fields = ['name', 'parts_supplied', 'parts_manufactured']
'name',
'parts_supplied',
'parts_manufactured',
]
ordering = 'name' ordering = 'name'
@ -88,9 +89,7 @@ class CompanyAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
queryset = CompanyAttachment.objects.all() queryset = CompanyAttachment.objects.all()
serializer_class = CompanyAttachmentSerializer serializer_class = CompanyAttachmentSerializer
filterset_fields = [ filterset_fields = ['company']
'company',
]
class CompanyAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): class CompanyAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI):
@ -108,18 +107,11 @@ class ContactList(ListCreateDestroyAPIView):
filter_backends = SEARCH_ORDER_FILTER filter_backends = SEARCH_ORDER_FILTER
filterset_fields = [ filterset_fields = ['company']
'company',
]
search_fields = [ search_fields = ['company__name', 'name']
'company__name',
'name',
]
ordering_fields = [ ordering_fields = ['name']
'name',
]
ordering = 'name' ordering = 'name'
@ -139,13 +131,9 @@ class AddressList(ListCreateDestroyAPIView):
filter_backends = SEARCH_ORDER_FILTER filter_backends = SEARCH_ORDER_FILTER
filterset_fields = [ filterset_fields = ['company']
'company',
]
ordering_fields = [ ordering_fields = ['title']
'title',
]
ordering = 'title' ordering = 'title'
@ -164,13 +152,7 @@ class ManufacturerPartFilter(rest_filters.FilterSet):
"""Metaclass options.""" """Metaclass options."""
model = ManufacturerPart model = ManufacturerPart
fields = [ fields = ['manufacturer', 'MPN', 'part', 'tags__name', 'tags__slug']
'manufacturer',
'MPN',
'part',
'tags__name',
'tags__slug',
]
# Filter by 'active' status of linked part # Filter by 'active' status of linked part
active = rest_filters.BooleanFilter(field_name='part__active') active = rest_filters.BooleanFilter(field_name='part__active')
@ -184,10 +166,7 @@ class ManufacturerPartList(ListCreateDestroyAPIView):
""" """
queryset = ManufacturerPart.objects.all().prefetch_related( queryset = ManufacturerPart.objects.all().prefetch_related(
'part', 'part', 'manufacturer', 'supplier_parts', 'tags'
'manufacturer',
'supplier_parts',
'tags',
) )
serializer_class = ManufacturerPartSerializer serializer_class = ManufacturerPartSerializer
@ -200,7 +179,9 @@ class ManufacturerPartList(ListCreateDestroyAPIView):
params = self.request.query_params params = self.request.query_params
kwargs['part_detail'] = str2bool(params.get('part_detail', None)) kwargs['part_detail'] = str2bool(params.get('part_detail', None))
kwargs['manufacturer_detail'] = str2bool(params.get('manufacturer_detail', None)) kwargs['manufacturer_detail'] = str2bool(
params.get('manufacturer_detail', None)
)
kwargs['pretty'] = str2bool(params.get('pretty', None)) kwargs['pretty'] = str2bool(params.get('pretty', None))
except AttributeError: except AttributeError:
pass pass
@ -241,9 +222,7 @@ class ManufacturerPartAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
queryset = ManufacturerPartAttachment.objects.all() queryset = ManufacturerPartAttachment.objects.all()
serializer_class = ManufacturerPartAttachmentSerializer serializer_class = ManufacturerPartAttachmentSerializer
filterset_fields = [ filterset_fields = ['manufacturer_part']
'manufacturer_part',
]
class ManufacturerPartAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI): class ManufacturerPartAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI):
@ -258,17 +237,17 @@ class ManufacturerPartParameterFilter(rest_filters.FilterSet):
class Meta: class Meta:
"""Metaclass options""" """Metaclass options"""
model = ManufacturerPartParameter model = ManufacturerPartParameter
fields = [ fields = ['name', 'value', 'units', 'manufacturer_part']
'name',
'value',
'units',
'manufacturer_part',
]
manufacturer = rest_filters.ModelChoiceFilter(queryset=Company.objects.all(), field_name='manufacturer_part__manufacturer') manufacturer = rest_filters.ModelChoiceFilter(
queryset=Company.objects.all(), field_name='manufacturer_part__manufacturer'
)
part = rest_filters.ModelChoiceFilter(queryset=part.models.Part.objects.all(), field_name='manufacturer_part__part') part = rest_filters.ModelChoiceFilter(
queryset=part.models.Part.objects.all(), field_name='manufacturer_part__part'
)
class ManufacturerPartParameterList(ListCreateDestroyAPIView): class ManufacturerPartParameterList(ListCreateDestroyAPIView):
@ -284,9 +263,7 @@ class ManufacturerPartParameterList(ListCreateDestroyAPIView):
try: try:
params = self.request.query_params params = self.request.query_params
optional_fields = [ optional_fields = ['manufacturer_part_detail']
'manufacturer_part_detail',
]
for key in optional_fields: for key in optional_fields:
kwargs[key] = str2bool(params.get(key, None)) kwargs[key] = str2bool(params.get(key, None))
@ -300,11 +277,7 @@ class ManufacturerPartParameterList(ListCreateDestroyAPIView):
filter_backends = SEARCH_ORDER_FILTER filter_backends = SEARCH_ORDER_FILTER
search_fields = [ search_fields = ['name', 'value', 'units']
'name',
'value',
'units',
]
class ManufacturerPartParameterDetail(RetrieveUpdateDestroyAPI): class ManufacturerPartParameterDetail(RetrieveUpdateDestroyAPI):
@ -337,7 +310,7 @@ class SupplierPartFilter(rest_filters.FilterSet):
MPN = rest_filters.CharFilter( MPN = rest_filters.CharFilter(
label='Manufacturer Part Number', label='Manufacturer Part Number',
field_name='manufacturer_part__MPN', field_name='manufacturer_part__MPN',
lookup_expr='iexact' lookup_expr='iexact',
) )
@ -348,9 +321,7 @@ class SupplierPartList(ListCreateDestroyAPIView):
- POST: Create a new SupplierPart object - POST: Create a new SupplierPart object
""" """
queryset = SupplierPart.objects.all().prefetch_related( queryset = SupplierPart.objects.all().prefetch_related('tags')
'tags',
)
filterset_class = SupplierPartFilter filterset_class = SupplierPartFilter
def get_queryset(self, *args, **kwargs): def get_queryset(self, *args, **kwargs):
@ -376,7 +347,9 @@ class SupplierPartList(ListCreateDestroyAPIView):
company = params.get('company', None) company = params.get('company', None)
if company is not None: if company is not None:
queryset = queryset.filter(Q(manufacturer_part__manufacturer=company) | Q(supplier=company)).distinct() queryset = queryset.filter(
Q(manufacturer_part__manufacturer=company) | Q(supplier=company)
).distinct()
return queryset return queryset
@ -387,7 +360,9 @@ class SupplierPartList(ListCreateDestroyAPIView):
params = self.request.query_params params = self.request.query_params
kwargs['part_detail'] = str2bool(params.get('part_detail', None)) kwargs['part_detail'] = str2bool(params.get('part_detail', None))
kwargs['supplier_detail'] = str2bool(params.get('supplier_detail', True)) kwargs['supplier_detail'] = str2bool(params.get('supplier_detail', True))
kwargs['manufacturer_detail'] = str2bool(params.get('manufacturer_detail', None)) kwargs['manufacturer_detail'] = str2bool(
params.get('manufacturer_detail', None)
)
kwargs['pretty'] = str2bool(params.get('pretty', None)) kwargs['pretty'] = str2bool(params.get('pretty', None))
except AttributeError: except AttributeError:
pass pass
@ -446,8 +421,7 @@ class SupplierPartDetail(RetrieveUpdateDestroyAPI):
queryset = SupplierPart.objects.all() queryset = SupplierPart.objects.all()
serializer_class = SupplierPartSerializer serializer_class = SupplierPartSerializer
read_only_fields = [ read_only_fields = []
]
class SupplierPriceBreakFilter(rest_filters.FilterSet): class SupplierPriceBreakFilter(rest_filters.FilterSet):
@ -457,10 +431,7 @@ class SupplierPriceBreakFilter(rest_filters.FilterSet):
"""Metaclass options""" """Metaclass options"""
model = SupplierPriceBreak model = SupplierPriceBreak
fields = [ fields = ['part', 'quantity']
'part',
'quantity',
]
base_part = rest_filters.ModelChoiceFilter( base_part = rest_filters.ModelChoiceFilter(
label='Base Part', label='Base Part',
@ -469,9 +440,7 @@ class SupplierPriceBreakFilter(rest_filters.FilterSet):
) )
supplier = rest_filters.ModelChoiceFilter( supplier = rest_filters.ModelChoiceFilter(
label='Supplier', label='Supplier', queryset=Company.objects.all(), field_name='part__supplier'
queryset=Company.objects.all(),
field_name='part__supplier',
) )
@ -503,9 +472,7 @@ class SupplierPriceBreakList(ListCreateAPI):
filter_backends = ORDER_FILTER filter_backends = ORDER_FILTER
ordering_fields = [ ordering_fields = ['quantity']
'quantity',
]
ordering = 'quantity' ordering = 'quantity'
@ -518,37 +485,74 @@ class SupplierPriceBreakDetail(RetrieveUpdateDestroyAPI):
manufacturer_part_api_urls = [ manufacturer_part_api_urls = [
# Base URL for ManufacturerPartAttachment API endpoints # Base URL for ManufacturerPartAttachment API endpoints
re_path(r'^attachment/', include([ re_path(
path(r'<int:pk>/', ManufacturerPartAttachmentDetail.as_view(), name='api-manufacturer-part-attachment-detail'), r'^attachment/',
re_path(r'^$', ManufacturerPartAttachmentList.as_view(), name='api-manufacturer-part-attachment-list'), include([
])), path(
r'<int:pk>/',
re_path(r'^parameter/', include([ ManufacturerPartAttachmentDetail.as_view(),
path(r'<int:pk>/', ManufacturerPartParameterDetail.as_view(), name='api-manufacturer-part-parameter-detail'), name='api-manufacturer-part-attachment-detail',
),
re_path(
r'^$',
ManufacturerPartAttachmentList.as_view(),
name='api-manufacturer-part-attachment-list',
),
]),
),
re_path(
r'^parameter/',
include([
path(
r'<int:pk>/',
ManufacturerPartParameterDetail.as_view(),
name='api-manufacturer-part-parameter-detail',
),
# Catch anything else # Catch anything else
re_path(r'^.*$', ManufacturerPartParameterList.as_view(), name='api-manufacturer-part-parameter-list'), re_path(
])), r'^.*$',
ManufacturerPartParameterList.as_view(),
re_path(r'^(?P<pk>\d+)/?', include([ name='api-manufacturer-part-parameter-list',
re_path('^metadata/', MetadataView.as_view(), {'model': ManufacturerPart}, name='api-manufacturer-part-metadata'), ),
re_path('^.*$', ManufacturerPartDetail.as_view(), name='api-manufacturer-part-detail'), ]),
])), ),
re_path(
r'^(?P<pk>\d+)/?',
include([
re_path(
'^metadata/',
MetadataView.as_view(),
{'model': ManufacturerPart},
name='api-manufacturer-part-metadata',
),
re_path(
'^.*$',
ManufacturerPartDetail.as_view(),
name='api-manufacturer-part-detail',
),
]),
),
# Catch anything else # Catch anything else
re_path(r'^.*$', ManufacturerPartList.as_view(), name='api-manufacturer-part-list'), re_path(r'^.*$', ManufacturerPartList.as_view(), name='api-manufacturer-part-list'),
] ]
supplier_part_api_urls = [ supplier_part_api_urls = [
re_path(
re_path(r'^(?P<pk>\d+)/?', include([ r'^(?P<pk>\d+)/?',
re_path('^metadata/', MetadataView.as_view(), {'model': SupplierPart}, name='api-supplier-part-metadata'), include([
re_path('^.*$', SupplierPartDetail.as_view(), name='api-supplier-part-detail'), re_path(
])), '^metadata/',
MetadataView.as_view(),
{'model': SupplierPart},
name='api-supplier-part-metadata',
),
re_path(
'^.*$', SupplierPartDetail.as_view(), name='api-supplier-part-detail'
),
]),
),
# Catch anything else # Catch anything else
re_path(r'^.*$', SupplierPartList.as_view(), name='api-supplier-part-list'), re_path(r'^.*$', SupplierPartList.as_view(), name='api-supplier-part-list'),
] ]
@ -556,39 +560,74 @@ supplier_part_api_urls = [
company_api_urls = [ company_api_urls = [
re_path(r'^part/manufacturer/', include(manufacturer_part_api_urls)), re_path(r'^part/manufacturer/', include(manufacturer_part_api_urls)),
re_path(r'^part/', include(supplier_part_api_urls)), re_path(r'^part/', include(supplier_part_api_urls)),
# Supplier price breaks # Supplier price breaks
re_path(r'^price-break/', include([ re_path(
r'^price-break/',
re_path(r'^(?P<pk>\d+)/?', SupplierPriceBreakDetail.as_view(), name='api-part-supplier-price-detail'), include([
re_path(r'^.*$', SupplierPriceBreakList.as_view(), name='api-part-supplier-price-list'), re_path(
])), r'^(?P<pk>\d+)/?',
SupplierPriceBreakDetail.as_view(),
re_path(r'^(?P<pk>\d+)/?', include([ name='api-part-supplier-price-detail',
re_path(r'^metadata/', MetadataView.as_view(), {'model': Company}, name='api-company-metadata'), ),
re_path(
r'^.*$',
SupplierPriceBreakList.as_view(),
name='api-part-supplier-price-list',
),
]),
),
re_path(
r'^(?P<pk>\d+)/?',
include([
re_path(
r'^metadata/',
MetadataView.as_view(),
{'model': Company},
name='api-company-metadata',
),
re_path(r'^.*$', CompanyDetail.as_view(), name='api-company-detail'), re_path(r'^.*$', CompanyDetail.as_view(), name='api-company-detail'),
])), ]),
),
re_path(r'^attachment/', include([ re_path(
path(r'<int:pk>/', CompanyAttachmentDetail.as_view(), name='api-company-attachment-detail'), r'^attachment/',
re_path(r'^$', CompanyAttachmentList.as_view(), name='api-company-attachment-list'), include([
])), path(
r'<int:pk>/',
re_path(r'^contact/', include([ CompanyAttachmentDetail.as_view(),
re_path(r'^(?P<pk>\d+)/?', include([ name='api-company-attachment-detail',
re_path('^metadata/', MetadataView.as_view(), {'model': Contact}, name='api-contact-metadata'), ),
re_path(
r'^$',
CompanyAttachmentList.as_view(),
name='api-company-attachment-list',
),
]),
),
re_path(
r'^contact/',
include([
re_path(
r'^(?P<pk>\d+)/?',
include([
re_path(
'^metadata/',
MetadataView.as_view(),
{'model': Contact},
name='api-contact-metadata',
),
re_path('^.*$', ContactDetail.as_view(), name='api-contact-detail'), re_path('^.*$', ContactDetail.as_view(), name='api-contact-detail'),
])), ]),
),
re_path(r'^.*$', ContactList.as_view(), name='api-contact-list'), re_path(r'^.*$', ContactList.as_view(), name='api-contact-list'),
])), ]),
),
re_path(r'^address/', include([ re_path(
r'^address/',
include([
path('<int:pk>/', AddressDetail.as_view(), name='api-address-detail'), path('<int:pk>/', AddressDetail.as_view(), name='api-address-detail'),
re_path(r'^.*$', AddressList.as_view(), name='api-address-list'), re_path(r'^.*$', AddressList.as_view(), name='api-address-list'),
])), ]),
),
re_path(r'^.*$', CompanyList.as_view(), name='api-company-list'), re_path(r'^.*$', CompanyList.as_view(), name='api-company-list'),
] ]

View File

@ -29,8 +29,12 @@ import InvenTree.tasks
import InvenTree.validators import InvenTree.validators
from common.settings import currency_code_default from common.settings import currency_code_default
from InvenTree.fields import InvenTreeURLField, RoundingDecimalField from InvenTree.fields import InvenTreeURLField, RoundingDecimalField
from InvenTree.models import (InvenTreeAttachment, InvenTreeBarcodeMixin, from InvenTree.models import (
InvenTreeNotesMixin, MetadataMixin) InvenTreeAttachment,
InvenTreeBarcodeMixin,
InvenTreeNotesMixin,
MetadataMixin,
)
from InvenTree.status_codes import PurchaseOrderStatusGroups from InvenTree.status_codes import PurchaseOrderStatusGroups
@ -87,20 +91,24 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options"""
ordering = ['name', ]
ordering = ['name']
constraints = [ constraints = [
UniqueConstraint(fields=['name', 'email'], name='unique_name_email_pair') UniqueConstraint(fields=['name', 'email'], name='unique_name_email_pair')
] ]
verbose_name_plural = "Companies" verbose_name_plural = 'Companies'
@staticmethod @staticmethod
def get_api_url(): def get_api_url():
"""Return the API URL associated with the Company model""" """Return the API URL associated with the Company model"""
return reverse('api-company-list') return reverse('api-company-list')
name = models.CharField(max_length=100, blank=False, name = models.CharField(
max_length=100,
blank=False,
help_text=_('Company name'), help_text=_('Company name'),
verbose_name=_('Company name')) verbose_name=_('Company name'),
)
description = models.CharField( description = models.CharField(
max_length=500, max_length=500,
@ -110,41 +118,62 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
) )
website = InvenTreeURLField( website = InvenTreeURLField(
blank=True, blank=True, verbose_name=_('Website'), help_text=_('Company website URL')
verbose_name=_('Website'),
help_text=_('Company website URL')
) )
phone = models.CharField(max_length=50, phone = models.CharField(
max_length=50,
verbose_name=_('Phone number'), verbose_name=_('Phone number'),
blank=True, help_text=_('Contact phone number')) blank=True,
help_text=_('Contact phone number'),
)
email = models.EmailField(blank=True, null=True, email = models.EmailField(
verbose_name=_('Email'), help_text=_('Contact email address')) blank=True,
null=True,
verbose_name=_('Email'),
help_text=_('Contact email address'),
)
contact = models.CharField(max_length=100, contact = models.CharField(
max_length=100,
verbose_name=_('Contact'), verbose_name=_('Contact'),
blank=True, help_text=_('Point of contact')) blank=True,
help_text=_('Point of contact'),
)
link = InvenTreeURLField(blank=True, verbose_name=_('Link'), help_text=_('Link to external company information')) link = InvenTreeURLField(
blank=True,
verbose_name=_('Link'),
help_text=_('Link to external company information'),
)
image = StdImageField( image = StdImageField(
upload_to=rename_company_image, upload_to=rename_company_image,
null=True, null=True,
blank=True, blank=True,
variations={ variations={'thumbnail': (128, 128), 'preview': (256, 256)},
'thumbnail': (128, 128),
'preview': (256, 256),
},
delete_orphans=True, delete_orphans=True,
verbose_name=_('Image'), verbose_name=_('Image'),
) )
is_customer = models.BooleanField(default=False, verbose_name=_('is customer'), help_text=_('Do you sell items to this company?')) is_customer = models.BooleanField(
default=False,
verbose_name=_('is customer'),
help_text=_('Do you sell items to this company?'),
)
is_supplier = models.BooleanField(default=True, verbose_name=_('is supplier'), help_text=_('Do you purchase items from this company?')) is_supplier = models.BooleanField(
default=True,
verbose_name=_('is supplier'),
help_text=_('Do you purchase items from this company?'),
)
is_manufacturer = models.BooleanField(default=False, verbose_name=_('is manufacturer'), help_text=_('Does this company manufacture parts?')) is_manufacturer = models.BooleanField(
default=False,
verbose_name=_('is manufacturer'),
help_text=_('Does this company manufacture parts?'),
)
currency = models.CharField( currency = models.CharField(
max_length=3, max_length=3,
@ -186,7 +215,7 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
def __str__(self): def __str__(self):
"""Get string representation of a Company.""" """Get string representation of a Company."""
return f"{self.name} - {self.description}" return f'{self.name} - {self.description}'
def get_absolute_url(self): def get_absolute_url(self):
"""Get the web URL for the detail view for this Company.""" """Get the web URL for the detail view for this Company."""
@ -207,13 +236,18 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
@property @property
def parts(self): def parts(self):
"""Return SupplierPart objects which are supplied or manufactured by this company.""" """Return SupplierPart objects which are supplied or manufactured by this company."""
return SupplierPart.objects.filter(Q(supplier=self.id) | Q(manufacturer_part__manufacturer=self.id)).distinct() return SupplierPart.objects.filter(
Q(supplier=self.id) | Q(manufacturer_part__manufacturer=self.id)
).distinct()
@property @property
def stock_items(self): def stock_items(self):
"""Return a list of all stock items supplied or manufactured by this company.""" """Return a list of all stock items supplied or manufactured by this company."""
stock = apps.get_model('stock', 'StockItem') stock = apps.get_model('stock', 'StockItem')
return stock.objects.filter(Q(supplier_part__supplier=self.id) | Q(supplier_part__manufacturer_part__manufacturer=self.id)).distinct() return stock.objects.filter(
Q(supplier_part__supplier=self.id)
| Q(supplier_part__manufacturer_part__manufacturer=self.id)
).distinct()
class CompanyAttachment(InvenTreeAttachment): class CompanyAttachment(InvenTreeAttachment):
@ -229,7 +263,8 @@ class CompanyAttachment(InvenTreeAttachment):
return os.path.join('company_files', str(self.company.pk)) return os.path.join('company_files', str(self.company.pk))
company = models.ForeignKey( company = models.ForeignKey(
Company, on_delete=models.CASCADE, Company,
on_delete=models.CASCADE,
verbose_name=_('Company'), verbose_name=_('Company'),
related_name='attachments', related_name='attachments',
) )
@ -251,8 +286,9 @@ class Contact(MetadataMixin, models.Model):
"""Return the API URL associated with the Contcat model""" """Return the API URL associated with the Contcat model"""
return reverse('api-contact-list') return reverse('api-contact-list')
company = models.ForeignKey(Company, related_name='contacts', company = models.ForeignKey(
on_delete=models.CASCADE) Company, related_name='contacts', on_delete=models.CASCADE
)
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
@ -281,7 +317,8 @@ class Address(models.Model):
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options"""
verbose_name_plural = "Addresses"
verbose_name_plural = 'Addresses'
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
"""Custom init function""" """Custom init function"""
@ -289,12 +326,13 @@ class Address(models.Model):
def __str__(self): def __str__(self):
"""Defines string representation of address to supple a one-line to API calls""" """Defines string representation of address to supple a one-line to API calls"""
available_lines = [self.line1, available_lines = [
self.line1,
self.line2, self.line2,
self.postal_code, self.postal_code,
self.postal_city, self.postal_city,
self.province, self.province,
self.country self.country,
] ]
populated_lines = [] populated_lines = []
@ -302,14 +340,16 @@ class Address(models.Model):
if len(line) > 0: if len(line) > 0:
populated_lines.append(line) populated_lines.append(line)
return ", ".join(populated_lines) return ', '.join(populated_lines)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""Run checks when saving an address: """Run checks when saving an address:
- If this address is marked as "primary", ensure that all other addresses for this company are marked as non-primary - If this address is marked as "primary", ensure that all other addresses for this company are marked as non-primary
""" """
others = list(Address.objects.filter(company=self.company).exclude(pk=self.pk).all()) others = list(
Address.objects.filter(company=self.company).exclude(pk=self.pk).all()
)
# If this is the *only* address for this company, make it the primary one # If this is the *only* address for this company, make it the primary one
if len(others) == 0: if len(others) == 0:
@ -329,63 +369,88 @@ class Address(models.Model):
"""Return the API URL associated with the Contcat model""" """Return the API URL associated with the Contcat model"""
return reverse('api-address-list') return reverse('api-address-list')
company = models.ForeignKey(Company, related_name='addresses', company = models.ForeignKey(
Company,
related_name='addresses',
on_delete=models.CASCADE, on_delete=models.CASCADE,
verbose_name=_('Company'), verbose_name=_('Company'),
help_text=_('Select company')) help_text=_('Select company'),
)
title = models.CharField(max_length=100, title = models.CharField(
max_length=100,
verbose_name=_('Address title'), verbose_name=_('Address title'),
help_text=_('Title describing the address entry'), help_text=_('Title describing the address entry'),
blank=False) blank=False,
)
primary = models.BooleanField(default=False, primary = models.BooleanField(
default=False,
verbose_name=_('Primary address'), verbose_name=_('Primary address'),
help_text=_('Set as primary address')) help_text=_('Set as primary address'),
)
line1 = models.CharField(max_length=50, line1 = models.CharField(
max_length=50,
verbose_name=_('Line 1'), verbose_name=_('Line 1'),
help_text=_('Address line 1'), help_text=_('Address line 1'),
blank=True) blank=True,
)
line2 = models.CharField(max_length=50, line2 = models.CharField(
max_length=50,
verbose_name=_('Line 2'), verbose_name=_('Line 2'),
help_text=_('Address line 2'), help_text=_('Address line 2'),
blank=True) blank=True,
)
postal_code = models.CharField(max_length=10, postal_code = models.CharField(
max_length=10,
verbose_name=_('Postal code'), verbose_name=_('Postal code'),
help_text=_('Postal code'), help_text=_('Postal code'),
blank=True) blank=True,
)
postal_city = models.CharField(max_length=50, postal_city = models.CharField(
max_length=50,
verbose_name=_('City/Region'), verbose_name=_('City/Region'),
help_text=_('Postal code city/region'), help_text=_('Postal code city/region'),
blank=True) blank=True,
)
province = models.CharField(max_length=50, province = models.CharField(
max_length=50,
verbose_name=_('State/Province'), verbose_name=_('State/Province'),
help_text=_('State or province'), help_text=_('State or province'),
blank=True) blank=True,
)
country = models.CharField(max_length=50, country = models.CharField(
max_length=50,
verbose_name=_('Country'), verbose_name=_('Country'),
help_text=_('Address country'), help_text=_('Address country'),
blank=True) blank=True,
)
shipping_notes = models.CharField(max_length=100, shipping_notes = models.CharField(
max_length=100,
verbose_name=_('Courier shipping notes'), verbose_name=_('Courier shipping notes'),
help_text=_('Notes for shipping courier'), help_text=_('Notes for shipping courier'),
blank=True) blank=True,
)
internal_shipping_notes = models.CharField(max_length=100, internal_shipping_notes = models.CharField(
max_length=100,
verbose_name=_('Internal shipping notes'), verbose_name=_('Internal shipping notes'),
help_text=_('Shipping notes for internal use'), help_text=_('Shipping notes for internal use'),
blank=True) blank=True,
)
link = InvenTreeURLField(blank=True, link = InvenTreeURLField(
blank=True,
verbose_name=_('Link'), verbose_name=_('Link'),
help_text=_('Link to address information (external)')) help_text=_('Link to address information (external)'),
)
class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model): class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
@ -401,6 +466,7 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options"""
unique_together = ('part', 'manufacturer', 'MPN') unique_together = ('part', 'manufacturer', 'MPN')
@staticmethod @staticmethod
@ -408,12 +474,12 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
"""Return the API URL associated with the ManufacturerPart instance""" """Return the API URL associated with the ManufacturerPart instance"""
return reverse('api-manufacturer-part-list') return reverse('api-manufacturer-part-list')
part = models.ForeignKey('part.Part', on_delete=models.CASCADE, part = models.ForeignKey(
'part.Part',
on_delete=models.CASCADE,
related_name='manufacturer_parts', related_name='manufacturer_parts',
verbose_name=_('Base Part'), verbose_name=_('Base Part'),
limit_choices_to={ limit_choices_to={'purchaseable': True},
'purchaseable': True,
},
help_text=_('Select part'), help_text=_('Select part'),
) )
@ -422,9 +488,7 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
on_delete=models.CASCADE, on_delete=models.CASCADE,
null=True, null=True,
related_name='manufactured_parts', related_name='manufactured_parts',
limit_choices_to={ limit_choices_to={'is_manufacturer': True},
'is_manufacturer': True
},
verbose_name=_('Manufacturer'), verbose_name=_('Manufacturer'),
help_text=_('Select manufacturer'), help_text=_('Select manufacturer'),
) )
@ -433,19 +497,22 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
null=True, null=True,
max_length=100, max_length=100,
verbose_name=_('MPN'), verbose_name=_('MPN'),
help_text=_('Manufacturer Part Number') help_text=_('Manufacturer Part Number'),
) )
link = InvenTreeURLField( link = InvenTreeURLField(
blank=True, null=True, blank=True,
null=True,
verbose_name=_('Link'), verbose_name=_('Link'),
help_text=_('URL for external manufacturer part link') help_text=_('URL for external manufacturer part link'),
) )
description = models.CharField( description = models.CharField(
max_length=250, blank=True, null=True, max_length=250,
blank=True,
null=True,
verbose_name=_('Description'), verbose_name=_('Description'),
help_text=_('Manufacturer part description') help_text=_('Manufacturer part description'),
) )
tags = TaggableManager(blank=True) tags = TaggableManager(blank=True)
@ -456,12 +523,20 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
manufacturer_part = None manufacturer_part = None
try: try:
manufacturer_part = ManufacturerPart.objects.get(part=part, manufacturer=manufacturer, MPN=mpn) manufacturer_part = ManufacturerPart.objects.get(
part=part, manufacturer=manufacturer, MPN=mpn
)
except ManufacturerPart.DoesNotExist: except ManufacturerPart.DoesNotExist:
pass pass
if not manufacturer_part: if not manufacturer_part:
manufacturer_part = ManufacturerPart(part=part, manufacturer=manufacturer, MPN=mpn, description=description, link=link) manufacturer_part = ManufacturerPart(
part=part,
manufacturer=manufacturer,
MPN=mpn,
description=description,
link=link,
)
manufacturer_part.save() manufacturer_part.save()
return manufacturer_part return manufacturer_part
@ -489,10 +564,14 @@ class ManufacturerPartAttachment(InvenTreeAttachment):
def getSubdir(self): def getSubdir(self):
"""Return the subdirectory where attachment files for the ManufacturerPart model are located""" """Return the subdirectory where attachment files for the ManufacturerPart model are located"""
return os.path.join("manufacturer_part_files", str(self.manufacturer_part.id)) return os.path.join('manufacturer_part_files', str(self.manufacturer_part.id))
manufacturer_part = models.ForeignKey(ManufacturerPart, on_delete=models.CASCADE, manufacturer_part = models.ForeignKey(
verbose_name=_('Manufacturer Part'), related_name='attachments') ManufacturerPart,
on_delete=models.CASCADE,
verbose_name=_('Manufacturer Part'),
related_name='attachments',
)
class ManufacturerPartParameter(models.Model): class ManufacturerPartParameter(models.Model):
@ -505,6 +584,7 @@ class ManufacturerPartParameter(models.Model):
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options"""
unique_together = ('manufacturer_part', 'name') unique_together = ('manufacturer_part', 'name')
@staticmethod @staticmethod
@ -523,21 +603,22 @@ class ManufacturerPartParameter(models.Model):
max_length=500, max_length=500,
blank=False, blank=False,
verbose_name=_('Name'), verbose_name=_('Name'),
help_text=_('Parameter name') help_text=_('Parameter name'),
) )
value = models.CharField( value = models.CharField(
max_length=500, max_length=500,
blank=False, blank=False,
verbose_name=_('Value'), verbose_name=_('Value'),
help_text=_('Parameter value') help_text=_('Parameter value'),
) )
units = models.CharField( units = models.CharField(
max_length=64, max_length=64,
blank=True, null=True, blank=True,
null=True,
verbose_name=_('Units'), verbose_name=_('Units'),
help_text=_('Parameter units') help_text=_('Parameter units'),
) )
@ -551,10 +632,10 @@ class SupplierPartManager(models.Manager):
def get_queryset(self): def get_queryset(self):
"""Prefetch related fields when querying against the SupplierPart model""" """Prefetch related fields when querying against the SupplierPart model"""
# Always prefetch related models # Always prefetch related models
return super().get_queryset().prefetch_related( return (
'part', super()
'supplier', .get_queryset()
'manufacturer_part__manufacturer', .prefetch_related('part', 'supplier', 'manufacturer_part__manufacturer')
) )
@ -580,6 +661,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options"""
unique_together = ('part', 'supplier', 'SKU') unique_together = ('part', 'supplier', 'SKU')
# This model was moved from the 'Part' app # This model was moved from the 'Part' app
@ -600,11 +682,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
def api_instance_filters(self): def api_instance_filters(self):
"""Return custom API filters for this particular instance""" """Return custom API filters for this particular instance"""
return { return {'manufacturer_part': {'part': self.part.pk}}
'manufacturer_part': {
'part': self.part.pk
}
}
def clean(self): def clean(self):
"""Custom clean action for the SupplierPart model: """Custom clean action for the SupplierPart model:
@ -624,36 +702,38 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
try: try:
# Attempt conversion to specified unit # Attempt conversion to specified unit
native_value = InvenTree.conversion.convert_physical_value( native_value = InvenTree.conversion.convert_physical_value(
self.pack_quantity, self.part.units, self.pack_quantity, self.part.units, strip_units=False
strip_units=False
) )
# If part units are not provided, value must be dimensionless # If part units are not provided, value must be dimensionless
if not self.part.units and not InvenTree.conversion.is_dimensionless(native_value): if not self.part.units and not InvenTree.conversion.is_dimensionless(
native_value
):
raise ValidationError({ raise ValidationError({
'pack_quantity': _("Pack units must be compatible with the base part units") 'pack_quantity': _(
'Pack units must be compatible with the base part units'
)
}) })
# Native value must be greater than zero # Native value must be greater than zero
if float(native_value.magnitude) <= 0: if float(native_value.magnitude) <= 0:
raise ValidationError({ raise ValidationError({
'pack_quantity': _("Pack units must be greater than zero") 'pack_quantity': _('Pack units must be greater than zero')
}) })
# Update native pack units value # Update native pack units value
self.pack_quantity_native = Decimal(native_value.magnitude) self.pack_quantity_native = Decimal(native_value.magnitude)
except ValidationError as e: except ValidationError as e:
raise ValidationError({ raise ValidationError({'pack_quantity': e.messages})
'pack_quantity': e.messages
})
# Ensure that the linked manufacturer_part points to the same part! # Ensure that the linked manufacturer_part points to the same part!
if self.manufacturer_part and self.part: if self.manufacturer_part and self.part:
if self.manufacturer_part.part != self.part: if self.manufacturer_part.part != self.part:
raise ValidationError({ raise ValidationError({
'manufacturer_part': _("Linked manufacturer part must reference the same base part"), 'manufacturer_part': _(
'Linked manufacturer part must reference the same base part'
)
}) })
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
@ -666,7 +746,9 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
# Retrieve manufacturer part # Retrieve manufacturer part
try: try:
manufacturer_part = ManufacturerPart.objects.get(manufacturer__name=manufacturer_name, MPN=MPN) manufacturer_part = ManufacturerPart.objects.get(
manufacturer__name=manufacturer_name, MPN=MPN
)
except (ValueError, Company.DoesNotExist): except (ValueError, Company.DoesNotExist):
# ManufacturerPart does not exist # ManufacturerPart does not exist
pass pass
@ -676,7 +758,9 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
# Connect ManufacturerPart to SupplierPart # Connect ManufacturerPart to SupplierPart
self.manufacturer_part = manufacturer_part self.manufacturer_part = manufacturer_part
else: else:
raise ValidationError(f'SupplierPart {self.__str__} is already linked to {self.manufacturer_part}') raise ValidationError(
f'SupplierPart {self.__str__} is already linked to {self.manufacturer_part}'
)
self.clean() self.clean()
self.validate_unique() self.validate_unique()
@ -684,17 +768,17 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
super().save(*args, **kwargs) super().save(*args, **kwargs)
part = models.ForeignKey( part = models.ForeignKey(
'part.Part', on_delete=models.CASCADE, 'part.Part',
on_delete=models.CASCADE,
related_name='supplier_parts', related_name='supplier_parts',
verbose_name=_('Base Part'), verbose_name=_('Base Part'),
limit_choices_to={ limit_choices_to={'purchaseable': True},
'purchaseable': True,
},
help_text=_('Select part'), help_text=_('Select part'),
) )
supplier = models.ForeignKey( supplier = models.ForeignKey(
Company, on_delete=models.CASCADE, Company,
on_delete=models.CASCADE,
related_name='supplied_parts', related_name='supplied_parts',
limit_choices_to={'is_supplier': True}, limit_choices_to={'is_supplier': True},
verbose_name=_('Supplier'), verbose_name=_('Supplier'),
@ -703,50 +787,71 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
SKU = models.CharField( SKU = models.CharField(
max_length=100, max_length=100,
verbose_name=__("SKU = Stock Keeping Unit (supplier part number)", 'SKU'), verbose_name=__('SKU = Stock Keeping Unit (supplier part number)', 'SKU'),
help_text=_('Supplier stock keeping unit') help_text=_('Supplier stock keeping unit'),
) )
manufacturer_part = models.ForeignKey( manufacturer_part = models.ForeignKey(
ManufacturerPart, on_delete=models.CASCADE, ManufacturerPart,
blank=True, null=True, on_delete=models.CASCADE,
blank=True,
null=True,
related_name='supplier_parts', related_name='supplier_parts',
verbose_name=_('Manufacturer Part'), verbose_name=_('Manufacturer Part'),
help_text=_('Select manufacturer part'), help_text=_('Select manufacturer part'),
) )
link = InvenTreeURLField( link = InvenTreeURLField(
blank=True, null=True, blank=True,
null=True,
verbose_name=_('Link'), verbose_name=_('Link'),
help_text=_('URL for external supplier part link') help_text=_('URL for external supplier part link'),
) )
description = models.CharField( description = models.CharField(
max_length=250, blank=True, null=True, max_length=250,
blank=True,
null=True,
verbose_name=_('Description'), verbose_name=_('Description'),
help_text=_('Supplier part description') help_text=_('Supplier part description'),
) )
note = models.CharField( note = models.CharField(
max_length=100, blank=True, null=True, max_length=100,
blank=True,
null=True,
verbose_name=_('Note'), verbose_name=_('Note'),
help_text=_('Notes') help_text=_('Notes'),
) )
base_cost = models.DecimalField(max_digits=10, decimal_places=3, default=0, validators=[MinValueValidator(0)], verbose_name=_('base cost'), help_text=_('Minimum charge (e.g. stocking fee)')) base_cost = models.DecimalField(
max_digits=10,
decimal_places=3,
default=0,
validators=[MinValueValidator(0)],
verbose_name=_('base cost'),
help_text=_('Minimum charge (e.g. stocking fee)'),
)
packaging = models.CharField(max_length=50, blank=True, null=True, verbose_name=_('Packaging'), help_text=_('Part packaging')) packaging = models.CharField(
max_length=50,
blank=True,
null=True,
verbose_name=_('Packaging'),
help_text=_('Part packaging'),
)
pack_quantity = models.CharField( pack_quantity = models.CharField(
max_length=25, max_length=25,
verbose_name=_('Pack Quantity'), verbose_name=_('Pack Quantity'),
help_text=_('Total quantity supplied in a single pack. Leave empty for single items.'), help_text=_(
'Total quantity supplied in a single pack. Leave empty for single items.'
),
blank=True, blank=True,
) )
pack_quantity_native = RoundingDecimalField( pack_quantity_native = RoundingDecimalField(
max_digits=20, decimal_places=10, default=1, max_digits=20, decimal_places=10, default=1, null=True
null=True,
) )
def base_quantity(self, quantity=1) -> Decimal: def base_quantity(self, quantity=1) -> Decimal:
@ -756,20 +861,29 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
return q return q
multiple = models.PositiveIntegerField(default=1, validators=[MinValueValidator(1)], verbose_name=_('multiple'), help_text=_('Order multiple')) multiple = models.PositiveIntegerField(
default=1,
validators=[MinValueValidator(1)],
verbose_name=_('multiple'),
help_text=_('Order multiple'),
)
# TODO - Reimplement lead-time as a charfield with special validation (pattern matching). # TODO - Reimplement lead-time as a charfield with special validation (pattern matching).
# lead_time = models.DurationField(blank=True, null=True) # lead_time = models.DurationField(blank=True, null=True)
available = models.DecimalField( available = models.DecimalField(
max_digits=10, decimal_places=3, default=0, max_digits=10,
decimal_places=3,
default=0,
validators=[MinValueValidator(0)], validators=[MinValueValidator(0)],
verbose_name=_('Available'), verbose_name=_('Available'),
help_text=_('Quantity available from supplier'), help_text=_('Quantity available from supplier'),
) )
availability_updated = models.DateTimeField( availability_updated = models.DateTimeField(
null=True, blank=True, verbose_name=_('Availability Updated'), null=True,
blank=True,
verbose_name=_('Availability Updated'),
help_text=_('Date of last update of availability data'), help_text=_('Date of last update of availability data'),
) )
@ -821,17 +935,15 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
if self.price_breaks.filter(quantity=quantity, part=self.pk).exists(): if self.price_breaks.filter(quantity=quantity, part=self.pk).exists():
return return
SupplierPriceBreak.objects.create( SupplierPriceBreak.objects.create(part=self, quantity=quantity, price=price)
part=self,
quantity=quantity,
price=price
)
get_price = common.models.get_price get_price = common.models.get_price
def open_orders(self): def open_orders(self):
"""Return a database query for PurchaseOrder line items for this SupplierPart, limited to purchase orders that are open / outstanding.""" """Return a database query for PurchaseOrder line items for this SupplierPart, limited to purchase orders that are open / outstanding."""
return self.purchase_order_line_items.prefetch_related('order').filter(order__status__in=PurchaseOrderStatusGroups.OPEN) return self.purchase_order_line_items.prefetch_related('order').filter(
order__status__in=PurchaseOrderStatusGroups.OPEN
)
def on_order(self): def on_order(self):
"""Return the total quantity of items currently on order. """Return the total quantity of items currently on order.
@ -852,7 +964,10 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
def purchase_orders(self): def purchase_orders(self):
"""Returns a list of purchase orders relating to this supplier part.""" """Returns a list of purchase orders relating to this supplier part."""
return [line.order for line in self.purchase_order_line_items.all().prefetch_related('order')] return [
line.order
for line in self.purchase_order_line_items.all().prefetch_related('order')
]
@property @property
def pretty_name(self): def pretty_name(self):
@ -891,7 +1006,8 @@ class SupplierPriceBreak(common.models.PriceBreak):
class Meta: class Meta:
"""Metaclass defines extra model options""" """Metaclass defines extra model options"""
unique_together = ("part", "quantity")
unique_together = ('part', 'quantity')
# This model was moved from the 'Part' app # This model was moved from the 'Part' app
db_table = 'part_supplierpricebreak' db_table = 'part_supplierpricebreak'
@ -905,22 +1021,31 @@ class SupplierPriceBreak(common.models.PriceBreak):
"""Return the API URL associated with the SupplierPriceBreak model""" """Return the API URL associated with the SupplierPriceBreak model"""
return reverse('api-part-supplier-price-list') return reverse('api-part-supplier-price-list')
part = models.ForeignKey(SupplierPart, on_delete=models.CASCADE, related_name='pricebreaks', verbose_name=_('Part'),) part = models.ForeignKey(
SupplierPart,
on_delete=models.CASCADE,
related_name='pricebreaks',
verbose_name=_('Part'),
)
@receiver(post_save, sender=SupplierPriceBreak, dispatch_uid='post_save_supplier_price_break') @receiver(
post_save, sender=SupplierPriceBreak, dispatch_uid='post_save_supplier_price_break'
)
def after_save_supplier_price(sender, instance, created, **kwargs): def after_save_supplier_price(sender, instance, created, **kwargs):
"""Callback function when a SupplierPriceBreak is created or updated""" """Callback function when a SupplierPriceBreak is created or updated"""
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
if instance.part and instance.part.part: if instance.part and instance.part.part:
instance.part.part.schedule_pricing_update(create=True) instance.part.part.schedule_pricing_update(create=True)
@receiver(post_delete, sender=SupplierPriceBreak, dispatch_uid='post_delete_supplier_price_break') @receiver(
post_delete,
sender=SupplierPriceBreak,
dispatch_uid='post_delete_supplier_price_break',
)
def after_delete_supplier_price(sender, instance, **kwargs): def after_delete_supplier_price(sender, instance, **kwargs):
"""Callback function when a SupplierPriceBreak is deleted""" """Callback function when a SupplierPriceBreak is deleted"""
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData(): if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
if instance.part and instance.part.part: if instance.part and instance.part.part:
instance.part.part.schedule_pricing_update(create=False) instance.part.part.schedule_pricing_update(create=False)

View File

@ -10,20 +10,29 @@ from sql_util.utils import SubqueryCount
from taggit.serializers import TagListSerializerField from taggit.serializers import TagListSerializerField
import part.filters import part.filters
from InvenTree.serializers import (InvenTreeAttachmentSerializer, from InvenTree.serializers import (
InvenTreeAttachmentSerializer,
InvenTreeCurrencySerializer, InvenTreeCurrencySerializer,
InvenTreeDecimalField, InvenTreeDecimalField,
InvenTreeImageSerializerField, InvenTreeImageSerializerField,
InvenTreeModelSerializer, InvenTreeModelSerializer,
InvenTreeMoneySerializer, InvenTreeMoneySerializer,
InvenTreeTagModelSerializer, InvenTreeTagModelSerializer,
RemoteImageMixin) RemoteImageMixin,
)
from part.serializers import PartBriefSerializer from part.serializers import PartBriefSerializer
from .models import (Address, Company, CompanyAttachment, Contact, from .models import (
ManufacturerPart, ManufacturerPartAttachment, Address,
ManufacturerPartParameter, SupplierPart, Company,
SupplierPriceBreak) CompanyAttachment,
Contact,
ManufacturerPart,
ManufacturerPartAttachment,
ManufacturerPartParameter,
SupplierPart,
SupplierPriceBreak,
)
class CompanyBriefSerializer(InvenTreeModelSerializer): class CompanyBriefSerializer(InvenTreeModelSerializer):
@ -33,13 +42,7 @@ class CompanyBriefSerializer(InvenTreeModelSerializer):
"""Metaclass options.""" """Metaclass options."""
model = Company model = Company
fields = [ fields = ['pk', 'url', 'name', 'description', 'image']
'pk',
'url',
'name',
'description',
'image',
]
url = serializers.CharField(source='get_absolute_url', read_only=True) url = serializers.CharField(source='get_absolute_url', read_only=True)
@ -86,7 +89,7 @@ class AddressBriefSerializer(InvenTreeModelSerializer):
'province', 'province',
'country', 'country',
'shipping_notes', 'shipping_notes',
'internal_shipping_notes' 'internal_shipping_notes',
] ]
@ -119,7 +122,7 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
'parts_manufactured', 'parts_manufactured',
'remote_image', 'remote_image',
'address_count', 'address_count',
'primary_address' 'primary_address',
] ]
@staticmethod @staticmethod
@ -130,13 +133,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
parts_manufactured=SubqueryCount('manufactured_parts') parts_manufactured=SubqueryCount('manufactured_parts')
) )
queryset = queryset.annotate( queryset = queryset.annotate(parts_supplied=SubqueryCount('supplied_parts'))
parts_supplied=SubqueryCount('supplied_parts')
)
queryset = queryset.annotate( queryset = queryset.annotate(address_count=SubqueryCount('addresses'))
address_count=SubqueryCount('addresses')
)
return queryset return queryset
@ -150,7 +149,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
parts_manufactured = serializers.IntegerField(read_only=True) parts_manufactured = serializers.IntegerField(read_only=True)
address_count = serializers.IntegerField(read_only=True) address_count = serializers.IntegerField(read_only=True)
currency = InvenTreeCurrencySerializer(help_text=_('Default currency used for this supplier'), required=True) currency = InvenTreeCurrencySerializer(
help_text=_('Default currency used for this supplier'), required=True
)
def save(self): def save(self):
"""Save the Company instance""" """Save the Company instance"""
@ -167,12 +168,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
remote_img.save(buffer, format=fmt) remote_img.save(buffer, format=fmt)
# Construct a simplified name for the image # Construct a simplified name for the image
filename = f"company_{company.pk}_image.{fmt.lower()}" filename = f'company_{company.pk}_image.{fmt.lower()}'
company.image.save( company.image.save(filename, ContentFile(buffer.getvalue()))
filename,
ContentFile(buffer.getvalue()),
)
return self.instance return self.instance
@ -182,11 +180,10 @@ class CompanyAttachmentSerializer(InvenTreeAttachmentSerializer):
class Meta: class Meta:
"""Metaclass defines serializer options""" """Metaclass defines serializer options"""
model = CompanyAttachment model = CompanyAttachment
fields = InvenTreeAttachmentSerializer.attachment_fields([ fields = InvenTreeAttachmentSerializer.attachment_fields(['company'])
'company',
])
class ContactSerializer(InvenTreeModelSerializer): class ContactSerializer(InvenTreeModelSerializer):
@ -196,14 +193,7 @@ class ContactSerializer(InvenTreeModelSerializer):
"""Metaclass options""" """Metaclass options"""
model = Contact model = Contact
fields = [ fields = ['pk', 'company', 'name', 'phone', 'email', 'role']
'pk',
'company',
'name',
'phone',
'email',
'role',
]
class ManufacturerPartSerializer(InvenTreeTagModelSerializer): class ManufacturerPartSerializer(InvenTreeTagModelSerializer):
@ -224,7 +214,6 @@ class ManufacturerPartSerializer(InvenTreeTagModelSerializer):
'MPN', 'MPN',
'link', 'link',
'barcode_hash', 'barcode_hash',
'tags', 'tags',
] ]
@ -249,11 +238,15 @@ class ManufacturerPartSerializer(InvenTreeTagModelSerializer):
part_detail = PartBriefSerializer(source='part', many=False, read_only=True) part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
manufacturer_detail = CompanyBriefSerializer(source='manufacturer', many=False, read_only=True) manufacturer_detail = CompanyBriefSerializer(
source='manufacturer', many=False, read_only=True
)
pretty_name = serializers.CharField(read_only=True) pretty_name = serializers.CharField(read_only=True)
manufacturer = serializers.PrimaryKeyRelatedField(queryset=Company.objects.filter(is_manufacturer=True)) manufacturer = serializers.PrimaryKeyRelatedField(
queryset=Company.objects.filter(is_manufacturer=True)
)
class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer): class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer):
@ -264,9 +257,7 @@ class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer):
model = ManufacturerPartAttachment model = ManufacturerPartAttachment
fields = InvenTreeAttachmentSerializer.attachment_fields([ fields = InvenTreeAttachmentSerializer.attachment_fields(['manufacturer_part'])
'manufacturer_part',
])
class ManufacturerPartParameterSerializer(InvenTreeModelSerializer): class ManufacturerPartParameterSerializer(InvenTreeModelSerializer):
@ -295,7 +286,9 @@ class ManufacturerPartParameterSerializer(InvenTreeModelSerializer):
if not man_detail: if not man_detail:
self.fields.pop('manufacturer_part_detail') self.fields.pop('manufacturer_part_detail')
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', many=False, read_only=True) manufacturer_part_detail = ManufacturerPartSerializer(
source='manufacturer_part', many=False, read_only=True
)
class SupplierPartSerializer(InvenTreeTagModelSerializer): class SupplierPartSerializer(InvenTreeTagModelSerializer):
@ -330,14 +323,10 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
'supplier_detail', 'supplier_detail',
'url', 'url',
'updated', 'updated',
'tags', 'tags',
] ]
read_only_fields = [ read_only_fields = ['availability_updated', 'barcode_hash']
'availability_updated',
'barcode_hash',
]
tags = TagListSerializerField(required=False) tags = TagListSerializerField(required=False)
@ -377,19 +366,27 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
part_detail = PartBriefSerializer(source='part', many=False, read_only=True) part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
supplier_detail = CompanyBriefSerializer(source='supplier', many=False, read_only=True) supplier_detail = CompanyBriefSerializer(
source='supplier', many=False, read_only=True
)
manufacturer_detail = CompanyBriefSerializer(source='manufacturer_part.manufacturer', many=False, read_only=True) manufacturer_detail = CompanyBriefSerializer(
source='manufacturer_part.manufacturer', many=False, read_only=True
)
pretty_name = serializers.CharField(read_only=True) pretty_name = serializers.CharField(read_only=True)
supplier = serializers.PrimaryKeyRelatedField(queryset=Company.objects.filter(is_supplier=True)) supplier = serializers.PrimaryKeyRelatedField(
queryset=Company.objects.filter(is_supplier=True)
)
manufacturer = serializers.CharField(read_only=True) manufacturer = serializers.CharField(read_only=True)
MPN = serializers.CharField(read_only=True) MPN = serializers.CharField(read_only=True)
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', part_detail=False, read_only=True) manufacturer_part_detail = ManufacturerPartSerializer(
source='manufacturer_part', part_detail=False, read_only=True
)
url = serializers.CharField(source='get_absolute_url', read_only=True) url = serializers.CharField(source='get_absolute_url', read_only=True)
@ -403,9 +400,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
Fields: Fields:
in_stock: Current stock quantity for each SupplierPart in_stock: Current stock quantity for each SupplierPart
""" """
queryset = queryset.annotate( queryset = queryset.annotate(in_stock=part.filters.annotate_total_stock())
in_stock=part.filters.annotate_total_stock()
)
return queryset return queryset
@ -436,10 +431,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
MPN = self.initial_data.get('MPN', None) MPN = self.initial_data.get('MPN', None)
if manufacturer and MPN: if manufacturer and MPN:
kwargs = { kwargs = {'manufacturer': manufacturer, 'MPN': MPN}
'manufacturer': manufacturer,
'MPN': MPN,
}
supplier_part.save(**kwargs) supplier_part.save(**kwargs)
return supplier_part return supplier_part
@ -479,17 +471,19 @@ class SupplierPriceBreakSerializer(InvenTreeModelSerializer):
quantity = InvenTreeDecimalField() quantity = InvenTreeDecimalField()
price = InvenTreeMoneySerializer( price = InvenTreeMoneySerializer(allow_null=True, required=True, label=_('Price'))
allow_null=True,
required=True,
label=_('Price'),
)
price_currency = InvenTreeCurrencySerializer() price_currency = InvenTreeCurrencySerializer()
supplier = serializers.PrimaryKeyRelatedField(source='part.supplier', many=False, read_only=True) supplier = serializers.PrimaryKeyRelatedField(
source='part.supplier', many=False, read_only=True
)
supplier_detail = CompanyBriefSerializer(source='part.supplier', many=False, read_only=True) supplier_detail = CompanyBriefSerializer(
source='part.supplier', many=False, read_only=True
)
# Detail serializer for SupplierPart # Detail serializer for SupplierPart
part_detail = SupplierPartSerializer(source='part', brief=True, many=False, read_only=True) part_detail = SupplierPartSerializer(
source='part', brief=True, many=False, read_only=True
)

View File

@ -12,10 +12,7 @@ from .models import Address, Company, Contact, ManufacturerPart, SupplierPart
class CompanyTest(InvenTreeAPITestCase): class CompanyTest(InvenTreeAPITestCase):
"""Series of tests for the Company DRF API.""" """Series of tests for the Company DRF API."""
roles = [ roles = ['purchase_order.add', 'purchase_order.change']
'purchase_order.add',
'purchase_order.change',
]
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -23,9 +20,18 @@ class CompanyTest(InvenTreeAPITestCase):
super().setUpTestData() super().setUpTestData()
# Create some company objects to work with # Create some company objects to work with
cls.acme = Company.objects.create(name='ACME', description='Supplier', is_customer=False, is_supplier=True) cls.acme = Company.objects.create(
Company.objects.create(name='Drippy Cup Co.', description='Customer', is_customer=True, is_supplier=False) name='ACME', description='Supplier', is_customer=False, is_supplier=True
Company.objects.create(name='Sippy Cup Emporium', description='Another supplier') )
Company.objects.create(
name='Drippy Cup Co.',
description='Customer',
is_customer=True,
is_supplier=False,
)
Company.objects.create(
name='Sippy Cup Emporium', description='Another supplier'
)
def test_company_list(self): def test_company_list(self):
"""Test the list API endpoint for the Company model""" """Test the list API endpoint for the Company model"""
@ -82,22 +88,13 @@ class CompanyTest(InvenTreeAPITestCase):
url = reverse('api-company-list') url = reverse('api-company-list')
# Name is required # Name is required
response = self.post( response = self.post(url, {'description': 'A description!'}, expected_code=400)
url,
{
'description': 'A description!',
},
expected_code=400
)
# Minimal example, checking default values # Minimal example, checking default values
response = self.post( response = self.post(
url, url,
{ {'name': 'My API Company', 'description': 'A company created via the API'},
'name': 'My API Company', expected_code=201,
'description': 'A company created via the API',
},
expected_code=201
) )
self.assertTrue(response.data['is_supplier']) self.assertTrue(response.data['is_supplier'])
@ -110,14 +107,14 @@ class CompanyTest(InvenTreeAPITestCase):
response = self.post( response = self.post(
url, url,
{ {
'name': "Another Company", 'name': 'Another Company',
'description': "Also created via the API!", 'description': 'Also created via the API!',
'currency': 'AUD', 'currency': 'AUD',
'is_supplier': False, 'is_supplier': False,
'is_manufacturer': True, 'is_manufacturer': True,
'is_customer': True, 'is_customer': True,
}, },
expected_code=201 expected_code=201,
) )
self.assertEqual(response.data['currency'], 'AUD') self.assertEqual(response.data['currency'], 'AUD')
@ -128,12 +125,8 @@ class CompanyTest(InvenTreeAPITestCase):
# Attempt to create with invalid currency # Attempt to create with invalid currency
response = self.post( response = self.post(
url, url,
{ {'name': 'A name', 'description': 'A description', 'currency': 'POQD'},
'name': "A name", expected_code=400,
'description': 'A description',
'currency': 'POQD',
},
expected_code=400
) )
self.assertTrue('currency' in response.data) self.assertTrue('currency' in response.data)
@ -151,10 +144,8 @@ class ContactTest(InvenTreeAPITestCase):
# Create some companies # Create some companies
companies = [ companies = [
Company( Company(name=f'Company {idx}', description='Some company')
name=f"Company {idx}", for idx in range(3)
description="Some company"
) for idx in range(3)
] ]
Company.objects.bulk_create(companies) Company.objects.bulk_create(companies)
@ -164,10 +155,7 @@ class ContactTest(InvenTreeAPITestCase):
# Create some contacts # Create some contacts
for cmp in Company.objects.all(): for cmp in Company.objects.all():
contacts += [ contacts += [
Contact( Contact(company=cmp, name=f'My name {idx}') for idx in range(3)
company=cmp,
name=f"My name {idx}",
) for idx in range(3)
] ]
Contact.objects.bulk_create(contacts) Contact.objects.bulk_create(contacts)
@ -187,13 +175,7 @@ class ContactTest(InvenTreeAPITestCase):
# Filter by particular company # Filter by particular company
for cmp in Company.objects.all(): for cmp in Company.objects.all():
response = self.get( response = self.get(self.url, {'company': cmp.pk}, expected_code=200)
self.url,
{
'company': cmp.pk,
},
expected_code=200
)
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
@ -205,23 +187,13 @@ class ContactTest(InvenTreeAPITestCase):
# Without required permissions, creation should fail # Without required permissions, creation should fail
self.post( self.post(
self.url, self.url, {'company': company.pk, 'name': 'Joe Bloggs'}, expected_code=403
{
'company': company.pk,
'name': 'Joe Bloggs',
},
expected_code=403
) )
self.assignRole('return_order.add') self.assignRole('return_order.add')
self.post( self.post(
self.url, self.url, {'company': company.pk, 'name': 'Joe Bloggs'}, expected_code=201
{
'company': company.pk,
'name': 'Joe Bloggs',
},
expected_code=201
) )
self.assertEqual(Contact.objects.count(), n + 1) self.assertEqual(Contact.objects.count(), n + 1)
@ -239,23 +211,11 @@ class ContactTest(InvenTreeAPITestCase):
for key in ['pk', 'name', 'role']: for key in ['pk', 'name', 'role']:
self.assertIn(key, data) self.assertIn(key, data)
self.patch( self.patch(url, {'role': 'model'}, expected_code=403)
url,
{
'role': 'model',
},
expected_code=403
)
self.assignRole('purchase_order.change') self.assignRole('purchase_order.change')
self.patch( self.patch(url, {'role': 'x'}, expected_code=200)
url,
{
'role': 'x',
},
expected_code=200
)
# Get the contact again # Get the contact again
contact = Contact.objects.first() contact = Contact.objects.first()
@ -291,10 +251,8 @@ class AddressTest(InvenTreeAPITestCase):
cls.num_addr = 3 cls.num_addr = 3
# Create some companies # Create some companies
companies = [ companies = [
Company( Company(name=f'Company {idx}', description='Some company')
name=f"Company {idx}", for idx in range(cls.num_companies)
description="Some company"
) for idx in range(cls.num_companies)
] ]
Company.objects.bulk_create(companies) Company.objects.bulk_create(companies)
@ -304,10 +262,8 @@ class AddressTest(InvenTreeAPITestCase):
# Create some contacts # Create some contacts
for cmp in Company.objects.all(): for cmp in Company.objects.all():
addresses += [ addresses += [
Address( Address(company=cmp, title=f'Address no. {idx}')
company=cmp, for idx in range(cls.num_addr)
title=f"Address no. {idx}",
) for idx in range(cls.num_addr)
] ]
cls.url = reverse('api-address-list') cls.url = reverse('api-address-list')
@ -332,21 +288,11 @@ class AddressTest(InvenTreeAPITestCase):
"""Test creating a new address""" """Test creating a new address"""
company = Company.objects.first() company = Company.objects.first()
self.post(self.url, self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=403)
{
'company': company.pk,
'title': 'HQ'
},
expected_code=403)
self.assignRole('purchase_order.add') self.assignRole('purchase_order.add')
self.post(self.url, self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=201)
{
'company': company.pk,
'title': 'HQ'
},
expected_code=201)
def test_get(self): def test_get(self):
"""Test that objects are properly returned from a get""" """Test that objects are properly returned from a get"""
@ -357,7 +303,15 @@ class AddressTest(InvenTreeAPITestCase):
self.assertEqual(response.data['pk'], addr.pk) self.assertEqual(response.data['pk'], addr.pk)
for key in ['title', 'line1', 'line2', 'postal_code', 'postal_city', 'province', 'country']: for key in [
'title',
'line1',
'line2',
'postal_code',
'postal_city',
'province',
'country',
]:
self.assertIn(key, response.data) self.assertIn(key, response.data)
def test_edit(self): def test_edit(self):
@ -366,23 +320,11 @@ class AddressTest(InvenTreeAPITestCase):
url = reverse('api-address-detail', kwargs={'pk': addr.pk}) url = reverse('api-address-detail', kwargs={'pk': addr.pk})
self.patch( self.patch(url, {'title': 'Hello'}, expected_code=403)
url,
{
'title': 'Hello'
},
expected_code=403
)
self.assignRole('purchase_order.change') self.assignRole('purchase_order.change')
self.patch( self.patch(url, {'title': 'World'}, expected_code=200)
url,
{
'title': 'World'
},
expected_code=200
)
data = self.get(url, expected_code=200).data data = self.get(url, expected_code=200).data
@ -415,10 +357,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
'supplier_part', 'supplier_part',
] ]
roles = [ roles = ['part.add', 'part.change']
'part.add',
'part.change',
]
def test_manufacturer_part_list(self): def test_manufacturer_part_list(self):
"""Test the ManufacturerPart API list functionality""" """Test the ManufacturerPart API list functionality"""
@ -429,11 +368,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 3) self.assertEqual(len(response.data), 3)
# Create manufacturer part # Create manufacturer part
data = { data = {'part': 1, 'manufacturer': 7, 'MPN': 'MPN_TEST'}
'part': 1,
'manufacturer': 7,
'MPN': 'MPN_TEST',
}
response = self.client.post(url, data, format='json') response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data['MPN'], 'MPN_TEST') self.assertEqual(response.data['MPN'], 'MPN_TEST')
@ -456,9 +391,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(response.data['MPN'], 'MPN123') self.assertEqual(response.data['MPN'], 'MPN123')
# Change the MPN # Change the MPN
data = { data = {'MPN': 'MPN-TEST-123'}
'MPN': 'MPN-TEST-123',
}
response = self.client.patch(url, data, format='json') response = self.client.patch(url, data, format='json')
@ -485,7 +418,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
'MPN': 'PART_NUMBER', 'MPN': 'PART_NUMBER',
'link': 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E', 'link': 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
}, },
expected_code=201 expected_code=201,
) )
pk = response.data['pk'] pk = response.data['pk']
@ -503,10 +436,16 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
# Check link is not modified # Check link is not modified
self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E') self.assertEqual(
response.data['link'],
'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
)
# Check link is not modified # Check link is not modified
self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E') self.assertEqual(
response.data['link'],
'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
)
class SupplierPartTest(InvenTreeAPITestCase): class SupplierPartTest(InvenTreeAPITestCase):
@ -521,12 +460,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
'supplier_part', 'supplier_part',
] ]
roles = [ roles = ['part.add', 'part.change', 'part.add', 'purchase_order.change']
'part.add',
'part.change',
'part.add',
'purchase_order.change',
]
def test_supplier_part_list(self): def test_supplier_part_list(self):
"""Test the SupplierPart API list functionality""" """Test the SupplierPart API list functionality"""
@ -543,10 +477,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), supplier.supplied_parts.count()) self.assertEqual(len(response.data), supplier.supplied_parts.count())
# Filter by Part reference # Filter by Part reference
expected = { expected = {1: 4, 25: 2}
1: 4,
25: 2,
}
for pk, n in expected.items(): for pk, n in expected.items():
response = self.get(url, {'part': pk}, expected_code=200) response = self.get(url, {'part': pk}, expected_code=200)
@ -559,12 +490,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
# Should fail when sending an invalid 'available' field # Should fail when sending an invalid 'available' field
response = self.post( response = self.post(
url, url,
{ {'part': 1, 'supplier': 2, 'SKU': 'QQ', 'available': 'not a number'},
'part': 1,
'supplier': 2,
'SKU': 'QQ',
'available': 'not a number',
},
expected_code=400, expected_code=400,
) )
@ -572,13 +498,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
# Create a SupplierPart without specifying available quantity # Create a SupplierPart without specifying available quantity
response = self.post( response = self.post(
url, url, {'part': 1, 'supplier': 2, 'SKU': 'QQ'}, expected_code=201
{
'part': 1,
'supplier': 2,
'SKU': 'QQ',
},
expected_code=201
) )
sp = SupplierPart.objects.get(pk=response.data['pk']) sp = SupplierPart.objects.get(pk=response.data['pk'])
@ -589,9 +509,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
# Now, *update* the available quantity via the API # Now, *update* the available quantity via the API
self.patch( self.patch(
reverse('api-supplier-part-detail', kwargs={'pk': sp.pk}), reverse('api-supplier-part-detail', kwargs={'pk': sp.pk}),
{ {'available': 1234},
'available': 1234,
},
expected_code=200, expected_code=200,
) )
@ -602,12 +520,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
# We should also be able to create a SupplierPart with initial 'available' quantity # We should also be able to create a SupplierPart with initial 'available' quantity
response = self.post( response = self.post(
url, url,
{ {'part': 1, 'supplier': 2, 'SKU': 'QQQ', 'available': 999},
'part': 1,
'supplier': 2,
'SKU': 'QQQ',
'available': 999,
},
expected_code=201, expected_code=201,
) )
@ -629,11 +542,7 @@ class CompanyMetadataAPITest(InvenTreeAPITestCase):
'supplier_part', 'supplier_part',
] ]
roles = [ roles = ['company.change', 'purchase_order.change', 'part.change']
'company.change',
'purchase_order.change',
'part.change',
]
def metatester(self, apikey, model): def metatester(self, apikey, model):
"""Generic tester""" """Generic tester"""
@ -651,17 +560,15 @@ class CompanyMetadataAPITest(InvenTreeAPITestCase):
self.patch( self.patch(
url, url,
{ {'metadata': {f'abc-{numstr}': f'xyz-{apikey}-{numstr}'}},
'metadata': { expected_code=200,
f'abc-{numstr}': f'xyz-{apikey}-{numstr}',
}
},
expected_code=200
) )
# Refresh # Refresh
modeldata.refresh_from_db() modeldata.refresh_from_db()
self.assertEqual(modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}') self.assertEqual(
modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}'
)
def test_metadata(self): def test_metadata(self):
"""Test all endpoints""" """Test all endpoints"""

View File

@ -16,9 +16,7 @@ class TestForwardMigrations(MigratorTestCase):
Company = self.old_state.apps.get_model('company', 'company') Company = self.old_state.apps.get_model('company', 'company')
Company.objects.create( Company.objects.create(
name='MSPC', name='MSPC', description='Michael Scotts Paper Company', is_supplier=True
description='Michael Scotts Paper Company',
is_supplier=True
) )
def test_migrations(self): def test_migrations(self):
@ -49,7 +47,10 @@ class TestManufacturerField(MigratorTestCase):
part = Part.objects.create( part = Part.objects.create(
name='Screw', name='Screw',
description='A single screw', description='A single screw',
level=0, tree_id=0, lft=0, rght=0 level=0,
tree_id=0,
lft=0,
rght=0,
) )
# Create a company to act as the supplier # Create a company to act as the supplier
@ -62,17 +63,11 @@ class TestManufacturerField(MigratorTestCase):
# Add some SupplierPart objects # Add some SupplierPart objects
SupplierPart.objects.create( SupplierPart.objects.create(
part=part, part=part, supplier=supplier, SKU='SCREW.001', manufacturer_name='ACME'
supplier=supplier,
SKU='SCREW.001',
manufacturer_name='ACME',
) )
SupplierPart.objects.create( SupplierPart.objects.create(
part=part, part=part, supplier=supplier, SKU='SCREW.002', manufacturer_name='Zero Corp'
supplier=supplier,
SKU='SCREW.002',
manufacturer_name='Zero Corp',
) )
self.assertEqual(Company.objects.count(), 1) self.assertEqual(Company.objects.count(), 1)
@ -186,17 +181,12 @@ class TestManufacturerPart(MigratorTestCase):
# No Manufacturer # No Manufacturer
SupplierPart.objects.create( SupplierPart.objects.create(
part=part, part=part, supplier=supplier_2, SKU='CAP-CER-01UF-2', MPN='MUR-CAP-123456'
supplier=supplier_2,
SKU='CAP-CER-01UF-2',
MPN='MUR-CAP-123456',
) )
# No Manufacturer data # No Manufacturer data
SupplierPart.objects.create( SupplierPart.objects.create(
part=part, part=part, supplier=supplier_2, SKU='CAP-CER-01UF-3'
supplier=supplier_2,
SKU='CAP-CER-01UF-3',
) )
def test_manufacturer_part_objects(self): def test_manufacturer_part_objects(self):
@ -238,17 +228,20 @@ class TestCurrencyMigration(MigratorTestCase):
Part = self.old_state.apps.get_model('part', 'part') Part = self.old_state.apps.get_model('part', 'part')
part = Part.objects.create( part = Part.objects.create(
name="PART", description="A purchaseable part", name='PART',
description='A purchaseable part',
purchaseable=True, purchaseable=True,
level=0, level=0,
tree_id=0, tree_id=0,
lft=0, lft=0,
rght=0 rght=0,
) )
Company = self.old_state.apps.get_model('company', 'company') Company = self.old_state.apps.get_model('company', 'company')
supplier = Company.objects.create(name='Supplier', description='A supplier', is_supplier=True) supplier = Company.objects.create(
name='Supplier', description='A supplier', is_supplier=True
)
SupplierPart = self.old_state.apps.get_model('company', 'supplierpart') SupplierPart = self.old_state.apps.get_model('company', 'supplierpart')
@ -256,8 +249,12 @@ class TestCurrencyMigration(MigratorTestCase):
Currency = self.old_state.apps.get_model('common', 'currency') Currency = self.old_state.apps.get_model('common', 'currency')
aud = Currency.objects.create(symbol='$', suffix='AUD', description='Australian Dollars', value=1.0) aud = Currency.objects.create(
usd = Currency.objects.create(symbol='$', suffix='USD', description='US Dollars', value=1.0) symbol='$', suffix='AUD', description='Australian Dollars', value=1.0
)
usd = Currency.objects.create(
symbol='$', suffix='USD', description='US Dollars', value=1.0
)
PB = self.old_state.apps.get_model('company', 'supplierpricebreak') PB = self.old_state.apps.get_model('company', 'supplierpricebreak')
@ -312,7 +309,7 @@ class TestAddressMigration(MigratorTestCase):
a2 = Address.objects.filter(company=c2.pk).first() a2 = Address.objects.filter(company=c2.pk).first()
self.assertEqual(a1.line1, self.short_l1) self.assertEqual(a1.line1, self.short_l1)
self.assertEqual(a1.line2, "") self.assertEqual(a1.line2, '')
self.assertEqual(a2.line1, self.long_l1) self.assertEqual(a2.line1, self.long_l1)
self.assertEqual(a2.line2, self.l2) self.assertEqual(a2.line2, self.l2)
self.assertEqual(c1.address, '') self.assertEqual(c1.address, '')
@ -332,12 +329,18 @@ class TestSupplierPartQuantity(MigratorTestCase):
SupplierPart = self.old_state.apps.get_model('company', 'supplierpart') SupplierPart = self.old_state.apps.get_model('company', 'supplierpart')
self.part = Part.objects.create( self.part = Part.objects.create(
name="PART", description="A purchaseable part", name='PART',
description='A purchaseable part',
purchaseable=True, purchaseable=True,
level=0, tree_id=0, lft=0, rght=0 level=0,
tree_id=0,
lft=0,
rght=0,
) )
self.supplier = Company.objects.create(name='Supplier', description='A supplier', is_supplier=True) self.supplier = Company.objects.create(
name='Supplier', description='A supplier', is_supplier=True
)
self.supplier_parts = [] self.supplier_parts = []
@ -356,7 +359,6 @@ class TestSupplierPartQuantity(MigratorTestCase):
SupplierPart = self.new_state.apps.get_model('company', 'supplierpart') SupplierPart = self.new_state.apps.get_model('company', 'supplierpart')
for i, sp in enumerate(SupplierPart.objects.all()): for i, sp in enumerate(SupplierPart.objects.all()):
self.assertEqual(sp.pack_quantity, str(i + 1)) self.assertEqual(sp.pack_quantity, str(i + 1))
self.assertEqual(sp.pack_quantity_native, i + 1) self.assertEqual(sp.pack_quantity_native, i + 1)

View File

@ -15,17 +15,15 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
def test_pack_quantity_dimensionless(self): def test_pack_quantity_dimensionless(self):
"""Test valid values for the 'pack_quantity' field""" """Test valid values for the 'pack_quantity' field"""
# Create a part without units (dimensionless) # Create a part without units (dimensionless)
part = Part.objects.create(name='Test Part', description='Test part description', component=True) part = Part.objects.create(
name='Test Part', description='Test part description', component=True
)
# Create a supplier (company) # Create a supplier (company)
company = Company.objects.create(name='Test Company', is_supplier=True) company = Company.objects.create(name='Test Company', is_supplier=True)
# Create a supplier part for this part # Create a supplier part for this part
sp = SupplierPart.objects.create( sp = SupplierPart.objects.create(part=part, supplier=company, SKU='TEST-SKU')
part=part,
supplier=company,
SKU='TEST-SKU'
)
# All these values are valid for a dimensionless part # All these values are valid for a dimensionless part
pass_tests = { pass_tests = {
@ -37,14 +35,7 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
} }
# All these values are invalid for a dimensionless part # All these values are invalid for a dimensionless part
fail_tests = [ fail_tests = ['1.2m', '-1', '0', '0.0', '100 feet', '0 amps']
'1.2m',
'-1',
'0',
'0.0',
'100 feet',
'0 amps'
]
for test, expected in pass_tests.items(): for test, expected in pass_tests.items():
sp.pack_quantity = test sp.pack_quantity = test
@ -59,17 +50,18 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
def test_pack_quantity(self): def test_pack_quantity(self):
"""Test pack_quantity for a part with a specified dimension""" """Test pack_quantity for a part with a specified dimension"""
# Create a part with units 'm' # Create a part with units 'm'
part = Part.objects.create(name='Test Part', description='Test part description', component=True, units='m') part = Part.objects.create(
name='Test Part',
description='Test part description',
component=True,
units='m',
)
# Create a supplier (company) # Create a supplier (company)
company = Company.objects.create(name='Test Company', is_supplier=True) company = Company.objects.create(name='Test Company', is_supplier=True)
# Create a supplier part for this part # Create a supplier part for this part
sp = SupplierPart.objects.create( sp = SupplierPart.objects.create(part=part, supplier=company, SKU='TEST-SKU')
part=part,
supplier=company,
SKU='TEST-SKU'
)
# All these values are valid for a part with dimension 'm' # All these values are valid for a part with dimension 'm'
pass_tests = { pass_tests = {
@ -87,23 +79,14 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
# All these values are invalid for a part with dimension 'm' # All these values are invalid for a part with dimension 'm'
# Either the values are invalid, or the units are incompatible # Either the values are invalid, or the units are incompatible
fail_tests = [ fail_tests = ['-1', '-1m', '0', '0m', '12 deg', '57 amps', '-12 oz', '17 yaks']
'-1',
'-1m',
'0',
'0m',
'12 deg',
'57 amps',
'-12 oz',
'17 yaks',
]
for test, expected in pass_tests.items(): for test, expected in pass_tests.items():
sp.pack_quantity = test sp.pack_quantity = test
sp.full_clean() sp.full_clean()
self.assertEqual( self.assertEqual(
round(Decimal(sp.pack_quantity_native), 10), round(Decimal(sp.pack_quantity_native), 10),
round(Decimal(str(expected)), 10) round(Decimal(str(expected)), 10),
) )
for test in fail_tests: for test in fail_tests:

View File

@ -36,7 +36,9 @@ class CompanyViewTest(InvenTreeTestCase):
def test_manufacturer_part_detail_view(self): def test_manufacturer_part_detail_view(self):
"""Test the manufacturer part detail view.""" """Test the manufacturer part detail view."""
response = self.client.get(reverse('manufacturer-part-detail', kwargs={'pk': 1})) response = self.client.get(
reverse('manufacturer-part-detail', kwargs={'pk': 1})
)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertContains(response, 'MPN123') self.assertContains(response, 'MPN123')

View File

@ -8,8 +8,14 @@ from django.test import TestCase
from part.models import Part from part.models import Part
from .models import (Address, Company, Contact, ManufacturerPart, SupplierPart, from .models import (
rename_company_image) Address,
Company,
Contact,
ManufacturerPart,
SupplierPart,
rename_company_image,
)
class CompanySimpleTest(TestCase): class CompanySimpleTest(TestCase):
@ -31,11 +37,13 @@ class CompanySimpleTest(TestCase):
"""Perform initialization for the tests in this class""" """Perform initialization for the tests in this class"""
super().setUpTestData() super().setUpTestData()
Company.objects.create(name='ABC Co.', Company.objects.create(
name='ABC Co.',
description='Seller of ABC products', description='Seller of ABC products',
website='www.abc-sales.com', website='www.abc-sales.com',
is_customer=False, is_customer=False,
is_supplier=True) is_supplier=True,
)
cls.acme0001 = SupplierPart.objects.get(SKU='ACME0001') cls.acme0001 = SupplierPart.objects.get(SKU='ACME0001')
cls.acme0002 = SupplierPart.objects.get(SKU='ACME0002') cls.acme0002 = SupplierPart.objects.get(SKU='ACME0002')
@ -95,9 +103,9 @@ class CompanySimpleTest(TestCase):
"""Unit tests for supplier part pricing""" """Unit tests for supplier part pricing"""
m2x4 = Part.objects.get(name='M2x4 LPHS') m2x4 = Part.objects.get(name='M2x4 LPHS')
self.assertEqual(m2x4.get_price_info(5.5), "38.5 - 41.25") self.assertEqual(m2x4.get_price_info(5.5), '38.5 - 41.25')
self.assertEqual(m2x4.get_price_info(10), "70 - 75") self.assertEqual(m2x4.get_price_info(10), '70 - 75')
self.assertEqual(m2x4.get_price_info(100), "125 - 350") self.assertEqual(m2x4.get_price_info(100), '125 - 350')
pmin, pmax = m2x4.get_price_range(5) pmin, pmax = m2x4.get_price_range(5)
self.assertEqual(pmin, 35) self.assertEqual(pmin, 35)
@ -113,18 +121,14 @@ class CompanySimpleTest(TestCase):
"""Test validation for currency selection.""" """Test validation for currency selection."""
# Create a company with a valid currency code (should pass) # Create a company with a valid currency code (should pass)
company = Company.objects.create( company = Company.objects.create(
name='Test', name='Test', description='Toast', currency='AUD'
description='Toast',
currency='AUD',
) )
company.full_clean() company.full_clean()
# Create a company with an invalid currency code (should fail) # Create a company with an invalid currency code (should fail)
company = Company.objects.create( company = Company.objects.create(
name='test', name='test', description='Toasty', currency='XZY'
description='Toasty',
currency='XZY',
) )
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
@ -154,7 +158,9 @@ class ContactSimpleTest(TestCase):
def setUp(self): def setUp(self):
"""Initialization for the tests in this class""" """Initialization for the tests in this class"""
# Create a simple company # Create a simple company
self.c = Company.objects.create(name='Test Corp.', description='We make stuff good') self.c = Company.objects.create(
name='Test Corp.', description='We make stuff good'
)
# Add some contacts # Add some contacts
Contact.objects.create(name='Joe Smith', company=self.c) Contact.objects.create(name='Joe Smith', company=self.c)
@ -178,7 +184,9 @@ class AddressTest(TestCase):
def setUp(self): def setUp(self):
"""Initialization for the tests in this class""" """Initialization for the tests in this class"""
# Create a simple company # Create a simple company
self.c = Company.objects.create(name='Test Corp.', description='We make stuff good') self.c = Company.objects.create(
name='Test Corp.', description='We make stuff good'
)
def test_create(self): def test_create(self):
"""Test that object creation with only company supplied is successful""" """Test that object creation with only company supplied is successful"""
@ -214,27 +222,28 @@ class AddressTest(TestCase):
def test_model_str(self): def test_model_str(self):
"""Test value of __str__""" """Test value of __str__"""
t = "Test address" t = 'Test address'
l1 = "Busy street 56" l1 = 'Busy street 56'
l2 = "Red building" l2 = 'Red building'
pcd = "12345" pcd = '12345'
pct = "City" pct = 'City'
pv = "Province" pv = 'Province'
cn = "COUNTRY" cn = 'COUNTRY'
addr = Address.objects.create(company=self.c, addr = Address.objects.create(
company=self.c,
title=t, title=t,
line1=l1, line1=l1,
line2=l2, line2=l2,
postal_code=pcd, postal_code=pcd,
postal_city=pct, postal_city=pct,
province=pv, province=pv,
country=cn) country=cn,
)
self.assertEqual(str(addr), f'{l1}, {l2}, {pcd}, {pct}, {pv}, {cn}') self.assertEqual(str(addr), f'{l1}, {l2}, {pcd}, {pct}, {pv}, {cn}')
addr2 = Address.objects.create(company=self.c, addr2 = Address.objects.create(
title=t, company=self.c, title=t, line1=l1, postal_code=pcd
line1=l1, )
postal_code=pcd)
self.assertEqual(str(addr2), f'{l1}, {pcd}') self.assertEqual(str(addr2), f'{l1}, {pcd}')
@ -242,13 +251,7 @@ class AddressTest(TestCase):
class ManufacturerPartSimpleTest(TestCase): class ManufacturerPartSimpleTest(TestCase):
"""Unit tests for the ManufacturerPart model""" """Unit tests for the ManufacturerPart model"""
fixtures = [ fixtures = ['category', 'company', 'location', 'part', 'manufacturer_part']
'category',
'company',
'location',
'part',
'manufacturer_part',
]
def setUp(self): def setUp(self):
"""Initialization for the unit tests in this class""" """Initialization for the unit tests in this class"""
@ -266,9 +269,7 @@ class ManufacturerPartSimpleTest(TestCase):
# Create a supplier part # Create a supplier part
supplier = Company.objects.get(pk=5) supplier = Company.objects.get(pk=5)
supplier_part = SupplierPart.objects.create( supplier_part = SupplierPart.objects.create(
part=self.part, part=self.part, supplier=supplier, SKU='SKU_TEST'
supplier=supplier,
SKU='SKU_TEST',
) )
supplier_part.save() supplier_part.save()

View File

@ -5,29 +5,41 @@ from django.urls import include, path, re_path
from . import views from . import views
company_urls = [ company_urls = [
# Detail URLs for a specific Company instance # Detail URLs for a specific Company instance
path(r'<int:pk>/', include([ path(
re_path(r'^.*$', views.CompanyDetail.as_view(), name='company-detail'), r'<int:pk>/',
])), include([
re_path(r'^.*$', views.CompanyDetail.as_view(), name='company-detail')
]),
),
re_path(r'suppliers/', views.CompanyIndex.as_view(), name='supplier-index'), re_path(r'suppliers/', views.CompanyIndex.as_view(), name='supplier-index'),
re_path(r'manufacturers/', views.CompanyIndex.as_view(), name='manufacturer-index'), re_path(r'manufacturers/', views.CompanyIndex.as_view(), name='manufacturer-index'),
re_path(r'customers/', views.CompanyIndex.as_view(), name='customer-index'), re_path(r'customers/', views.CompanyIndex.as_view(), name='customer-index'),
# Redirect any other patterns to the 'company' index which displays all companies # Redirect any other patterns to the 'company' index which displays all companies
re_path(r'^.*$', views.CompanyIndex.as_view(), name='company-index'), re_path(r'^.*$', views.CompanyIndex.as_view(), name='company-index'),
] ]
manufacturer_part_urls = [ manufacturer_part_urls = [
path(
path(r'<int:pk>/', views.ManufacturerPartDetail.as_view(template_name='company/manufacturer_part.html'), name='manufacturer-part-detail'), r'<int:pk>/',
views.ManufacturerPartDetail.as_view(
template_name='company/manufacturer_part.html'
),
name='manufacturer-part-detail',
)
] ]
supplier_part_urls = [ supplier_part_urls = [
path(r'<int:pk>/', include([ path(
re_path('^.*$', views.SupplierPartDetail.as_view(template_name='company/supplier_part.html'), name='supplier-part-detail'), r'<int:pk>/',
])) include([
re_path(
'^.*$',
views.SupplierPartDetail.as_view(
template_name='company/supplier_part.html'
),
name='supplier-part-detail',
)
]),
)
] ]

View File

@ -44,14 +44,14 @@ class CompanyIndex(InvenTreeRoleMixin, ListView):
'button_text': _('New Customer'), 'button_text': _('New Customer'),
'filters': {'is_customer': 'true'}, 'filters': {'is_customer': 'true'},
'pagetype': 'customers', 'pagetype': 'customers',
} },
} }
default = { default = {
'title': _('Companies'), 'title': _('Companies'),
'button_text': _('New Company'), 'button_text': _('New Company'),
'filters': {}, 'filters': {},
'pagetype': 'companies' 'pagetype': 'companies',
} }
context = None context = None
@ -88,6 +88,7 @@ class CompanyIndex(InvenTreeRoleMixin, ListView):
class CompanyDetail(InvenTreePluginViewMixin, DetailView): class CompanyDetail(InvenTreePluginViewMixin, DetailView):
"""Detail view for Company object.""" """Detail view for Company object."""
context_obect_name = 'company' context_obect_name = 'company'
template_name = 'company/detail.html' template_name = 'company/detail.html'
queryset = Company.objects.all() queryset = Company.objects.all()
@ -97,6 +98,7 @@ class CompanyDetail(InvenTreePluginViewMixin, DetailView):
class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView): class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView):
"""Detail view for ManufacturerPart.""" """Detail view for ManufacturerPart."""
model = ManufacturerPart model = ManufacturerPart
template_name = 'company/manufacturer_part_detail.html' template_name = 'company/manufacturer_part_detail.html'
context_object_name = 'part' context_object_name = 'part'
@ -106,6 +108,7 @@ class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView):
class SupplierPartDetail(InvenTreePluginViewMixin, DetailView): class SupplierPartDetail(InvenTreePluginViewMixin, DetailView):
"""Detail view for SupplierPart.""" """Detail view for SupplierPart."""
model = SupplierPart model = SupplierPart
template_name = 'company/supplier_part_detail.html' template_name = 'company/supplier_part_detail.html'
context_object_name = 'part' context_object_name = 'part'

View File

@ -9,9 +9,4 @@ States can be extended with custom options for each InvenTree instance - those o
from .states import StatusCode from .states import StatusCode
from .transition import StateTransitionMixin, TransitionMethod, storage from .transition import StateTransitionMixin, TransitionMethod, storage
__all__ = [ __all__ = [StatusCode, storage, TransitionMethod, StateTransitionMixin]
StatusCode,
storage,
TransitionMethod,
StateTransitionMixin,
]

View File

@ -18,9 +18,7 @@ class StatusView(APIView):
all available 'StockStatus' codes all available 'StockStatus' codes
""" """
permission_classes = [ permission_classes = [permissions.IsAuthenticated]
permissions.IsAuthenticated,
]
# Override status_class for implementing subclass # Override status_class for implementing subclass
MODEL_REF = 'statusmodel' MODEL_REF = 'statusmodel'
@ -30,7 +28,9 @@ class StatusView(APIView):
status_model = self.kwargs.get(self.MODEL_REF, None) status_model = self.kwargs.get(self.MODEL_REF, None)
if status_model is None: if status_model is None:
raise ValidationError(f"StatusView view called without '{self.MODEL_REF}' parameter") raise ValidationError(
f"StatusView view called without '{self.MODEL_REF}' parameter"
)
return status_model return status_model
@ -39,15 +39,12 @@ class StatusView(APIView):
status_class = self.get_status_model() status_class = self.get_status_model()
if not inspect.isclass(status_class): if not inspect.isclass(status_class):
raise NotImplementedError("`status_class` not a class") raise NotImplementedError('`status_class` not a class')
if not issubclass(status_class, StatusCode): if not issubclass(status_class, StatusCode):
raise NotImplementedError("`status_class` not a valid StatusCode class") raise NotImplementedError('`status_class` not a valid StatusCode class')
data = { data = {'class': status_class.__name__, 'values': status_class.dict()}
'class': status_class.__name__,
'values': status_class.dict(),
}
return Response(data) return Response(data)
@ -55,9 +52,7 @@ class StatusView(APIView):
class AllStatusViews(StatusView): class AllStatusViews(StatusView):
"""Endpoint for listing all defined status models.""" """Endpoint for listing all defined status models."""
permission_classes = [ permission_classes = [permissions.IsAuthenticated]
permissions.IsAuthenticated,
]
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Perform a GET request to learn information about status codes""" """Perform a GET request to learn information about status codes"""

View File

@ -1,4 +1,5 @@
"""Generic implementation of status for InvenTree models.""" """Generic implementation of status for InvenTree models."""
import enum import enum
import re import re
@ -148,12 +149,10 @@ class StatusCode(BaseEnum):
@classmethod @classmethod
def dict(cls, key=None): def dict(cls, key=None):
"""Return a dict representation containing all required information""" """Return a dict representation containing all required information"""
return {x.name: { return {
'color': x.color, x.name: {'color': x.color, 'key': x.value, 'label': x.label, 'name': x.name}
'key': x.value, for x in cls.values(key)
'label': x.label, }
'name': x.name,
} for x in cls.values(key)}
@classmethod @classmethod
def list(cls): def list(cls):

View File

@ -35,8 +35,7 @@ class TransitionTests(InvenTreeTestCase):
def test_class(self): def test_class(self):
"""Ensure that the class itself works.""" """Ensure that the class itself works."""
class ErrorImplementation(TransitionMethod): class ErrorImplementation(TransitionMethod): ...
...
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
ErrorImplementation() ErrorImplementation()
@ -90,7 +89,6 @@ class TransitionTests(InvenTreeTestCase):
class ValidImplementation(TransitionMethod): class ValidImplementation(TransitionMethod):
def transition(self, *args, **kwargs): def transition(self, *args, **kwargs):
global raise_function global raise_function
if raise_function: if raise_function:
@ -103,7 +101,9 @@ class TransitionTests(InvenTreeTestCase):
self.assertIn(ValidImplementation, storage.list) self.assertIn(ValidImplementation, storage.list)
# Ensure that the function is called # Ensure that the function is called
self.assertEqual(StateTransitionMixin.handle_transition(0, 1, self, self, dflt), 1234) self.assertEqual(
StateTransitionMixin.handle_transition(0, 1, self, self, dflt), 1234
)
_clean_storage([ValidImplementationNoEffect, ValidImplementation]) _clean_storage([ValidImplementationNoEffect, ValidImplementation])

View File

@ -1,4 +1,5 @@
"""Tests for the generic states module.""" """Tests for the generic states module."""
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -13,9 +14,9 @@ from .states import StatusCode
class GeneralStatus(StatusCode): class GeneralStatus(StatusCode):
"""Defines a set of status codes for tests.""" """Defines a set of status codes for tests."""
PENDING = 10, _("Pending"), 'secondary' PENDING = 10, _('Pending'), 'secondary'
PLACED = 20, _("Placed"), 'primary' PLACED = 20, _('Placed'), 'primary'
COMPLETE = 30, _("Complete"), 'success' COMPLETE = 30, _('Complete'), 'success'
ABC = None # This should be ignored ABC = None # This should be ignored
_DEF = None # This should be ignored _DEF = None # This should be ignored
jkl = None # This should be ignored jkl = None # This should be ignored
@ -27,6 +28,7 @@ class GeneralStatus(StatusCode):
class GeneralStateTest(InvenTreeTestCase): class GeneralStateTest(InvenTreeTestCase):
"""Test that the StatusCode class works.""" """Test that the StatusCode class works."""
def test_code_definition(self): def test_code_definition(self):
"""Test that the status code class has been defined correctly.""" """Test that the status code class has been defined correctly."""
self.assertEqual(GeneralStatus.PENDING, 10) self.assertEqual(GeneralStatus.PENDING, 10)
@ -36,36 +38,84 @@ class GeneralStateTest(InvenTreeTestCase):
def test_code_functions(self): def test_code_functions(self):
"""Test that the status code class functions work correctly""" """Test that the status code class functions work correctly"""
# render # render
self.assertEqual(GeneralStatus.render(10), "<span class='badge rounded-pill bg-secondary'>Pending</span>") self.assertEqual(
self.assertEqual(GeneralStatus.render(20), "<span class='badge rounded-pill bg-primary'>Placed</span>") GeneralStatus.render(10),
"<span class='badge rounded-pill bg-secondary'>Pending</span>",
)
self.assertEqual(
GeneralStatus.render(20),
"<span class='badge rounded-pill bg-primary'>Placed</span>",
)
# render with invalid key # render with invalid key
self.assertEqual(GeneralStatus.render(100), 100) self.assertEqual(GeneralStatus.render(100), 100)
# list # list
self.assertEqual(GeneralStatus.list(), [{'color': 'secondary', 'key': 10, 'label': 'Pending', 'name': 'PENDING'}, {'color': 'primary', 'key': 20, 'label': 'Placed', 'name': 'PLACED'}, {'color': 'success', 'key': 30, 'label': 'Complete', 'name': 'COMPLETE'}]) self.assertEqual(
GeneralStatus.list(),
[
{
'color': 'secondary',
'key': 10,
'label': 'Pending',
'name': 'PENDING',
},
{'color': 'primary', 'key': 20, 'label': 'Placed', 'name': 'PLACED'},
{
'color': 'success',
'key': 30,
'label': 'Complete',
'name': 'COMPLETE',
},
],
)
# text # text
self.assertEqual(GeneralStatus.text(10), 'Pending') self.assertEqual(GeneralStatus.text(10), 'Pending')
self.assertEqual(GeneralStatus.text(20), 'Placed') self.assertEqual(GeneralStatus.text(20), 'Placed')
# items # items
self.assertEqual(list(GeneralStatus.items()), [(10, 'Pending'), (20, 'Placed'), (30, 'Complete')]) self.assertEqual(
list(GeneralStatus.items()),
[(10, 'Pending'), (20, 'Placed'), (30, 'Complete')],
)
# keys # keys
self.assertEqual(list(GeneralStatus.keys()), ([10, 20, 30])) self.assertEqual(list(GeneralStatus.keys()), ([10, 20, 30]))
# labels # labels
self.assertEqual(list(GeneralStatus.labels()), ['Pending', 'Placed', 'Complete']) self.assertEqual(
list(GeneralStatus.labels()), ['Pending', 'Placed', 'Complete']
)
# names # names
self.assertEqual(GeneralStatus.names(), {'PENDING': 10, 'PLACED': 20, 'COMPLETE': 30}) self.assertEqual(
GeneralStatus.names(), {'PENDING': 10, 'PLACED': 20, 'COMPLETE': 30}
)
# dict # dict
self.assertEqual(GeneralStatus.dict(), { self.assertEqual(
'PENDING': {'key': 10, 'name': 'PENDING', 'label': 'Pending', 'color': 'secondary'}, GeneralStatus.dict(),
'PLACED': {'key': 20, 'name': 'PLACED', 'label': 'Placed', 'color': 'primary'}, {
'COMPLETE': {'key': 30, 'name': 'COMPLETE', 'label': 'Complete', 'color': 'success'}, 'PENDING': {
}) 'key': 10,
'name': 'PENDING',
'label': 'Pending',
'color': 'secondary',
},
'PLACED': {
'key': 20,
'name': 'PLACED',
'label': 'Placed',
'color': 'primary',
},
'COMPLETE': {
'key': 30,
'name': 'COMPLETE',
'label': 'Complete',
'color': 'success',
},
},
)
# label # label
self.assertEqual(GeneralStatus.label(10), 'Pending') self.assertEqual(GeneralStatus.label(10), 'Pending')
@ -74,7 +124,10 @@ class GeneralStateTest(InvenTreeTestCase):
"""Test that the status code tag functions.""" """Test that the status code tag functions."""
from .tags import status_label from .tags import status_label
self.assertEqual(status_label('general', 10), "<span class='badge rounded-pill bg-secondary'>Pending</span>") self.assertEqual(
status_label('general', 10),
"<span class='badge rounded-pill bg-secondary'>Pending</span>",
)
# invalid type # invalid type
with self.assertRaises(ValueError) as e: with self.assertRaises(ValueError) as e:
@ -87,24 +140,54 @@ class GeneralStateTest(InvenTreeTestCase):
def test_api(self): def test_api(self):
"""Test StatusView API view.""" """Test StatusView API view."""
view = StatusView.as_view() view = StatusView.as_view()
rqst = RequestFactory().get('status/',) rqst = RequestFactory().get('status/')
force_authenticate(rqst, user=self.user) force_authenticate(rqst, user=self.user)
# Correct call # Correct call
resp = view(rqst, **{StatusView.MODEL_REF: GeneralStatus}) resp = view(rqst, **{StatusView.MODEL_REF: GeneralStatus})
self.assertEqual(resp.data, {'class': 'GeneralStatus', 'values': {'COMPLETE': {'key': 30, 'name': 'COMPLETE', 'label': 'Complete', 'color': 'success'}, 'PENDING': {'key': 10, 'name': 'PENDING', 'label': 'Pending', 'color': 'secondary'}, 'PLACED': {'key': 20, 'name': 'PLACED', 'label': 'Placed', 'color': 'primary'}}}) self.assertEqual(
resp.data,
{
'class': 'GeneralStatus',
'values': {
'COMPLETE': {
'key': 30,
'name': 'COMPLETE',
'label': 'Complete',
'color': 'success',
},
'PENDING': {
'key': 10,
'name': 'PENDING',
'label': 'Pending',
'color': 'secondary',
},
'PLACED': {
'key': 20,
'name': 'PLACED',
'label': 'Placed',
'color': 'primary',
},
},
},
)
# No status defined # No status defined
resp = view(rqst, **{StatusView.MODEL_REF: None}) resp = view(rqst, **{StatusView.MODEL_REF: None})
self.assertEqual(resp.status_code, 400) self.assertEqual(resp.status_code, 400)
self.assertEqual(str(resp.rendered_content, 'utf-8'), '["StatusView view called without \'statusmodel\' parameter"]') self.assertEqual(
str(resp.rendered_content, 'utf-8'),
'["StatusView view called without \'statusmodel\' parameter"]',
)
# Invalid call - not a class # Invalid call - not a class
with self.assertRaises(NotImplementedError) as e: with self.assertRaises(NotImplementedError) as e:
resp = view(rqst, **{StatusView.MODEL_REF: 'invalid'}) resp = view(rqst, **{StatusView.MODEL_REF: 'invalid'})
self.assertEqual(str(e.exception), "`status_class` not a class") self.assertEqual(str(e.exception), '`status_class` not a class')
# Invalid call - not the right class # Invalid call - not the right class
with self.assertRaises(NotImplementedError) as e: with self.assertRaises(NotImplementedError) as e:
resp = view(rqst, **{StatusView.MODEL_REF: object}) resp = view(rqst, **{StatusView.MODEL_REF: object})
self.assertEqual(str(e.exception), "`status_class` not a valid StatusCode class") self.assertEqual(
str(e.exception), '`status_class` not a valid StatusCode class'
)

View File

@ -1,4 +1,5 @@
"""Classes and functions for plugin controlled object state transitions.""" """Classes and functions for plugin controlled object state transitions."""
import InvenTree.helpers import InvenTree.helpers
@ -15,8 +16,10 @@ class TransitionMethod:
- The needed functions are implemented - The needed functions are implemented
""" """
# Check if a sending fnc is defined # Check if a sending fnc is defined
if (not hasattr(self, 'transition')): if not hasattr(self, 'transition'):
raise NotImplementedError('A TransitionMethod must define a `transition` method') raise NotImplementedError(
'A TransitionMethod must define a `transition` method'
)
class TransitionMethodStorageClass: class TransitionMethodStorageClass:
@ -62,7 +65,9 @@ class StateTransitionMixin:
``` ```
""" """
def handle_transition(self, current_state, target_state, instance, default_action, **kwargs): def handle_transition(
self, current_state, target_state, instance, default_action, **kwargs
):
"""Handle a state transition for an object. """Handle a state transition for an object.
Args: Args:
@ -74,7 +79,9 @@ class StateTransitionMixin:
# Check if there is a custom override function for this transition # Check if there is a custom override function for this transition
for override in storage.list: for override in storage.list:
rslt = override.transition(current_state, target_state, instance, default_action, **kwargs) rslt = override.transition(
current_state, target_state, instance, default_action, **kwargs
)
if rslt: if rslt:
return rslt return rslt

View File

@ -5,6 +5,4 @@ from django import template
register = template.Library() register = template.Library()
from generic.states.tags import status_label # noqa: E402 from generic.states.tags import status_label # noqa: E402
__all__ = [ __all__ = [status_label]
status_label,
]

View File

@ -2,7 +2,7 @@
import multiprocessing import multiprocessing
bind = "0.0.0.0:8000" bind = '0.0.0.0:8000'
workers = multiprocessing.cpu_count() * 2 + 1 workers = multiprocessing.cpu_count() * 2 + 1

View File

@ -7,6 +7,7 @@ import label.models
class LabelAdmin(admin.ModelAdmin): class LabelAdmin(admin.ModelAdmin):
"""Admin class for the various label models""" """Admin class for the various label models"""
list_display = ('name', 'description', 'label', 'filters', 'enabled') list_display = ('name', 'description', 'label', 'filters', 'enabled')

View File

@ -18,8 +18,7 @@ import label.models
import label.serializers import label.serializers
from InvenTree.api import MetadataView from InvenTree.api import MetadataView
from InvenTree.filters import InvenTreeSearchFilter from InvenTree.filters import InvenTreeSearchFilter
from InvenTree.mixins import (ListCreateAPI, RetrieveAPI, from InvenTree.mixins import ListCreateAPI, RetrieveAPI, RetrieveUpdateDestroyAPI
RetrieveUpdateDestroyAPI)
from part.models import Part from part.models import Part
from plugin.builtin.labels.inventree_label import InvenTreeLabelPlugin from plugin.builtin.labels.inventree_label import InvenTreeLabelPlugin
from plugin.registry import registry from plugin.registry import registry
@ -59,7 +58,7 @@ class LabelFilterMixin:
for id in ids: for id in ids:
try: try:
valid_ids.append(int(id)) valid_ids.append(int(id))
except (ValueError): except ValueError:
pass pass
# Filter queryset by matching ID values # Filter queryset by matching ID values
@ -120,34 +119,23 @@ class LabelListView(LabelFilterMixin, ListCreateAPI):
return queryset return queryset
filter_backends = [ filter_backends = [DjangoFilterBackend, InvenTreeSearchFilter]
DjangoFilterBackend,
InvenTreeSearchFilter
]
filterset_fields = [ filterset_fields = ['enabled']
'enabled',
]
search_fields = [ search_fields = ['name', 'description']
'name',
'description',
]
@method_decorator(cache_page(5), name='dispatch') @method_decorator(cache_page(5), name='dispatch')
class LabelPrintMixin(LabelFilterMixin): class LabelPrintMixin(LabelFilterMixin):
"""Mixin for printing labels.""" """Mixin for printing labels."""
rolemap = { rolemap = {'GET': 'view', 'POST': 'view'}
"GET": "view",
"POST": "view",
}
def check_permissions(self, request): def check_permissions(self, request):
"""Override request method to GET so that also non superusers can print using a post request.""" """Override request method to GET so that also non superusers can print using a post request."""
if request.method == "POST": if request.method == 'POST':
request = clone_request(request, "GET") request = clone_request(request, 'GET')
return super().check_permissions(request) return super().check_permissions(request)
@method_decorator(never_cache) @method_decorator(never_cache)
@ -161,7 +149,9 @@ class LabelPrintMixin(LabelFilterMixin):
plugin = self.get_plugin(self.request) plugin = self.get_plugin(self.request)
kwargs.setdefault('context', self.get_serializer_context()) kwargs.setdefault('context', self.get_serializer_context())
serializer = plugin.get_printing_options_serializer(self.request, *args, **kwargs) serializer = plugin.get_printing_options_serializer(
self.request, *args, **kwargs
)
# if no serializer is defined, return an empty serializer # if no serializer is defined, return an empty serializer
if not serializer: if not serializer:
@ -171,8 +161,12 @@ class LabelPrintMixin(LabelFilterMixin):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""Perform a GET request against this endpoint to print labels""" """Perform a GET request against this endpoint to print labels"""
common.models.InvenTreeUserSetting.set_setting('DEFAULT_' + self.ITEM_KEY.upper() + '_LABEL_TEMPLATE', common.models.InvenTreeUserSetting.set_setting(
self.get_object().pk, None, user=request.user) 'DEFAULT_' + self.ITEM_KEY.upper() + '_LABEL_TEMPLATE',
self.get_object().pk,
None,
user=request.user,
)
return self.print(request, self.get_items()) return self.print(request, self.get_items())
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
@ -205,8 +199,10 @@ class LabelPrintMixin(LabelFilterMixin):
if not plugin.is_active(): if not plugin.is_active():
raise ValidationError(f"Plugin '{plugin_key}' is not enabled") raise ValidationError(f"Plugin '{plugin_key}' is not enabled")
if not plugin.mixin_enabled("labels"): if not plugin.mixin_enabled('labels'):
raise ValidationError(f"Plugin '{plugin_key}' is not a label printing plugin") raise ValidationError(
f"Plugin '{plugin_key}' is not a label printing plugin"
)
# Only return the plugin if it is enabled and has the label printing mixin # Only return the plugin if it is enabled and has the label printing mixin
return plugin return plugin
@ -228,18 +224,24 @@ class LabelPrintMixin(LabelFilterMixin):
raise ValidationError('Label has invalid dimensions') raise ValidationError('Label has invalid dimensions')
# if the plugin returns a serializer, validate the data # if the plugin returns a serializer, validate the data
if serializer := plugin.get_printing_options_serializer(request, data=request.data, context=self.get_serializer_context()): if serializer := plugin.get_printing_options_serializer(
request, data=request.data, context=self.get_serializer_context()
):
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
# At this point, we offload the label(s) to the selected plugin. # At this point, we offload the label(s) to the selected plugin.
# The plugin is responsible for handling the request and returning a response. # The plugin is responsible for handling the request and returning a response.
result = plugin.print_labels(label, items_to_print, request, printing_options=request.data) result = plugin.print_labels(
label, items_to_print, request, printing_options=request.data
)
if isinstance(result, JsonResponse): if isinstance(result, JsonResponse):
result['plugin'] = plugin.plugin_slug() result['plugin'] = plugin.plugin_slug()
return result return result
raise ValidationError(f"Plugin '{plugin.plugin_slug()}' returned invalid response type '{type(result)}'") raise ValidationError(
f"Plugin '{plugin.plugin_slug()}' returned invalid response type '{type(result)}'"
)
class StockItemLabelMixin: class StockItemLabelMixin:
@ -261,16 +263,19 @@ class StockItemLabelList(StockItemLabelMixin, LabelListView):
- item: Filter by single stock item - item: Filter by single stock item
- items: Filter by list of stock items - items: Filter by list of stock items
""" """
pass pass
class StockItemLabelDetail(StockItemLabelMixin, RetrieveUpdateDestroyAPI): class StockItemLabelDetail(StockItemLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single StockItemLabel object.""" """API endpoint for a single StockItemLabel object."""
pass pass
class StockItemLabelPrint(StockItemLabelMixin, LabelPrintMixin, RetrieveAPI): class StockItemLabelPrint(StockItemLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a StockItemLabel object.""" """API endpoint for printing a StockItemLabel object."""
pass pass
@ -293,21 +298,25 @@ class StockLocationLabelList(StockLocationLabelMixin, LabelListView):
- location: Filter by a single stock location - location: Filter by a single stock location
- locations: Filter by list of stock locations - locations: Filter by list of stock locations
""" """
pass pass
class StockLocationLabelDetail(StockLocationLabelMixin, RetrieveUpdateDestroyAPI): class StockLocationLabelDetail(StockLocationLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single StockLocationLabel object.""" """API endpoint for a single StockLocationLabel object."""
pass pass
class StockLocationLabelPrint(StockLocationLabelMixin, LabelPrintMixin, RetrieveAPI): class StockLocationLabelPrint(StockLocationLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a StockLocationLabel object.""" """API endpoint for printing a StockLocationLabel object."""
pass pass
class PartLabelMixin: class PartLabelMixin:
"""Mixin for PartLabel endpoints""" """Mixin for PartLabel endpoints"""
queryset = label.models.PartLabel.objects.all() queryset = label.models.PartLabel.objects.all()
serializer_class = label.serializers.PartLabelSerializer serializer_class = label.serializers.PartLabelSerializer
@ -317,16 +326,19 @@ class PartLabelMixin:
class PartLabelList(PartLabelMixin, LabelListView): class PartLabelList(PartLabelMixin, LabelListView):
"""API endpoint for viewing list of PartLabel objects.""" """API endpoint for viewing list of PartLabel objects."""
pass pass
class PartLabelDetail(PartLabelMixin, RetrieveUpdateDestroyAPI): class PartLabelDetail(PartLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single PartLabel object.""" """API endpoint for a single PartLabel object."""
pass pass
class PartLabelPrint(PartLabelMixin, LabelPrintMixin, RetrieveAPI): class PartLabelPrint(PartLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a PartLabel object.""" """API endpoint for printing a PartLabel object."""
pass pass
@ -342,70 +354,147 @@ class BuildLineLabelMixin:
class BuildLineLabelList(BuildLineLabelMixin, LabelListView): class BuildLineLabelList(BuildLineLabelMixin, LabelListView):
"""API endpoint for viewing a list of BuildLineLabel objects""" """API endpoint for viewing a list of BuildLineLabel objects"""
pass pass
class BuildLineLabelDetail(BuildLineLabelMixin, RetrieveUpdateDestroyAPI): class BuildLineLabelDetail(BuildLineLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single BuildLineLabel object""" """API endpoint for a single BuildLineLabel object"""
pass pass
class BuildLineLabelPrint(BuildLineLabelMixin, LabelPrintMixin, RetrieveAPI): class BuildLineLabelPrint(BuildLineLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a BuildLineLabel object""" """API endpoint for printing a BuildLineLabel object"""
pass pass
label_api_urls = [ label_api_urls = [
# Stock item labels # Stock item labels
re_path(r'stock/', include([ re_path(
r'stock/',
include([
# Detail views # Detail views
path(r'<int:pk>/', include([ path(
re_path(r'print/?', StockItemLabelPrint.as_view(), name='api-stockitem-label-print'), r'<int:pk>/',
re_path(r'metadata/', MetadataView.as_view(), {'model': label.models.StockItemLabel}, name='api-stockitem-label-metadata'), include([
re_path(r'^.*$', StockItemLabelDetail.as_view(), name='api-stockitem-label-detail'), re_path(
])), r'print/?',
StockItemLabelPrint.as_view(),
name='api-stockitem-label-print',
),
re_path(
r'metadata/',
MetadataView.as_view(),
{'model': label.models.StockItemLabel},
name='api-stockitem-label-metadata',
),
re_path(
r'^.*$',
StockItemLabelDetail.as_view(),
name='api-stockitem-label-detail',
),
]),
),
# List view # List view
re_path(r'^.*$', StockItemLabelList.as_view(), name='api-stockitem-label-list'), re_path(
])), r'^.*$', StockItemLabelList.as_view(), name='api-stockitem-label-list'
),
]),
),
# Stock location labels # Stock location labels
re_path(r'location/', include([ re_path(
r'location/',
include([
# Detail views # Detail views
path(r'<int:pk>/', include([ path(
re_path(r'print/?', StockLocationLabelPrint.as_view(), name='api-stocklocation-label-print'), r'<int:pk>/',
re_path(r'metadata/', MetadataView.as_view(), {'model': label.models.StockLocationLabel}, name='api-stocklocation-label-metadata'), include([
re_path(r'^.*$', StockLocationLabelDetail.as_view(), name='api-stocklocation-label-detail'), re_path(
])), r'print/?',
StockLocationLabelPrint.as_view(),
name='api-stocklocation-label-print',
),
re_path(
r'metadata/',
MetadataView.as_view(),
{'model': label.models.StockLocationLabel},
name='api-stocklocation-label-metadata',
),
re_path(
r'^.*$',
StockLocationLabelDetail.as_view(),
name='api-stocklocation-label-detail',
),
]),
),
# List view # List view
re_path(r'^.*$', StockLocationLabelList.as_view(), name='api-stocklocation-label-list'), re_path(
])), r'^.*$',
StockLocationLabelList.as_view(),
name='api-stocklocation-label-list',
),
]),
),
# Part labels # Part labels
re_path(r'^part/', include([ re_path(
r'^part/',
include([
# Detail views # Detail views
path(r'<int:pk>/', include([ path(
re_path(r'^print/', PartLabelPrint.as_view(), name='api-part-label-print'), r'<int:pk>/',
re_path(r'^metadata/', MetadataView.as_view(), {'model': label.models.PartLabel}, name='api-part-label-metadata'), include([
re_path(r'^.*$', PartLabelDetail.as_view(), name='api-part-label-detail'), re_path(
])), r'^print/',
PartLabelPrint.as_view(),
name='api-part-label-print',
),
re_path(
r'^metadata/',
MetadataView.as_view(),
{'model': label.models.PartLabel},
name='api-part-label-metadata',
),
re_path(
r'^.*$', PartLabelDetail.as_view(), name='api-part-label-detail'
),
]),
),
# List view # List view
re_path(r'^.*$', PartLabelList.as_view(), name='api-part-label-list'), re_path(r'^.*$', PartLabelList.as_view(), name='api-part-label-list'),
])), ]),
),
# BuildLine labels # BuildLine labels
re_path(r'^buildline/', include([ re_path(
r'^buildline/',
include([
# Detail views # Detail views
path(r'<int:pk>/', include([ path(
re_path(r'^print/', BuildLineLabelPrint.as_view(), name='api-buildline-label-print'), r'<int:pk>/',
re_path(r'^metadata/', MetadataView.as_view(), {'model': label.models.BuildLineLabel}, name='api-buildline-label-metadata'), include([
re_path(r'^.*$', BuildLineLabelDetail.as_view(), name='api-buildline-label-detail'), re_path(
])), r'^print/',
BuildLineLabelPrint.as_view(),
name='api-buildline-label-print',
),
re_path(
r'^metadata/',
MetadataView.as_view(),
{'model': label.models.BuildLineLabel},
name='api-buildline-label-metadata',
),
re_path(
r'^.*$',
BuildLineLabelDetail.as_view(),
name='api-buildline-label-detail',
),
]),
),
# List view # List view
re_path(r'^.*$', BuildLineLabelList.as_view(), name='api-buildline-label-list'), re_path(
])), r'^.*$', BuildLineLabelList.as_view(), name='api-buildline-label-list'
),
]),
),
] ]

View File

@ -14,7 +14,7 @@ from django.db.utils import IntegrityError, OperationalError, ProgrammingError
import InvenTree.ready import InvenTree.ready
logger = logging.getLogger("inventree") logger = logging.getLogger('inventree')
def hashFile(filename): def hashFile(filename):
@ -36,23 +36,37 @@ class LabelConfig(AppConfig):
def ready(self): def ready(self):
"""This function is called whenever the label app is loaded.""" """This function is called whenever the label app is loaded."""
# skip loading if plugin registry is not loaded or we run in a background thread # skip loading if plugin registry is not loaded or we run in a background thread
if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread(): if (
not InvenTree.ready.isPluginRegistryLoaded()
or not InvenTree.ready.isInMainThread()
):
return return
if InvenTree.ready.isRunningMigrations(): if InvenTree.ready.isRunningMigrations():
return return
if InvenTree.ready.canAppAccessDatabase(allow_test=False) and not InvenTree.ready.isImportingData(): if (
InvenTree.ready.canAppAccessDatabase(allow_test=False)
and not InvenTree.ready.isImportingData()
):
try: try:
self.create_labels() # pragma: no cover self.create_labels() # pragma: no cover
except (AppRegistryNotReady, IntegrityError, OperationalError, ProgrammingError): except (
AppRegistryNotReady,
IntegrityError,
OperationalError,
ProgrammingError,
):
# Database might not yet be ready # Database might not yet be ready
warnings.warn('Database was not ready for creating labels', stacklevel=2) warnings.warn(
'Database was not ready for creating labels', stacklevel=2
)
def create_labels(self): def create_labels(self):
"""Create all default templates.""" """Create all default templates."""
# Test if models are ready # Test if models are ready
import label.models import label.models
assert bool(label.models.StockLocationLabel is not None) assert bool(label.models.StockLocationLabel is not None)
# Create the categories # Create the categories
@ -66,7 +80,7 @@ class LabelConfig(AppConfig):
'description': 'Simple QR code label', 'description': 'Simple QR code label',
'width': 24, 'width': 24,
'height': 24, 'height': 24,
}, }
], ],
) )
@ -87,8 +101,8 @@ class LabelConfig(AppConfig):
'description': 'Label with QR code and name of location', 'description': 'Label with QR code and name of location',
'width': 50, 'width': 50,
'height': 24, 'height': 24,
} },
] ],
) )
self.create_labels_category( self.create_labels_category(
@ -109,7 +123,7 @@ class LabelConfig(AppConfig):
'width': 70, 'width': 70,
'height': 24, 'height': 24,
}, },
] ],
) )
self.create_labels_category( self.create_labels_category(
@ -122,24 +136,16 @@ class LabelConfig(AppConfig):
'description': 'Example build line label', 'description': 'Example build line label',
'width': 125, 'width': 125,
'height': 48, 'height': 48,
}, }
] ],
) )
def create_labels_category(self, model, ref_name, labels): def create_labels_category(self, model, ref_name, labels):
"""Create folder and database entries for the default templates, if they do not already exist.""" """Create folder and database entries for the default templates, if they do not already exist."""
# Create root dir for templates # Create root dir for templates
src_dir = Path(__file__).parent.joinpath( src_dir = Path(__file__).parent.joinpath('templates', 'label', ref_name)
'templates',
'label',
ref_name,
)
dst_dir = settings.MEDIA_ROOT.joinpath( dst_dir = settings.MEDIA_ROOT.joinpath('label', 'inventree', ref_name)
'label',
'inventree',
ref_name,
)
if not dst_dir.exists(): if not dst_dir.exists():
logger.info("Creating required directory: '%s'", dst_dir) logger.info("Creating required directory: '%s'", dst_dir)
@ -151,12 +157,7 @@ class LabelConfig(AppConfig):
def create_template_label(self, model, src_dir, ref_name, label): def create_template_label(self, model, src_dir, ref_name, label):
"""Ensure a label template is in place.""" """Ensure a label template is in place."""
filename = os.path.join( filename = os.path.join('label', 'inventree', ref_name, label['file'])
'label',
'inventree',
ref_name,
label['file']
)
src_file = src_dir.joinpath(label['file']) src_file = src_dir.joinpath(label['file'])
dst_file = settings.MEDIA_ROOT.joinpath(filename) dst_file = settings.MEDIA_ROOT.joinpath(filename)
@ -187,7 +188,10 @@ class LabelConfig(AppConfig):
if model.objects.filter(label=filename).exists(): if model.objects.filter(label=filename).exists():
return # pragma: no cover return # pragma: no cover
except Exception: except Exception:
logger.exception("Failed to query label for '%s' - you should run 'invoke update' first!", filename) logger.exception(
"Failed to query label for '%s' - you should run 'invoke update' first!",
filename,
)
logger.info("Creating entry for %s '%s'", model, label['name']) logger.info("Creating entry for %s '%s'", model, label['name'])

View File

@ -25,12 +25,12 @@ from plugin.registry import registry
try: try:
from django_weasyprint import WeasyTemplateResponseMixin from django_weasyprint import WeasyTemplateResponseMixin
except OSError as err: # pragma: no cover except OSError as err: # pragma: no cover
print(f"OSError: {err}") print(f'OSError: {err}')
print("You may require some further system packages to be installed.") print('You may require some further system packages to be installed.')
sys.exit(1) sys.exit(1)
logger = logging.getLogger("inventree") logger = logging.getLogger('inventree')
def rename_label(instance, filename): def rename_label(instance, filename):
@ -97,7 +97,7 @@ class LabelTemplate(MetadataMixin, models.Model):
abstract = True abstract = True
# Each class of label files will be stored in a separate subdirectory # Each class of label files will be stored in a separate subdirectory
SUBDIR = "label" SUBDIR = 'label'
# Object we will be printing against (will be filled out later) # Object we will be printing against (will be filled out later)
object_to_print = None object_to_print = None
@ -109,17 +109,16 @@ class LabelTemplate(MetadataMixin, models.Model):
def __str__(self): def __str__(self):
"""Format a string representation of a label instance""" """Format a string representation of a label instance"""
return f"{self.name} - {self.description}" return f'{self.name} - {self.description}'
name = models.CharField( name = models.CharField(
blank=False, max_length=100, blank=False, max_length=100, verbose_name=_('Name'), help_text=_('Label name')
verbose_name=_('Name'),
help_text=_('Label name'),
) )
description = models.CharField( description = models.CharField(
max_length=250, max_length=250,
blank=True, null=True, blank=True,
null=True,
verbose_name=_('Description'), verbose_name=_('Description'),
help_text=_('Label description'), help_text=_('Label description'),
) )
@ -127,7 +126,8 @@ class LabelTemplate(MetadataMixin, models.Model):
label = models.FileField( label = models.FileField(
upload_to=rename_label, upload_to=rename_label,
unique=True, unique=True,
blank=False, null=False, blank=False,
null=False,
verbose_name=_('Label'), verbose_name=_('Label'),
help_text=_('Label template file'), help_text=_('Label template file'),
validators=[FileExtensionValidator(allowed_extensions=['html'])], validators=[FileExtensionValidator(allowed_extensions=['html'])],
@ -143,18 +143,18 @@ class LabelTemplate(MetadataMixin, models.Model):
default=50, default=50,
verbose_name=_('Width [mm]'), verbose_name=_('Width [mm]'),
help_text=_('Label width, specified in mm'), help_text=_('Label width, specified in mm'),
validators=[MinValueValidator(2)] validators=[MinValueValidator(2)],
) )
height = models.FloatField( height = models.FloatField(
default=20, default=20,
verbose_name=_('Height [mm]'), verbose_name=_('Height [mm]'),
help_text=_('Label height, specified in mm'), help_text=_('Label height, specified in mm'),
validators=[MinValueValidator(2)] validators=[MinValueValidator(2)],
) )
filename_pattern = models.CharField( filename_pattern = models.CharField(
default="label.pdf", default='label.pdf',
verbose_name=_('Filename Pattern'), verbose_name=_('Filename Pattern'),
help_text=_('Pattern for generating label filenames'), help_text=_('Pattern for generating label filenames'),
max_length=100, max_length=100,
@ -249,11 +249,7 @@ class LabelTemplate(MetadataMixin, models.Model):
context = self.context(request, **kwargs) context = self.context(request, **kwargs)
return render_to_string( return render_to_string(self.template_name, context, request)
self.template_name,
context,
request
)
def render(self, request, target_object=None, **kwargs): def render(self, request, target_object=None, **kwargs):
"""Render the label template to a PDF file. """Render the label template to a PDF file.
@ -269,16 +265,13 @@ class LabelTemplate(MetadataMixin, models.Model):
wp = WeasyprintLabelMixin( wp = WeasyprintLabelMixin(
request, request,
self.template_name, self.template_name,
base_url=request.build_absolute_uri("/"), base_url=request.build_absolute_uri('/'),
presentational_hints=True, presentational_hints=True,
filename=self.generate_filename(request), filename=self.generate_filename(request),
**kwargs **kwargs,
) )
return wp.render_to_response( return wp.render_to_response(context, **kwargs)
context,
**kwargs
)
class LabelOutput(models.Model): class LabelOutput(models.Model):
@ -293,22 +286,14 @@ class LabelOutput(models.Model):
# File will be stored in a subdirectory # File will be stored in a subdirectory
label = models.FileField( label = models.FileField(
upload_to=rename_label_output, upload_to=rename_label_output, unique=True, blank=False, null=False
unique=True, blank=False, null=False,
) )
# Creation date of label output # Creation date of label output
created = models.DateField( created = models.DateField(auto_now_add=True, editable=False)
auto_now_add=True,
editable=False,
)
# User who generated the label # User who generated the label
user = models.ForeignKey( user = models.ForeignKey(User, on_delete=models.SET_NULL, blank=True, null=True)
User,
on_delete=models.SET_NULL,
blank=True, null=True,
)
class StockItemLabel(LabelTemplate): class StockItemLabel(LabelTemplate):
@ -319,15 +304,14 @@ class StockItemLabel(LabelTemplate):
"""Return the API URL associated with the StockItemLabel model""" """Return the API URL associated with the StockItemLabel model"""
return reverse('api-stockitem-label-list') # pragma: no cover return reverse('api-stockitem-label-list') # pragma: no cover
SUBDIR = "stockitem" SUBDIR = 'stockitem'
filters = models.CharField( filters = models.CharField(
blank=True, max_length=250, blank=True,
max_length=250,
help_text=_('Query filters (comma-separated list of key=value pairs)'), help_text=_('Query filters (comma-separated list of key=value pairs)'),
verbose_name=_('Filters'), verbose_name=_('Filters'),
validators=[ validators=[validate_stock_item_filters],
validate_stock_item_filters
]
) )
def get_context_data(self, request): def get_context_data(self, request):
@ -348,7 +332,6 @@ class StockItemLabel(LabelTemplate):
'qr_url': request.build_absolute_uri(stock_item.get_absolute_url()), 'qr_url': request.build_absolute_uri(stock_item.get_absolute_url()),
'tests': stock_item.testResultMap(), 'tests': stock_item.testResultMap(),
'parameters': stock_item.part.parameters_map(), 'parameters': stock_item.part.parameters_map(),
} }
@ -360,24 +343,21 @@ class StockLocationLabel(LabelTemplate):
"""Return the API URL associated with the StockLocationLabel model""" """Return the API URL associated with the StockLocationLabel model"""
return reverse('api-stocklocation-label-list') # pragma: no cover return reverse('api-stocklocation-label-list') # pragma: no cover
SUBDIR = "stocklocation" SUBDIR = 'stocklocation'
filters = models.CharField( filters = models.CharField(
blank=True, max_length=250, blank=True,
max_length=250,
help_text=_('Query filters (comma-separated list of key=value pairs)'), help_text=_('Query filters (comma-separated list of key=value pairs)'),
verbose_name=_('Filters'), verbose_name=_('Filters'),
validators=[ validators=[validate_stock_location_filters],
validate_stock_location_filters]
) )
def get_context_data(self, request): def get_context_data(self, request):
"""Generate context data for each provided StockLocation.""" """Generate context data for each provided StockLocation."""
location = self.object_to_print location = self.object_to_print
return { return {'location': location, 'qr_data': location.format_barcode(brief=True)}
'location': location,
'qr_data': location.format_barcode(brief=True),
}
class PartLabel(LabelTemplate): class PartLabel(LabelTemplate):
@ -391,12 +371,11 @@ class PartLabel(LabelTemplate):
SUBDIR = 'part' SUBDIR = 'part'
filters = models.CharField( filters = models.CharField(
blank=True, max_length=250, blank=True,
max_length=250,
help_text=_('Query filters (comma-separated list of key=value pairs)'), help_text=_('Query filters (comma-separated list of key=value pairs)'),
verbose_name=_('Filters'), verbose_name=_('Filters'),
validators=[ validators=[validate_part_filters],
validate_part_filters
]
) )
def get_context_data(self, request): def get_context_data(self, request):
@ -427,12 +406,11 @@ class BuildLineLabel(LabelTemplate):
SUBDIR = 'buildline' SUBDIR = 'buildline'
filters = models.CharField( filters = models.CharField(
blank=True, max_length=250, blank=True,
max_length=250,
help_text=_('Query filters (comma-separated list of key=value pairs)'), help_text=_('Query filters (comma-separated list of key=value pairs)'),
verbose_name=_('Filters'), verbose_name=_('Filters'),
validators=[ validators=[validate_build_line_filters],
validate_build_line_filters
]
) )
def get_context_data(self, request): def get_context_data(self, request):

View File

@ -1,8 +1,10 @@
"""API serializers for the label app""" """API serializers for the label app"""
import label.models import label.models
from InvenTree.serializers import (InvenTreeAttachmentSerializerField, from InvenTree.serializers import (
InvenTreeModelSerializer) InvenTreeAttachmentSerializerField,
InvenTreeModelSerializer,
)
class LabelSerializerBase(InvenTreeModelSerializer): class LabelSerializerBase(InvenTreeModelSerializer):
@ -13,14 +15,7 @@ class LabelSerializerBase(InvenTreeModelSerializer):
@staticmethod @staticmethod
def label_fields(): def label_fields():
"""Generic serializer fields for a label template""" """Generic serializer fields for a label template"""
return [ return ['pk', 'name', 'description', 'label', 'filters', 'enabled']
'pk',
'name',
'description',
'label',
'filters',
'enabled',
]
class StockItemLabelSerializer(LabelSerializerBase): class StockItemLabelSerializer(LabelSerializerBase):

View File

@ -8,17 +8,9 @@ from InvenTree.unit_test import InvenTreeAPITestCase
class TestReportTests(InvenTreeAPITestCase): class TestReportTests(InvenTreeAPITestCase):
"""Tests for the StockItem TestReport templates.""" """Tests for the StockItem TestReport templates."""
fixtures = [ fixtures = ['category', 'part', 'location', 'stock']
'category',
'part',
'location',
'stock',
]
roles = [ roles = ['stock.view', 'stock_location.view']
'stock.view',
'stock_location.view',
]
list_url = reverse('api-stockitem-testreport-list') list_url = reverse('api-stockitem-testreport-list')
@ -42,22 +34,10 @@ class TestReportTests(InvenTreeAPITestCase):
self.assertEqual(len(response), 0) self.assertEqual(len(response), 0)
# TODO - Add some tests to this response # TODO - Add some tests to this response
response = self.do_list( response = self.do_list({'item': 10})
{
'item': 10,
}
)
# TODO - Add some tests to this response # TODO - Add some tests to this response
response = self.do_list( response = self.do_list({'item': 100000})
{
'item': 100000,
}
)
# TODO - Add some tests to this response # TODO - Add some tests to this response
response = self.do_list( response = self.do_list({'items': [10, 11, 12]})
{
'items': [10, 11, 12],
}
)

View File

@ -24,12 +24,7 @@ from .models import PartLabel, StockItemLabel, StockLocationLabel
class LabelTest(InvenTreeAPITestCase): class LabelTest(InvenTreeAPITestCase):
"""Unit test class for label models""" """Unit test class for label models"""
fixtures = [ fixtures = ['category', 'part', 'location', 'stock']
'category',
'part',
'location',
'stock'
]
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
@ -49,12 +44,9 @@ class LabelTest(InvenTreeAPITestCase):
def test_default_files(self): def test_default_files(self):
"""Test that label files exist in the MEDIA directory.""" """Test that label files exist in the MEDIA directory."""
def test_subdir(ref_name): def test_subdir(ref_name):
item_dir = settings.MEDIA_ROOT.joinpath( item_dir = settings.MEDIA_ROOT.joinpath('label', 'inventree', ref_name)
'label',
'inventree',
ref_name,
)
self.assertTrue(len([item_dir.iterdir()]) > 0) self.assertTrue(len([item_dir.iterdir()]) > 0)
test_subdir('stockitem') test_subdir('stockitem')
@ -63,13 +55,13 @@ class LabelTest(InvenTreeAPITestCase):
def test_filters(self): def test_filters(self):
"""Test the label filters.""" """Test the label filters."""
filter_string = "part__pk=10" filter_string = 'part__pk=10'
filters = validateFilterString(filter_string, model=StockItem) filters = validateFilterString(filter_string, model=StockItem)
self.assertEqual(type(filters), dict) self.assertEqual(type(filters), dict)
bad_filter_string = "part_pk=10" bad_filter_string = 'part_pk=10'
with self.assertRaises(ValidationError): with self.assertRaises(ValidationError):
validateFilterString(bad_filter_string, model=StockItem) validateFilterString(bad_filter_string, model=StockItem)
@ -115,14 +107,11 @@ class LabelTest(InvenTreeAPITestCase):
buffer = io.StringIO() buffer = io.StringIO()
buffer.write(label_data) buffer.write(label_data)
template = ContentFile(buffer.getvalue(), "label.html") template = ContentFile(buffer.getvalue(), 'label.html')
# Construct a label template # Construct a label template
label = PartLabel.objects.create( label = PartLabel.objects.create(
name='test', name='test', description='Test label', enabled=True, label=template
description='Test label',
enabled=True,
label=template,
) )
# Ensure we are in "debug" mode (so the report is generated as HTML) # Ensure we are in "debug" mode (so the report is generated as HTML)
@ -151,7 +140,7 @@ class LabelTest(InvenTreeAPITestCase):
content = f.read() content = f.read()
# Test that each element has been rendered correctly # Test that each element has been rendered correctly
self.assertIn(f"part: {part_pk} - {part_name}", content) self.assertIn(f'part: {part_pk} - {part_name}', content)
self.assertIn(f'data: {{"part": {part_pk}}}', content) self.assertIn(f'data: {{"part": {part_pk}}}', content)
self.assertIn(f'http://testserver/part/{part_pk}/', content) self.assertIn(f'http://testserver/part/{part_pk}/', content)

View File

@ -7,17 +7,17 @@ import sys
def main(): def main():
"""Run administrative tasks.""" """Run administrative tasks."""
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "InvenTree.settings") os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'InvenTree.settings')
try: try:
from django.core.management import execute_from_command_line from django.core.management import execute_from_command_line
except ImportError as exc: # pragma: no cover except ImportError as exc: # pragma: no cover
raise ImportError( raise ImportError(
"Couldn't import Django. Are you sure it's installed and " "Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you " 'available on your PYTHONPATH environment variable? Did you '
"forget to activate a virtual environment?" 'forget to activate a virtual environment?'
) from exc ) from exc
execute_from_command_line(sys.argv) execute_from_command_line(sys.argv)
if __name__ == "__main__": if __name__ == '__main__':
main() main()

View File

@ -50,52 +50,54 @@ class PriceResourceMixin:
# region general classes # region general classes
class GeneralExtraLineAdmin: class GeneralExtraLineAdmin:
"""Admin class template for the 'ExtraLineItem' models""" """Admin class template for the 'ExtraLineItem' models"""
list_display = (
'order',
'quantity',
'reference'
)
search_fields = [ list_display = ('order', 'quantity', 'reference')
'order__reference',
'order__customer__name',
'reference',
]
autocomplete_fields = ('order', ) search_fields = ['order__reference', 'order__customer__name', 'reference']
autocomplete_fields = ('order',)
class GeneralExtraLineMeta: class GeneralExtraLineMeta:
"""Metaclass template for the 'ExtraLineItem' models""" """Metaclass template for the 'ExtraLineItem' models"""
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
clean_model_instances = True clean_model_instances = True
# endregion # endregion
class PurchaseOrderLineItemInlineAdmin(admin.StackedInline): class PurchaseOrderLineItemInlineAdmin(admin.StackedInline):
"""Inline admin class for the PurchaseOrderLineItem model""" """Inline admin class for the PurchaseOrderLineItem model"""
model = models.PurchaseOrderLineItem model = models.PurchaseOrderLineItem
extra = 0 extra = 0
class PurchaseOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource): class PurchaseOrderResource(
ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource
):
"""Class for managing import / export of PurchaseOrder data.""" """Class for managing import / export of PurchaseOrder data."""
class Meta: class Meta:
"""Metaclass""" """Metaclass"""
model = models.PurchaseOrder model = models.PurchaseOrder
skip_unchanged = True skip_unchanged = True
clean_model_instances = True clean_model_instances = True
exclude = [ exclude = ['metadata']
'metadata',
]
# Add number of line items # Add number of line items
line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True) line_items = Field(
attribute='line_count', widget=widgets.IntegerWidget(), readonly=True
)
# Is this order overdue? # Is this order overdue?
overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True) overdue = Field(
attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True
)
class PurchaseOrderAdmin(ImportExportModelAdmin): class PurchaseOrderAdmin(ImportExportModelAdmin):
@ -103,48 +105,39 @@ class PurchaseOrderAdmin(ImportExportModelAdmin):
resource_class = PurchaseOrderResource resource_class = PurchaseOrderResource
exclude = [ exclude = ['reference_int']
'reference_int',
]
list_display = ( list_display = ('reference', 'supplier', 'status', 'description', 'creation_date')
'reference',
'supplier',
'status',
'description',
'creation_date'
)
search_fields = [ search_fields = ['reference', 'supplier__name', 'description']
'reference',
'supplier__name',
'description',
]
inlines = [ inlines = [PurchaseOrderLineItemInlineAdmin]
PurchaseOrderLineItemInlineAdmin
]
autocomplete_fields = ('supplier',) autocomplete_fields = ('supplier',)
class SalesOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource): class SalesOrderResource(
ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource
):
"""Class for managing import / export of SalesOrder data.""" """Class for managing import / export of SalesOrder data."""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options"""
model = models.SalesOrder model = models.SalesOrder
skip_unchanged = True skip_unchanged = True
clean_model_instances = True clean_model_instances = True
exclude = [ exclude = ['metadata']
'metadata',
]
# Add number of line items # Add number of line items
line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True) line_items = Field(
attribute='line_count', widget=widgets.IntegerWidget(), readonly=True
)
# Is this order overdue? # Is this order overdue?
overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True) overdue = Field(
attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True
)
class SalesOrderAdmin(ImportExportModelAdmin): class SalesOrderAdmin(ImportExportModelAdmin):
@ -152,23 +145,11 @@ class SalesOrderAdmin(ImportExportModelAdmin):
resource_class = SalesOrderResource resource_class = SalesOrderResource
exclude = [ exclude = ['reference_int']
'reference_int',
]
list_display = ( list_display = ('reference', 'customer', 'status', 'description', 'creation_date')
'reference',
'customer',
'status',
'description',
'creation_date',
)
search_fields = [ search_fields = ['reference', 'customer__name', 'description']
'reference',
'customer__name',
'description',
]
autocomplete_fields = ('customer',) autocomplete_fields = ('customer',)
@ -178,6 +159,7 @@ class PurchaseOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
class Meta: class Meta:
"""Metaclass""" """Metaclass"""
model = models.PurchaseOrderLineItem model = models.PurchaseOrderLineItem
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -212,6 +194,7 @@ class SalesOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
class Meta: class Meta:
"""Metaclass options""" """Metaclass options"""
model = models.SalesOrderLineItem model = models.SalesOrderLineItem
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -240,6 +223,7 @@ class SalesOrderExtraLineResource(PriceResourceMixin, InvenTreeResource):
class Meta(GeneralExtraLineMeta): class Meta(GeneralExtraLineMeta):
"""Metaclass options.""" """Metaclass options."""
model = models.SalesOrderExtraLine model = models.SalesOrderExtraLine
@ -248,20 +232,16 @@ class PurchaseOrderLineItemAdmin(ImportExportModelAdmin):
resource_class = PurchaseOrderLineItemResource resource_class = PurchaseOrderLineItemResource
list_display = ( list_display = ('order', 'part', 'quantity', 'reference')
'order',
'part',
'quantity',
'reference'
)
search_fields = ('reference',) search_fields = ('reference',)
autocomplete_fields = ('order', 'part', 'destination',) autocomplete_fields = ('order', 'part', 'destination')
class PurchaseOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): class PurchaseOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the PurchaseOrderExtraLine model""" """Admin class for the PurchaseOrderExtraLine model"""
resource_class = PurchaseOrderExtraLineResource resource_class = PurchaseOrderExtraLineResource
@ -270,12 +250,7 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin):
resource_class = SalesOrderLineItemResource resource_class = SalesOrderLineItemResource
list_display = ( list_display = ('order', 'part', 'quantity', 'reference')
'order',
'part',
'quantity',
'reference'
)
search_fields = [ search_fields = [
'part__name', 'part__name',
@ -284,28 +259,21 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin):
'reference', 'reference',
] ]
autocomplete_fields = ('order', 'part',) autocomplete_fields = ('order', 'part')
class SalesOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): class SalesOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the SalesOrderExtraLine model""" """Admin class for the SalesOrderExtraLine model"""
resource_class = SalesOrderExtraLineResource resource_class = SalesOrderExtraLineResource
class SalesOrderShipmentAdmin(ImportExportModelAdmin): class SalesOrderShipmentAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderShipment model""" """Admin class for the SalesOrderShipment model"""
list_display = [ list_display = ['order', 'shipment_date', 'reference']
'order',
'shipment_date',
'reference',
]
search_fields = [ search_fields = ['reference', 'order__reference', 'order__customer__name']
'reference',
'order__reference',
'order__customer__name',
]
autocomplete_fields = ('order',) autocomplete_fields = ('order',)
@ -313,26 +281,23 @@ class SalesOrderShipmentAdmin(ImportExportModelAdmin):
class SalesOrderAllocationAdmin(ImportExportModelAdmin): class SalesOrderAllocationAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderAllocation model""" """Admin class for the SalesOrderAllocation model"""
list_display = ( list_display = ('line', 'item', 'quantity')
'line',
'item',
'quantity'
)
autocomplete_fields = ('line', 'shipment', 'item',) autocomplete_fields = ('line', 'shipment', 'item')
class ReturnOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource): class ReturnOrderResource(
ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource
):
"""Class for managing import / export of ReturnOrder data""" """Class for managing import / export of ReturnOrder data"""
class Meta: class Meta:
"""Metaclass options""" """Metaclass options"""
model = models.ReturnOrder model = models.ReturnOrder
skip_unchanged = True skip_unchanged = True
clean_model_instances = True clean_model_instances = True
exclude = [ exclude = ['metadata']
'metadata',
]
class ReturnOrderAdmin(ImportExportModelAdmin): class ReturnOrderAdmin(ImportExportModelAdmin):
@ -340,25 +305,13 @@ class ReturnOrderAdmin(ImportExportModelAdmin):
resource_class = ReturnOrderResource resource_class = ReturnOrderResource
exclude = [ exclude = ['reference_int']
'reference_int',
]
list_display = [ list_display = ['reference', 'customer', 'status']
'reference',
'customer',
'status',
]
search_fields = [ search_fields = ['reference', 'customer__name', 'description']
'reference',
'customer__name',
'description',
]
autocomplete_fields = [ autocomplete_fields = ['customer']
'customer',
]
class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource): class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
@ -366,6 +319,7 @@ class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
class Meta: class Meta:
"""Metaclass options""" """Metaclass options"""
model = models.ReturnOrderLineItem model = models.ReturnOrderLineItem
skip_unchanged = True skip_unchanged = True
report_skipped = False report_skipped = False
@ -377,11 +331,7 @@ class ReturnOrderLineItemAdmin(ImportExportModelAdmin):
resource_class = ReturnOrderLineItemResource resource_class = ReturnOrderLineItemResource
list_display = [ list_display = ['order', 'item', 'reference']
'order',
'item',
'reference',
]
class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource): class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource):
@ -389,11 +339,13 @@ class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource):
class Meta(GeneralExtraLineMeta): class Meta(GeneralExtraLineMeta):
"""Metaclass options""" """Metaclass options"""
model = models.ReturnOrderExtraLine model = models.ReturnOrderExtraLine
class ReturnOrdeerExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin): class ReturnOrdeerExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the ReturnOrderExtraLine model""" """Admin class for the ReturnOrderExtraLine model"""
resource_class = ReturnOrderExtraLineClass resource_class = ReturnOrderExtraLineClass

File diff suppressed because it is too large Load Diff

View File

@ -5,4 +5,5 @@ from django.apps import AppConfig
class OrderConfig(AppConfig): class OrderConfig(AppConfig):
"""Configuration class for the 'order' app""" """Configuration class for the 'order' app"""
name = 'order' name = 'order'

View File

@ -17,14 +17,16 @@ class OrderMatchItemForm(MatchItemForm):
if 'quantity' in col_guess.lower(): if 'quantity' in col_guess.lower():
return forms.CharField( return forms.CharField(
required=False, required=False,
widget=forms.NumberInput(attrs={ widget=forms.NumberInput(
attrs={
'name': 'quantity' + str(row['index']), 'name': 'quantity' + str(row['index']),
'class': 'numberinput', 'class': 'numberinput',
'type': 'number', 'type': 'number',
'min': '0', 'min': '0',
'step': 'any', 'step': 'any',
'value': clean_decimal(row.get('quantity', '')), 'value': clean_decimal(row.get('quantity', '')),
}) }
),
) )
# set price field # set price field
elif 'price' in col_guess.lower(): elif 'price' in col_guess.lower():

Some files were not shown because too many files have changed in this diff Show More