[CI] Enable python autoformat (#6169)

* Squashed commit of the following:

commit f5cf7b2e78
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:36:57 2024 +0100

    fixed reqs

commit 9d845bee98
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:32:35 2024 +0100

    disable autofix/format

commit aff5f27148
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:50 2024 +0100

    adjust checks

commit 47271cf1ef
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:22 2024 +0100

    reorder order of operations

commit e1bf178b40
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:01:09 2024 +0100

    adapted ruff settings to better fit code base

commit ad7d88a6f4
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:59:45 2024 +0100

    auto fixed docstring

commit a2e54a760e
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:46:35 2024 +0100

    fix getattr useage

commit cb80c73bc6
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:25:09 2024 +0100

    fix requirements file

commit b7780bbd21
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:42:28 2024 +0100

    fix removed sections

commit 71f1681f55
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:41:21 2024 +0100

    fix djlint syntax

commit a0bcf1bcce
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:35:28 2024 +0100

    remove flake8 from code base

commit 22475b31cc
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:34:56 2024 +0100

    remove flake8 from code base

commit 0413350f14
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:39 2024 +0100

    moved ruff section

commit d90c48a0bf
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:24 2024 +0100

    move djlint config to pyproject

commit c5ce55d511
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:20:39 2024 +0100

    added isort again

commit 42a41d23af
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:19:02 2024 +0100

    move config section

commit 8569233181
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:17:52 2024 +0100

    fix codespell error

commit 2897c6704d
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 17:29:21 2024 +0100

    replaced flake8 with ruff
    mostly for speed improvements

* enable autoformat

* added autofixes

* switched to single quotes everywhere

* switched to ruff for import sorting

* fix wrong url response

* switched to pathlib for lookup

* fixed lookup

* Squashed commit of the following:

commit d3b795824b
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:56:17 2024 +0100

    fixed source path

commit 0bac0c19b8
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:47:53 2024 +0100

    fixed req

commit 9f61f01d9c
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 22:45:18 2024 +0100

    added missing toml req

commit 91b71ed24a
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:49:50 2024 +0100

    moved isort config

commit 12460b0419
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:43:22 2024 +0100

    remove flake8 section from setup.cfg

commit f5cf7b2e78
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:36:57 2024 +0100

    fixed reqs

commit 9d845bee98
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:32:35 2024 +0100

    disable autofix/format

commit aff5f27148
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:50 2024 +0100

    adjust checks

commit 47271cf1ef
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:28:22 2024 +0100

    reorder order of operations

commit e1bf178b40
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 20:01:09 2024 +0100

    adapted ruff settings to better fit code base

commit ad7d88a6f4
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:59:45 2024 +0100

    auto fixed docstring

commit a2e54a760e
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:46:35 2024 +0100

    fix getattr useage

commit cb80c73bc6
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 19:25:09 2024 +0100

    fix requirements file

commit b7780bbd21
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:42:28 2024 +0100

    fix removed sections

commit 71f1681f55
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:41:21 2024 +0100

    fix djlint syntax

commit a0bcf1bcce
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:35:28 2024 +0100

    remove flake8 from code base

commit 22475b31cc
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:34:56 2024 +0100

    remove flake8 from code base

commit 0413350f14
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:39 2024 +0100

    moved ruff section

commit d90c48a0bf
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:24:24 2024 +0100

    move djlint config to pyproject

commit c5ce55d511
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:20:39 2024 +0100

    added isort again

commit 42a41d23af
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:19:02 2024 +0100

    move config section

commit 8569233181
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 18:17:52 2024 +0100

    fix codespell error

commit 2897c6704d
Author: Matthias Mair <code@mjmair.com>
Date:   Sun Jan 7 17:29:21 2024 +0100

    replaced flake8 with ruff
    mostly for speed improvements

* fix coverage souce format

---------

Co-authored-by: Oliver Walters <oliver.henry.walters@gmail.com>
This commit is contained in:
Matthias Mair 2024-01-11 01:28:58 +01:00 committed by GitHub
parent 9715af564f
commit 4b14986591
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
257 changed files with 13422 additions and 12200 deletions

View File

@ -15,18 +15,14 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: mixed-line-ending
- repo: https://github.com/pycqa/isort
rev: '5.12.0'
hooks:
- id: isort
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
hooks:
# - id: ruff-format
# args: [--preview]
- id: ruff-format
args: [--preview]
- id: ruff
args: [
#--fix,
--fix,
--preview
]
- repo: https://github.com/jazzband/pip-tools

View File

@ -30,14 +30,18 @@ class InvenTreeResource(ModelResource):
using_transactions,
collect_failed_rows,
rollback_on_validation_errors=None,
**kwargs
**kwargs,
):
"""Override the default import_data_inner function to provide better error handling"""
if len(dataset) > self.MAX_IMPORT_ROWS:
raise ImportExportError(f"Dataset contains too many rows (max {self.MAX_IMPORT_ROWS})")
raise ImportExportError(
f'Dataset contains too many rows (max {self.MAX_IMPORT_ROWS})'
)
if len(dataset.headers) > self.MAX_IMPORT_COLS:
raise ImportExportError(f"Dataset contains too many columns (max {self.MAX_IMPORT_COLS})")
raise ImportExportError(
f'Dataset contains too many columns (max {self.MAX_IMPORT_COLS})'
)
return super().import_data_inner(
dataset,
@ -46,7 +50,7 @@ class InvenTreeResource(ModelResource):
using_transactions,
collect_failed_rows,
rollback_on_validation_errors=rollback_on_validation_errors,
**kwargs
**kwargs,
)
def export_resource(self, obj):
@ -75,10 +79,7 @@ class InvenTreeResource(ModelResource):
"""Return fields, with some common exclusions"""
fields = super().get_fields(**kwargs)
fields_to_exclude = [
'metadata',
'lft', 'rght', 'tree_id', 'level',
]
fields_to_exclude = ['metadata', 'lft', 'rght', 'tree_id', 'level']
return [f for f in fields if f.column_name not in fields_to_exclude]

View File

@ -31,9 +31,7 @@ from .views import AjaxView
class VersionView(APIView):
"""Simple JSON endpoint for InvenTree version information."""
permission_classes = [
permissions.IsAdminUser,
]
permission_classes = [permissions.IsAdminUser]
def get(self, request, *args, **kwargs):
"""Return information about the InvenTree server."""
@ -47,20 +45,21 @@ class VersionView(APIView):
'commit_date': InvenTree.version.inventreeCommitDate(),
'commit_branch': InvenTree.version.inventreeBranch(),
'python': InvenTree.version.inventreePythonVersion(),
'django': InvenTree.version.inventreeDjangoVersion()
'django': InvenTree.version.inventreeDjangoVersion(),
},
'links': {
'doc': InvenTree.version.inventreeDocUrl(),
'code': InvenTree.version.inventreeGithubUrl(),
'credit': InvenTree.version.inventreeCreditsUrl(),
'app': InvenTree.version.inventreeAppUrl(),
'bug': f'{InvenTree.version.inventreeGithubUrl()}/issues'
}
'bug': f'{InvenTree.version.inventreeGithubUrl()}/issues',
},
})
class VersionSerializer(serializers.Serializer):
"""Serializer for a single version."""
version = serializers.CharField()
date = serializers.CharField()
gh = serializers.CharField()
@ -69,16 +68,19 @@ class VersionSerializer(serializers.Serializer):
class Meta:
"""Meta class for VersionSerializer."""
fields = ['version', 'date', 'gh', 'text', 'latest']
class VersionApiSerializer(serializers.Serializer):
"""Serializer for the version api endpoint."""
VersionSerializer(many=True)
class VersionTextView(ListAPI):
"""Simple JSON endpoint for InvenTree version text."""
permission_classes = [permissions.IsAdminUser]
@extend_schema(responses={200: OpenApiResponse(response=VersionApiSerializer)})
@ -119,10 +121,10 @@ class InfoView(AjaxView):
'debug_mode': settings.DEBUG,
'docker_mode': settings.DOCKER,
'system_health': check_system_health() if is_staff else None,
'database': InvenTree.version.inventreeDatabase()if is_staff else None,
'database': InvenTree.version.inventreeDatabase() if is_staff else None,
'platform': InvenTree.version.inventreePlatform() if is_staff else None,
'installer': InvenTree.version.inventreeInstaller() if is_staff else None,
'target': InvenTree.version.inventreeTarget()if is_staff else None,
'target': InvenTree.version.inventreeTarget() if is_staff else None,
}
return JsonResponse(data)
@ -130,7 +132,9 @@ class InfoView(AjaxView):
def check_auth_header(self, request):
"""Check if user is authenticated via a token in the header."""
# TODO @matmair: remove after refacgtor of Token check is done
headers = request.headers.get('Authorization', request.headers.get('authorization'))
headers = request.headers.get(
'Authorization', request.headers.get('authorization')
)
if not headers:
return False
@ -160,7 +164,7 @@ class NotFoundView(AjaxView):
'detail': _('API endpoint not found'),
'url': request.build_absolute_uri(),
},
status=404
status=404,
)
def options(self, request, *args, **kwargs):
@ -228,24 +232,25 @@ class BulkDeleteMixin:
if not items and not filters:
raise ValidationError({
"non_field_errors": ["List of items or filters must be provided for bulk deletion"],
'non_field_errors': [
'List of items or filters must be provided for bulk deletion'
]
})
if items and type(items) is not list:
raise ValidationError({
"items": ["'items' must be supplied as a list object"]
'items': ["'items' must be supplied as a list object"]
})
if filters and type(filters) is not dict:
raise ValidationError({
"filters": ["'filters' must be supplied as a dict object"]
'filters': ["'filters' must be supplied as a dict object"]
})
# Keep track of how many items we deleted
n_deleted = 0
with transaction.atomic():
# Start with *all* models and perform basic filtering
queryset = model.objects.all()
queryset = self.filter_delete_queryset(queryset, request)
@ -261,16 +266,12 @@ class BulkDeleteMixin:
n_deleted = queryset.count()
queryset.delete()
return Response(
{
'success': f"Deleted {n_deleted} items",
},
status=204
)
return Response({'success': f'Deleted {n_deleted} items'}, status=204)
class ListCreateDestroyAPIView(BulkDeleteMixin, ListCreateAPI):
"""Custom API endpoint which provides BulkDelete functionality in addition to List and Create"""
...
@ -307,24 +308,17 @@ class APIDownloadMixin:
def download_queryset(self, queryset, export_format):
"""This function must be implemented to provide a downloadFile request."""
raise NotImplementedError("download_queryset method not implemented!")
raise NotImplementedError('download_queryset method not implemented!')
class AttachmentMixin:
"""Mixin for creating attachment objects, and ensuring the user information is saved correctly."""
permission_classes = [
permissions.IsAuthenticated,
RolePermission,
]
permission_classes = [permissions.IsAuthenticated, RolePermission]
filter_backends = SEARCH_ORDER_FILTER
search_fields = [
'attachment',
'comment',
'link',
]
search_fields = ['attachment', 'comment', 'link']
def perform_create(self, serializer):
"""Save the user information when a file is uploaded."""
@ -342,9 +336,7 @@ class APISearchView(APIView):
Is much more efficient and simplifies code!
"""
permission_classes = [
permissions.IsAuthenticated,
]
permission_classes = [permissions.IsAuthenticated]
def get_result_types(self):
"""Construct a list of search types we can return"""
@ -384,14 +376,11 @@ class APISearchView(APIView):
}
if 'search' not in data:
raise ValidationError({
'search': 'Search term must be provided',
})
raise ValidationError({'search': 'Search term must be provided'})
for key, cls in self.get_result_types().items():
# Only return results which are specifically requested
if key in data:
params = data[key]
for k, v in pass_through_params.items():
@ -418,16 +407,18 @@ class APISearchView(APIView):
table = f'{app_label}_{model_name}'
try:
if users.models.RuleSet.check_table_permission(request.user, table, 'view'):
if users.models.RuleSet.check_table_permission(
request.user, table, 'view'
):
results[key] = view.list(request, *args, **kwargs).data
else:
results[key] = {
'error': _('User does not have permission to view this model')
'error': _(
'User does not have permission to view this model'
)
}
except Exception as exc:
results[key] = {
'error': str(exc)
}
results[key] = {'error': str(exc)}
return Response(results)
@ -442,7 +433,9 @@ class MetadataView(RetrieveUpdateAPI):
model = self.kwargs.get(self.MODEL_REF, None)
if model is None:
raise ValidationError(f"MetadataView called without '{self.MODEL_REF}' parameter")
raise ValidationError(
f"MetadataView called without '{self.MODEL_REF}' parameter"
)
return model

View File

@ -1,6 +1,5 @@
"""InvenTree API version information."""
# InvenTree API version
INVENTREE_API_VERSION = 160
"""Increment this API version number whenever there is a significant change to the API that any clients need to know about."""

View File

@ -16,11 +16,12 @@ import InvenTree.ready
import InvenTree.tasks
from InvenTree.config import get_setting
logger = logging.getLogger("inventree")
logger = logging.getLogger('inventree')
class InvenTreeConfig(AppConfig):
"""AppConfig for inventree app."""
name = 'InvenTree'
def ready(self):
@ -36,7 +37,10 @@ class InvenTreeConfig(AppConfig):
- Adding users set in the current environment
"""
# skip loading if plugin registry is not loaded or we run in a background thread
if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread():
if (
not InvenTree.ready.isPluginRegistryLoaded()
or not InvenTree.ready.isInMainThread()
):
return
# Skip if running migrations
@ -44,7 +48,6 @@ class InvenTreeConfig(AppConfig):
return
if InvenTree.ready.canAppAccessDatabase() or settings.TESTING_ENV:
self.remove_obsolete_tasks()
self.collect_tasks()
@ -81,11 +84,11 @@ class InvenTreeConfig(AppConfig):
try:
Schedule.objects.filter(func__in=obsolete).delete()
except Exception:
logger.exception("Failed to remove obsolete tasks - database not ready")
logger.exception('Failed to remove obsolete tasks - database not ready')
def start_background_tasks(self):
"""Start all background tests for InvenTree."""
logger.info("Starting background tasks...")
logger.info('Starting background tasks...')
from django_q.models import Schedule
@ -102,15 +105,16 @@ class InvenTreeConfig(AppConfig):
tasks = InvenTree.tasks.tasks.task_list
for task in tasks:
ref_name = f'{task.func.__module__}.{task.func.__name__}'
if ref_name in existing_tasks.keys():
# This task already exists - update the details if required
existing_task = existing_tasks[ref_name]
if existing_task.schedule_type != task.interval or existing_task.minutes != task.minutes:
if (
existing_task.schedule_type != task.interval
or existing_task.minutes != task.minutes
):
existing_task.schedule_type = task.interval
existing_task.minutes = task.minutes
tasks_to_update.append(existing_task)
@ -128,20 +132,17 @@ class InvenTreeConfig(AppConfig):
if len(tasks_to_create) > 0:
Schedule.objects.bulk_create(tasks_to_create)
logger.info("Created %s new scheduled tasks", len(tasks_to_create))
logger.info('Created %s new scheduled tasks', len(tasks_to_create))
if len(tasks_to_update) > 0:
Schedule.objects.bulk_update(tasks_to_update, ['schedule_type', 'minutes'])
logger.info("Updated %s existing scheduled tasks", len(tasks_to_update))
logger.info('Updated %s existing scheduled tasks', len(tasks_to_update))
# Put at least one task onto the background worker stack,
# which will be processed as soon as the worker comes online
InvenTree.tasks.offload_task(
InvenTree.tasks.heartbeat,
force_async=True,
)
InvenTree.tasks.offload_task(InvenTree.tasks.heartbeat, force_async=True)
logger.info("Started %s scheduled background tasks...", len(tasks))
logger.info('Started %s scheduled background tasks...', len(tasks))
def collect_tasks(self):
"""Collect all background tasks."""
@ -153,7 +154,7 @@ class InvenTreeConfig(AppConfig):
try:
import_module(f'{app.module.__package__}.tasks')
except Exception as e: # pragma: no cover
logger.exception("Error loading tasks for %s: %s", app_name, e)
logger.exception('Error loading tasks for %s: %s', app_name, e)
def update_exchange_rates(self): # pragma: no cover
"""Update exchange rates each time the server is started.
@ -184,16 +185,20 @@ class InvenTreeConfig(AppConfig):
if last_update is None:
# Never been updated
logger.info("Exchange backend has never been updated")
logger.info('Exchange backend has never been updated')
update = True
# Backend currency has changed?
if base_currency != backend.base_currency:
logger.info("Base currency changed from %s to %s", backend.base_currency, base_currency)
logger.info(
'Base currency changed from %s to %s',
backend.base_currency,
base_currency,
)
update = True
except (ExchangeBackend.DoesNotExist):
logger.info("Exchange backend not found - updating")
except ExchangeBackend.DoesNotExist:
logger.info('Exchange backend not found - updating')
update = True
except Exception:
@ -204,9 +209,9 @@ class InvenTreeConfig(AppConfig):
try:
update_exchange_rates()
except OperationalError:
logger.warning("Could not update exchange rates - database not ready")
logger.warning('Could not update exchange rates - database not ready')
except Exception as e:
logger.exception("Error updating exchange rates: %s (%s)", e, type(e))
logger.exception('Error updating exchange rates: %s (%s)', e, type(e))
def add_user_on_startup(self):
"""Add a user on startup."""
@ -218,7 +223,9 @@ class InvenTreeConfig(AppConfig):
add_user = get_setting('INVENTREE_ADMIN_USER', 'admin_user')
add_email = get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email')
add_password = get_setting('INVENTREE_ADMIN_PASSWORD', 'admin_password')
add_password_file = get_setting("INVENTREE_ADMIN_PASSWORD_FILE", "admin_password_file", None)
add_password_file = get_setting(
'INVENTREE_ADMIN_PASSWORD_FILE', 'admin_password_file', None
)
# check if all values are present
set_variables = 0
@ -239,7 +246,9 @@ class InvenTreeConfig(AppConfig):
# if a password file is present, do not warn - will be handled later
if add_password_file:
return
logger.warning('Not all required settings for adding a user on startup are present:\nINVENTREE_ADMIN_USER, INVENTREE_ADMIN_EMAIL, INVENTREE_ADMIN_PASSWORD')
logger.warning(
'Not all required settings for adding a user on startup are present:\nINVENTREE_ADMIN_USER, INVENTREE_ADMIN_EMAIL, INVENTREE_ADMIN_PASSWORD'
)
return
# good to go -> create user
@ -253,9 +262,11 @@ class InvenTreeConfig(AppConfig):
try:
with transaction.atomic():
if user.objects.filter(username=add_user).exists():
logger.info("User %s already exists - skipping creation", add_user)
logger.info('User %s already exists - skipping creation', add_user)
else:
new_user = user.objects.create_superuser(add_user, add_email, add_password)
new_user = user.objects.create_superuser(
add_user, add_email, add_password
)
logger.info('User %s was created!', str(new_user))
except IntegrityError:
logger.warning('The user "%s" could not be created', add_user)
@ -263,12 +274,12 @@ class InvenTreeConfig(AppConfig):
def add_user_from_file(self):
"""Add the superuser from a file."""
# stop if checks were already created
if hasattr(settings, "USER_ADDED_FILE") and settings.USER_ADDED_FILE:
if hasattr(settings, 'USER_ADDED_FILE') and settings.USER_ADDED_FILE:
return
# get values
add_password_file = get_setting(
"INVENTREE_ADMIN_PASSWORD_FILE", "admin_password_file", None
'INVENTREE_ADMIN_PASSWORD_FILE', 'admin_password_file', None
)
# no variable set -> do not try anything
@ -284,7 +295,11 @@ class InvenTreeConfig(AppConfig):
return
# good to go -> create user
self._create_admin_user(get_setting('INVENTREE_ADMIN_USER', 'admin_user', 'admin'), get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email', ''), add_password_file.read_text(encoding="utf-8"))
self._create_admin_user(
get_setting('INVENTREE_ADMIN_USER', 'admin_user', 'admin'),
get_setting('INVENTREE_ADMIN_EMAIL', 'admin_email', ''),
add_password_file.read_text(encoding='utf-8'),
)
# do not try again
settings.USER_ADDED_FILE = True

View File

@ -24,22 +24,14 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
here = os.path.abspath(os.path.dirname(__file__))
output_dir = os.path.join(
here,
'..',
'..',
'js_tmp',
)
output_dir = os.path.join(here, '..', '..', 'js_tmp')
output_dir = os.path.abspath(output_dir)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
output_file = os.path.join(
output_dir,
filename,
)
output_file = os.path.join(output_dir, filename)
with open(output_file, 'wb') as output:
output.write(response.content)
@ -48,12 +40,7 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
"""Download files in directory."""
here = os.path.abspath(os.path.dirname(__file__))
js_template_dir = os.path.join(
here,
'..',
'templates',
'js',
)
js_template_dir = os.path.join(here, '..', 'templates', 'js')
directory = os.path.join(js_template_dir, subdir)
@ -76,9 +63,9 @@ class RenderJavascriptFiles(InvenTreeTestCase): # pragma: no cover
"""Look for all javascript files."""
n = 0
print("Rendering javascript files...")
print('Rendering javascript files...')
n += self.download_files('translated', '/js/i18n')
n += self.download_files('dynamic', '/js/dynamic')
print(f"Rendered {n} javascript files.")
print(f'Rendered {n} javascript files.')

View File

@ -51,7 +51,11 @@ def to_dict(value):
try:
return json.loads(value)
except Exception as error:
logger.exception("Failed to parse value '%s' as JSON with error %s. Ensure value is a valid JSON string.", value, error)
logger.exception(
"Failed to parse value '%s' as JSON with error %s. Ensure value is a valid JSON string.",
value,
error,
)
return {}
@ -90,12 +94,14 @@ def get_config_file(create=True) -> Path:
cfg_filename = base_dir.joinpath('config.yaml').resolve()
if not cfg_filename.exists() and create:
print("InvenTree configuration file 'config.yaml' not found - creating default file")
print(
"InvenTree configuration file 'config.yaml' not found - creating default file"
)
ensure_dir(cfg_filename.parent)
cfg_template = base_dir.joinpath("config_template.yaml")
cfg_template = base_dir.joinpath('config_template.yaml')
shutil.copyfile(cfg_template, cfg_filename)
print(f"Created config file {cfg_filename}")
print(f'Created config file {cfg_filename}')
return cfg_filename
@ -153,7 +159,13 @@ def do_typecast(value, type, var_name=None):
return val
except Exception as error:
if var_name:
logger.exception("Failed to typecast '%s' with value '%s' to type '%s' with error %s", var_name, value, type, error)
logger.exception(
"Failed to typecast '%s' with value '%s' to type '%s' with error %s",
var_name,
value,
type,
error,
)
return value
@ -174,7 +186,12 @@ def get_setting(env_var=None, config_key=None, default_value=None, typecast=None
def set_metadata(source: str):
"""Set lookup metadata for the setting."""
key = env_var or config_key
CONFIG_LOOKUPS[key] = {'env_var': env_var, 'config_key': config_key, 'source': source, 'accessed': datetime.datetime.now()}
CONFIG_LOOKUPS[key] = {
'env_var': env_var,
'config_key': config_key,
'source': source,
'accessed': datetime.datetime.now(),
}
# First, try to load from the environment variables
if env_var is not None:
@ -192,7 +209,6 @@ def get_setting(env_var=None, config_key=None, default_value=None, typecast=None
# Hack to allow 'path traversal' in configuration file
for key in config_key.strip().split('.'):
if type(cfg_data) is not dict or key not in cfg_data:
result = None
break
@ -276,12 +292,16 @@ def get_plugin_file():
plugin_file = Path(plugin_file)
if not plugin_file.exists():
logger.warning("Plugin configuration file does not exist - creating default file")
logger.warning(
'Plugin configuration file does not exist - creating default file'
)
logger.info("Creating plugin file at '%s'", plugin_file)
ensure_dir(plugin_file.parent)
# If opening the file fails (no write permission, for example), then this will throw an error
plugin_file.write_text("# InvenTree Plugins (uses PIP framework to install)\n\n")
plugin_file.write_text(
'# InvenTree Plugins (uses PIP framework to install)\n\n'
)
return plugin_file
@ -303,7 +323,7 @@ def get_secret_key():
"""
# Look for environment variable
if secret_key := get_setting('INVENTREE_SECRET_KEY', 'secret_key'):
logger.info("SECRET_KEY loaded by INVENTREE_SECRET_KEY") # pragma: no cover
logger.info('SECRET_KEY loaded by INVENTREE_SECRET_KEY') # pragma: no cover
return secret_key
# Look for secret key file
@ -311,7 +331,7 @@ def get_secret_key():
secret_key_file = Path(secret_key_file).resolve()
else:
# Default location for secret key file
secret_key_file = get_base_dir().joinpath("secret_key.txt").resolve()
secret_key_file = get_base_dir().joinpath('secret_key.txt').resolve()
if not secret_key_file.exists():
logger.info("Generating random key file at '%s'", secret_key_file)
@ -329,7 +349,9 @@ def get_secret_key():
return key_data
def get_custom_file(env_ref: str, conf_ref: str, log_ref: str, lookup_media: bool = False):
def get_custom_file(
env_ref: str, conf_ref: str, log_ref: str, lookup_media: bool = False
):
"""Returns the checked path to a custom file.
Set lookup_media to True to also search in the media folder.
@ -345,12 +367,17 @@ def get_custom_file(env_ref: str, conf_ref: str, log_ref: str, lookup_media: boo
static_storage = StaticFilesStorage()
if static_storage.exists(value):
logger.info("Loading %s from %s directory: %s", log_ref, 'static', value)
logger.info('Loading %s from %s directory: %s', log_ref, 'static', value)
elif lookup_media and default_storage.exists(value):
logger.info("Loading %s from %s directory: %s", log_ref, 'media', value)
logger.info('Loading %s from %s directory: %s', log_ref, 'media', value)
else:
add_dir_str = ' or media' if lookup_media else ''
logger.warning("The %s file '%s' could not be found in the static %s directories", log_ref, value, add_dir_str)
logger.warning(
"The %s file '%s' could not be found in the static %s directories",
log_ref,
value,
add_dir_str,
)
value = False
return value
@ -364,16 +391,21 @@ def get_frontend_settings(debug=True):
"""
# Legacy settings
pui_settings = get_setting('INVENTREE_PUI_SETTINGS', 'pui_settings', {}, typecast=dict)
pui_settings = get_setting(
'INVENTREE_PUI_SETTINGS', 'pui_settings', {}, typecast=dict
)
if len(pui_settings) > 0:
warnings.warn(
"The 'INVENTREE_PUI_SETTINGS' key is deprecated. Please use 'INVENTREE_FRONTEND_SETTINGS' instead",
DeprecationWarning, stacklevel=2
DeprecationWarning,
stacklevel=2,
)
# New settings
frontend_settings = get_setting('INVENTREE_FRONTEND_SETTINGS', 'frontend_settings', {}, typecast=dict)
frontend_settings = get_setting(
'INVENTREE_FRONTEND_SETTINGS', 'frontend_settings', {}, typecast=dict
)
# Merge settings
settings = {**pui_settings, **frontend_settings}
@ -385,10 +417,13 @@ def get_frontend_settings(debug=True):
if base_url:
warnings.warn(
"The 'INVENTREE_PUI_URL_BASE' key is deprecated. Please use 'INVENTREE_FRONTEND_URL_BASE' instead",
DeprecationWarning, stacklevel=2
DeprecationWarning,
stacklevel=2,
)
else:
base_url = get_setting('INVENTREE_FRONTEND_URL_BASE', 'frontend_url_base', 'platform')
base_url = get_setting(
'INVENTREE_FRONTEND_URL_BASE', 'frontend_url_base', 'platform'
)
settings['base_url'] = base_url

View File

@ -31,9 +31,7 @@ def health_status(request):
}
# The following keys are required to denote system health
health_keys = [
'django_q_running',
]
health_keys = ['django_q_running']
all_healthy = True
@ -72,11 +70,9 @@ def user_roles(request):
"""
user = request.user
roles = {
}
roles = {}
for role in RuleSet.RULESET_MODELS.keys():
permissions = {}
for perm in ['view', 'add', 'change', 'delete']:

View File

@ -29,6 +29,7 @@ def reload_unit_registry():
This function is called at startup, and whenever the database is updated.
"""
import time
t_start = time.time()
global _unit_registry
@ -52,7 +53,9 @@ def reload_unit_registry():
try:
reg.define(cu.fmt_string())
except Exception as e:
logger.exception('Failed to load custom unit: %s - %s', cu.fmt_string(), e)
logger.exception(
'Failed to load custom unit: %s - %s', cu.fmt_string(), e
)
# Once custom units are loaded, save registry
_unit_registry = reg
@ -124,7 +127,7 @@ def convert_physical_value(value: str, unit: str = None, strip_units=True):
if unit:
raise ValidationError(_(f'Could not convert {original} to {unit}'))
else:
raise ValidationError(_("Invalid quantity supplied"))
raise ValidationError(_('Invalid quantity supplied'))
# Calculate the "magnitude" of the value, as a float
# If the value is specified strangely (e.g. as a fraction or a dozen), this can cause issues

View File

@ -30,22 +30,22 @@ def is_email_configured():
# Display warning unless in test mode
if not testing: # pragma: no cover
logger.debug("EMAIL_HOST is not configured")
logger.debug('EMAIL_HOST is not configured')
# Display warning unless in test mode
if not settings.EMAIL_HOST_USER and not testing: # pragma: no cover
logger.debug("EMAIL_HOST_USER is not configured")
logger.debug('EMAIL_HOST_USER is not configured')
# Display warning unless in test mode
if not settings.EMAIL_HOST_PASSWORD and testing: # pragma: no cover
logger.debug("EMAIL_HOST_PASSWORD is not configured")
logger.debug('EMAIL_HOST_PASSWORD is not configured')
# Email sender must be configured
if not settings.DEFAULT_FROM_EMAIL:
configured = False
if not testing: # pragma: no cover
logger.debug("DEFAULT_FROM_EMAIL is not configured")
logger.debug('DEFAULT_FROM_EMAIL is not configured')
return configured
@ -75,7 +75,7 @@ def send_email(subject, body, recipients, from_email=None, html_message=None):
if settings.TESTING:
from_email = 'from@test.com'
else:
logger.error("send_email failed: DEFAULT_FROM_EMAIL not specified")
logger.error('send_email failed: DEFAULT_FROM_EMAIL not specified')
return
InvenTree.tasks.offload_task(
@ -85,5 +85,5 @@ def send_email(subject, body, recipients, from_email=None, html_message=None):
from_email,
recipients,
fail_silently=False,
html_message=html_message
html_message=html_message,
)

View File

@ -86,7 +86,7 @@ def exception_handler(exc, context):
# If in DEBUG mode, provide error information in the response
error_detail = str(exc)
else:
error_detail = _("Error details can be found in the admin panel")
error_detail = _('Error details can be found in the admin panel')
response_data = {
'error': type(exc).__name__,

View File

@ -18,7 +18,7 @@ class InvenTreeExchange(SimpleExchangeBackend):
Uses the plugin system to actually fetch the rates from an external API.
"""
name = "InvenTreeExchange"
name = 'InvenTreeExchange'
def get_rates(self, **kwargs) -> None:
"""Set the requested currency codes and get rates."""
@ -44,7 +44,9 @@ class InvenTreeExchange(SimpleExchangeBackend):
plugin = plugins[0]
if not plugin:
logger.warning('No active currency exchange plugins found - skipping update')
logger.warning(
'No active currency exchange plugins found - skipping update'
)
return {}
logger.info("Running exchange rate update using plugin '%s'", plugin.name)
@ -53,16 +55,22 @@ class InvenTreeExchange(SimpleExchangeBackend):
try:
rates = plugin.update_exchange_rates(base_currency, symbols)
except Exception as exc:
logger.exception("Exchange rate update failed: %s", exc)
logger.exception('Exchange rate update failed: %s', exc)
return {}
if not rates:
logger.warning("Exchange rate update failed - no data returned from plugin %s", slug)
logger.warning(
'Exchange rate update failed - no data returned from plugin %s', slug
)
return {}
# Update exchange rates based on returned data
if type(rates) is not dict:
logger.warning("Invalid exchange rate data returned from plugin %s (type %s)", slug, type(rates))
logger.warning(
'Invalid exchange rate data returned from plugin %s (type %s)',
slug,
type(rates),
)
return {}
# Ensure base currency is provided
@ -73,14 +81,20 @@ class InvenTreeExchange(SimpleExchangeBackend):
@atomic
def update_rates(self, base_currency=None, **kwargs):
"""Call to update all exchange rates"""
backend, _ = ExchangeBackend.objects.update_or_create(name=self.name, defaults={"base_currency": base_currency})
backend, _ = ExchangeBackend.objects.update_or_create(
name=self.name, defaults={'base_currency': base_currency}
)
if base_currency is None:
base_currency = currency_code_default()
symbols = currency_codes()
logger.info("Updating exchange rates for %s (%s currencies)", base_currency, len(symbols))
logger.info(
'Updating exchange rates for %s (%s currencies)',
base_currency,
len(symbols),
)
# Fetch new rates from the backend
# If the backend fails, the existing rates will not be updated
@ -95,6 +109,8 @@ class InvenTreeExchange(SimpleExchangeBackend):
for currency, amount in rates.items()
])
else:
logger.info("No exchange rates returned from backend - currencies not updated")
logger.info(
'No exchange rates returned from backend - currencies not updated'
)
logger.info("Updated exchange rates for %s", base_currency)
logger.info('Updated exchange rates for %s', base_currency)

View File

@ -35,7 +35,9 @@ class InvenTreeRestURLField(RestURLField):
import common.models
strict_urls = common.models.InvenTreeSetting.get_setting('INVENTREE_STRICT_URLS', True, cache=False)
strict_urls = common.models.InvenTreeSetting.get_setting(
'INVENTREE_STRICT_URLS', True, cache=False
)
if not strict_urls and data is not empty:
if '://' not in data:
@ -97,11 +99,8 @@ class InvenTreeModelMoneyField(ModelMoneyField):
# If no validators are provided, add some "standard" ones
if len(validators) == 0:
if not allow_negative:
validators.append(
MinMoneyValidator(0),
)
validators.append(MinMoneyValidator(0))
kwargs['validators'] = validators
@ -144,11 +143,7 @@ class DatePickerFormField(forms.DateField):
required = kwargs.get('required', False)
initial = kwargs.get('initial', None)
widget = forms.DateInput(
attrs={
'type': 'date',
}
)
widget = forms.DateInput(attrs={'type': 'date'})
forms.DateField.__init__(
self,
@ -156,7 +151,7 @@ class DatePickerFormField(forms.DateField):
initial=initial,
help_text=help_text,
widget=widget,
label=label
label=label,
)

View File

@ -36,7 +36,9 @@ class InvenTreeSearchFilter(filters.SearchFilter):
The following query params are available to 'augment' the search (in decreasing order of priority)
- search_regex: If True, search is performed on 'regex' comparison
"""
regex = InvenTree.helpers.str2bool(request.query_params.get('search_regex', False))
regex = InvenTree.helpers.str2bool(
request.query_params.get('search_regex', False)
)
search_fields = super().get_search_fields(view, request)
@ -56,7 +58,9 @@ class InvenTreeSearchFilter(filters.SearchFilter):
Depending on the request parameters, we may "augment" these somewhat
"""
whole = InvenTree.helpers.str2bool(request.query_params.get('search_whole', False))
whole = InvenTree.helpers.str2bool(
request.query_params.get('search_whole', False)
)
terms = []
@ -72,7 +76,7 @@ class InvenTreeSearchFilter(filters.SearchFilter):
if whole:
# Wrap the search term to enable word-boundary matching
term = r"\y" + term + r"\y"
term = r'\y' + term + r'\y'
terms.append(term)
@ -110,7 +114,6 @@ class InvenTreeOrderingFilter(filters.OrderingFilter):
ordering = []
for field in ordering_initial:
reverse = field.startswith('-')
if reverse:
@ -164,7 +167,4 @@ SEARCH_ORDER_FILTER_ALIAS = [
InvenTreeOrderingFilter,
]
ORDER_FILTER = [
rest_filters.DjangoFilterBackend,
filters.OrderingFilter,
]
ORDER_FILTER = [rest_filters.DjangoFilterBackend, filters.OrderingFilter]

View File

@ -36,10 +36,7 @@ def parse_format_string(fmt_string: str) -> dict:
else:
seen_groups.add(name)
info[group[1]] = {
'format': group[1],
'prefix': group[0],
}
info[group[1]] = {'format': group[1], 'prefix': group[0]}
return info
@ -67,7 +64,7 @@ def construct_format_regex(fmt_string: str) -> str:
Raises:
ValueError: Format string is invalid
"""
pattern = "^"
pattern = '^'
for group in string.Formatter().parse(fmt_string):
prefix = group[0] # Prefix (literal text appearing before this group)
@ -75,9 +72,23 @@ def construct_format_regex(fmt_string: str) -> str:
format = group[2] # Format specifier e.g :04d
rep = [
'+', '-', '.',
'{', '}', '(', ')',
'^', '$', '~', '!', '@', ':', ';', '|', '\'', '"',
'+',
'-',
'.',
'{',
'}',
'(',
')',
'^',
'$',
'~',
'!',
'@',
':',
';',
'|',
"'",
'"',
]
# Escape any special regex characters
@ -94,7 +105,6 @@ def construct_format_regex(fmt_string: str) -> str:
# Add a named capture group for the format entry
if name:
# Check if integer values are required
if format.endswith('d'):
chr = '\d'
@ -105,9 +115,9 @@ def construct_format_regex(fmt_string: str) -> str:
# TODO: Introspect required width
w = '+'
pattern += f"(?P<{name}>{chr}{w})"
pattern += f'(?P<{name}>{chr}{w})'
pattern += "$"
pattern += '$'
return pattern
@ -161,7 +171,9 @@ def extract_named_group(name: str, value: str, fmt_string: str) -> str:
result = re.match(pattern, value)
if not result:
raise ValueError(_("Provided value does not match required pattern: ") + fmt_string)
raise ValueError(
_('Provided value does not match required pattern: ') + fmt_string
)
# And return the value we are interested in
# Note: This will raise an IndexError if the named group was not matched
@ -186,7 +198,7 @@ def format_money(money: Money, decimal_places: int = None, format: str = None) -
if format:
pattern = parse_pattern(format)
else:
pattern = locale.currency_formats["standard"]
pattern = locale.currency_formats['standard']
if decimal_places is not None:
pattern.frac_prec = (decimal_places, decimal_places)

View File

@ -17,8 +17,7 @@ from allauth.core.exceptions import ImmediateHttpResponse
from allauth.socialaccount.adapter import DefaultSocialAccountAdapter
from allauth_2fa.adapter import OTPAdapter
from allauth_2fa.utils import user_has_valid_totp_device
from crispy_forms.bootstrap import (AppendedText, PrependedAppendedText,
PrependedText)
from crispy_forms.bootstrap import AppendedText, PrependedAppendedText, PrependedText
from crispy_forms.helper import FormHelper
from crispy_forms.layout import Field, Layout
from dj_rest_auth.registration.serializers import RegisterSerializer
@ -80,31 +79,19 @@ class HelperForm(forms.ModelForm):
field,
prepended_text=prefix,
appended_text=suffix,
placeholder=placeholder
placeholder=placeholder,
)
)
)
elif prefix:
layouts.append(
Field(
PrependedText(
field,
prefix,
placeholder=placeholder
)
)
Field(PrependedText(field, prefix, placeholder=placeholder))
)
elif suffix:
layouts.append(
Field(
AppendedText(
field,
suffix,
placeholder=placeholder
)
)
Field(AppendedText(field, suffix, placeholder=placeholder))
)
else:
@ -120,10 +107,7 @@ class EditUserForm(HelperForm):
"""Metaclass options."""
model = User
fields = [
'first_name',
'last_name',
]
fields = ['first_name', 'last_name']
class SetPasswordForm(HelperForm):
@ -133,11 +117,7 @@ class SetPasswordForm(HelperForm):
"""Metaclass options."""
model = User
fields = [
'enter_password',
'confirm_password',
'old_password',
]
fields = ['enter_password', 'confirm_password', 'old_password']
enter_password = forms.CharField(
max_length=100,
@ -146,7 +126,7 @@ class SetPasswordForm(HelperForm):
initial='',
widget=forms.PasswordInput(attrs={'autocomplete': 'off'}),
label=_('Enter password'),
help_text=_('Enter new password')
help_text=_('Enter new password'),
)
confirm_password = forms.CharField(
@ -156,14 +136,16 @@ class SetPasswordForm(HelperForm):
initial='',
widget=forms.PasswordInput(attrs={'autocomplete': 'off'}),
label=_('Confirm password'),
help_text=_('Confirm new password')
help_text=_('Confirm new password'),
)
old_password = forms.CharField(
label=_("Old password"),
label=_('Old password'),
strip=False,
required=False,
widget=forms.PasswordInput(attrs={'autocomplete': 'current-password', 'autofocus': True}),
widget=forms.PasswordInput(
attrs={'autocomplete': 'current-password', 'autofocus': True}
),
)
@ -196,22 +178,24 @@ class CustomSignupForm(SignupForm):
# check for two mail fields
if InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_TWICE'):
self.fields["email2"] = forms.EmailField(
label=_("Email (again)"),
self.fields['email2'] = forms.EmailField(
label=_('Email (again)'),
widget=forms.TextInput(
attrs={
"type": "email",
"placeholder": _("Email address confirmation"),
'type': 'email',
'placeholder': _('Email address confirmation'),
}
),
)
# check for two password fields
if not InvenTreeSetting.get_setting('LOGIN_SIGNUP_PWD_TWICE'):
self.fields.pop("password2")
self.fields.pop('password2')
# reorder fields
set_form_field_order(self, ["username", "email", "email2", "password1", "password2", ])
set_form_field_order(
self, ['username', 'email', 'email2', 'password1', 'password2']
)
def clean(self):
"""Make sure the supplied emails match if enabled in settings."""
@ -219,21 +203,26 @@ class CustomSignupForm(SignupForm):
# check for two mail fields
if InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_TWICE'):
email = cleaned_data.get("email")
email2 = cleaned_data.get("email2")
email = cleaned_data.get('email')
email2 = cleaned_data.get('email2')
if (email and email2) and email != email2:
self.add_error("email2", _("You must type the same email each time."))
self.add_error('email2', _('You must type the same email each time.'))
return cleaned_data
def registration_enabled():
"""Determine whether user registration is enabled."""
if InvenTreeSetting.get_setting('LOGIN_ENABLE_REG') or InvenTree.sso.registration_enabled():
if (
InvenTreeSetting.get_setting('LOGIN_ENABLE_REG')
or InvenTree.sso.registration_enabled()
):
if settings.EMAIL_HOST:
return True
else:
logger.error("Registration cannot be enabled, because EMAIL_HOST is not configured.")
logger.error(
'Registration cannot be enabled, because EMAIL_HOST is not configured.'
)
return False
@ -251,20 +240,26 @@ class RegistratonMixin:
def clean_email(self, email):
"""Check if the mail is valid to the pattern in LOGIN_SIGNUP_MAIL_RESTRICTION (if enabled in settings)."""
mail_restriction = InvenTreeSetting.get_setting('LOGIN_SIGNUP_MAIL_RESTRICTION', None)
mail_restriction = InvenTreeSetting.get_setting(
'LOGIN_SIGNUP_MAIL_RESTRICTION', None
)
if not mail_restriction:
return super().clean_email(email)
split_email = email.split('@')
if len(split_email) != 2:
logger.error('The user %s has an invalid email address', email)
raise forms.ValidationError(_('The provided primary email address is not valid.'))
raise forms.ValidationError(
_('The provided primary email address is not valid.')
)
mailoptions = mail_restriction.split(',')
for option in mailoptions:
if not option.startswith('@'):
log_error('LOGIN_SIGNUP_MAIL_RESTRICTION is not configured correctly')
raise forms.ValidationError(_('The provided primary email address is not valid.'))
raise forms.ValidationError(
_('The provided primary email address is not valid.')
)
else:
if split_email[1] == option[1:]:
return super().clean_email(email)
@ -284,7 +279,10 @@ class RegistratonMixin:
group = Group.objects.get(id=start_group)
user.groups.add(group)
except Group.DoesNotExist:
logger.exception('The setting `SIGNUP_GROUP` contains an non existent group', start_group)
logger.exception(
'The setting `SIGNUP_GROUP` contains an non existent group',
start_group,
)
user.save()
return user
@ -294,11 +292,13 @@ class CustomUrlMixin:
def get_email_confirmation_url(self, request, emailconfirmation):
"""Custom email confirmation (activation) url."""
url = reverse("account_confirm_email", args=[emailconfirmation.key])
url = reverse('account_confirm_email', args=[emailconfirmation.key])
return Site.objects.get_current().domain + url
class CustomAccountAdapter(CustomUrlMixin, RegistratonMixin, OTPAdapter, DefaultAccountAdapter):
class CustomAccountAdapter(
CustomUrlMixin, RegistratonMixin, OTPAdapter, DefaultAccountAdapter
):
"""Override of adapter to use dynamic settings."""
def send_mail(self, template_prefix, email, context):
@ -325,7 +325,9 @@ class CustomAccountAdapter(CustomUrlMixin, RegistratonMixin, OTPAdapter, Default
return url
class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocialAccountAdapter):
class CustomSocialAccountAdapter(
CustomUrlMixin, RegistratonMixin, DefaultSocialAccountAdapter
):
"""Override of adapter to use dynamic settings."""
def is_auto_signup_allowed(self, request, sociallogin):
@ -352,14 +354,14 @@ class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocial
if request.GET:
redirect_url += '?' + urlencode(request.GET)
raise ImmediateHttpResponse(
response=HttpResponseRedirect(redirect_url)
)
raise ImmediateHttpResponse(response=HttpResponseRedirect(redirect_url))
# Otherwise defer to the original allauth adapter.
return super().login(request, user)
def authentication_error(self, request, provider_id, error=None, exception=None, extra_context=None):
def authentication_error(
self, request, provider_id, error=None, exception=None, extra_context=None
):
"""Callback method for authentication errors."""
# Log the error to the database
@ -370,6 +372,7 @@ class CustomSocialAccountAdapter(CustomUrlMixin, RegistratonMixin, DefaultSocial
# override dj-rest-auth
class CustomRegisterSerializer(RegisterSerializer):
"""Override of serializer to use dynamic settings."""
email = serializers.EmailField()
def __init__(self, instance=None, data=..., **kwargs):

View File

@ -36,7 +36,7 @@ def generateTestKey(test_name):
Tests must be named such that they will have unique keys.
"""
key = test_name.strip().lower()
key = key.replace(" ", "")
key = key.replace(' ', '')
# Remove any characters that cannot be used to represent a variable
key = re.sub(r'[^a-zA-Z0-9]', '', key)
@ -56,7 +56,7 @@ def constructPathString(path, max_chars=250):
# Replace middle elements to limit the pathstring
if len(pathstring) > max_chars:
n = int(max_chars / 2 - 2)
pathstring = pathstring[:n] + "..." + pathstring[-n:]
pathstring = pathstring[:n] + '...' + pathstring[-n:]
return pathstring
@ -82,19 +82,18 @@ def TestIfImage(img):
def getBlankImage():
"""Return the qualified path for the 'blank image' placeholder."""
return getStaticUrl("img/blank_image.png")
return getStaticUrl('img/blank_image.png')
def getBlankThumbnail():
"""Return the qualified path for the 'blank image' thumbnail placeholder."""
return getStaticUrl("img/blank_image.thumbnail.png")
return getStaticUrl('img/blank_image.thumbnail.png')
def getLogoImage(as_file=False, custom=True):
"""Return the InvenTree logo image, or a custom logo if available."""
"""Return the path to the logo-file."""
if custom and settings.CUSTOM_LOGO:
static_storage = StaticFilesStorage()
if static_storage.exists(settings.CUSTOM_LOGO):
@ -106,13 +105,13 @@ def getLogoImage(as_file=False, custom=True):
if storage is not None:
if as_file:
return f"file://{storage.path(settings.CUSTOM_LOGO)}"
return f'file://{storage.path(settings.CUSTOM_LOGO)}'
return storage.url(settings.CUSTOM_LOGO)
# If we have got to this point, return the default logo
if as_file:
path = settings.STATIC_ROOT.joinpath('img/inventree.png')
return f"file://{path}"
return f'file://{path}'
return getStaticUrl('img/inventree.png')
@ -121,12 +120,11 @@ def getSplashScreen(custom=True):
static_storage = StaticFilesStorage()
if custom and settings.CUSTOM_SPLASH:
if static_storage.exists(settings.CUSTOM_SPLASH):
return static_storage.url(settings.CUSTOM_SPLASH)
# No custom splash screen
return static_storage.url("img/inventree_splash.jpg")
return static_storage.url('img/inventree_splash.jpg')
def TestIfImageURL(url):
@ -135,10 +133,15 @@ def TestIfImageURL(url):
Simply tests the extension against a set of allowed values
"""
return os.path.splitext(os.path.basename(url))[-1].lower() in [
'.jpg', '.jpeg', '.j2k',
'.png', '.bmp',
'.tif', '.tiff',
'.webp', '.gif',
'.jpg',
'.jpeg',
'.j2k',
'.png',
'.bmp',
'.tif',
'.tiff',
'.webp',
'.gif',
]
@ -153,8 +156,8 @@ def str2bool(text, test=True):
True if the text looks like the selected boolean value
"""
if test:
return str(text).lower() in ['1', 'y', 'yes', 't', 'true', 'ok', 'on', ]
return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off', ]
return str(text).lower() in ['1', 'y', 'yes', 't', 'true', 'ok', 'on']
return str(text).lower() in ['0', 'n', 'no', 'none', 'f', 'false', 'off']
def str2int(text, default=None):
@ -191,7 +194,15 @@ def isNull(text):
Returns:
True if the text looks like a null value
"""
return str(text).strip().lower() in ['top', 'null', 'none', 'empty', 'false', '-1', '']
return str(text).strip().lower() in [
'top',
'null',
'none',
'empty',
'false',
'-1',
'',
]
def normalize(d):
@ -223,7 +234,7 @@ def increment(value):
# Provide a default value if provided with a null input
return '1'
pattern = r"(.*?)(\d+)?$"
pattern = r'(.*?)(\d+)?$'
result = re.search(pattern, value)
@ -282,7 +293,7 @@ def decimal2string(d):
if '.' not in s:
return s
return s.rstrip("0").rstrip(".")
return s.rstrip('0').rstrip('.')
def decimal2money(d, currency=None):
@ -354,17 +365,12 @@ def MakeBarcode(cls_name, object_pk: int, object_data=None, **kwargs):
def GetExportFormats():
"""Return a list of allowable file formats for exporting data."""
return [
'csv',
'tsv',
'xls',
'xlsx',
'json',
'yaml',
]
return ['csv', 'tsv', 'xls', 'xlsx', 'json', 'yaml']
def DownloadFile(data, filename, content_type='application/text', inline=False) -> StreamingHttpResponse:
def DownloadFile(
data, filename, content_type='application/text', inline=False
) -> StreamingHttpResponse:
"""Create a dynamic file for the user to download.
Args:
@ -389,7 +395,7 @@ def DownloadFile(data, filename, content_type='application/text', inline=False)
length = len(bytes(data, response.charset))
response['Content-Length'] = length
disposition = "inline" if inline else "attachment"
disposition = 'inline' if inline else 'attachment'
response['Content-Disposition'] = f'{disposition}; filename={filename}'
@ -449,7 +455,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
try:
expected_quantity = int(expected_quantity)
except ValueError:
raise ValidationError([_("Invalid quantity provided")])
raise ValidationError([_('Invalid quantity provided')])
if input_string:
input_string = str(input_string).strip()
@ -457,7 +463,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
input_string = ''
if len(input_string) == 0:
raise ValidationError([_("Empty serial number string")])
raise ValidationError([_('Empty serial number string')])
next_value = increment_serial_number(starting_value)
@ -467,7 +473,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
next_value = increment_serial_number(next_value)
# Split input string by whitespace or comma (,) characters
groups = re.split(r"[\s,]+", input_string)
groups = re.split(r'[\s,]+', input_string)
serials = []
errors = []
@ -487,7 +493,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
return
if serial in serials:
add_error(_("Duplicate serial") + f": {serial}")
add_error(_('Duplicate serial') + f': {serial}')
else:
serials.append(serial)
@ -502,7 +508,6 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
return serials
for group in groups:
# Calculate the "remaining" quantity of serial numbers
remaining = expected_quantity - len(serials)
@ -520,7 +525,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
if a == b:
# Invalid group
add_error(_(f"Invalid group range: {group}"))
add_error(_(f'Invalid group range: {group}'))
continue
group_items = []
@ -549,13 +554,21 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
break
if len(group_items) > remaining:
add_error(_(f"Group range {group} exceeds allowed quantity ({expected_quantity})"))
elif len(group_items) > 0 and group_items[0] == a and group_items[-1] == b:
add_error(
_(
f'Group range {group} exceeds allowed quantity ({expected_quantity})'
)
)
elif (
len(group_items) > 0
and group_items[0] == a
and group_items[-1] == b
):
# In this case, the range extraction looks like it has worked
for item in group_items:
add_serial(item)
else:
add_error(_(f"Invalid group range: {group}"))
add_error(_(f'Invalid group range: {group}'))
else:
# In the case of a different number of hyphens, simply add the entire group
@ -573,20 +586,24 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
sequence_count = max(0, expected_quantity - len(serials))
if len(items) > 2 or len(items) == 0:
add_error(_(f"Invalid group sequence: {group}"))
add_error(_(f'Invalid group sequence: {group}'))
continue
elif len(items) == 2:
try:
if items[1]:
sequence_count = int(items[1]) + 1
except ValueError:
add_error(_(f"Invalid group sequence: {group}"))
add_error(_(f'Invalid group sequence: {group}'))
continue
value = items[0]
# Keep incrementing up to the specified quantity
while value is not None and value not in sequence_items and counter < sequence_count:
while (
value is not None
and value not in sequence_items
and counter < sequence_count
):
sequence_items.append(value)
value = increment_serial_number(value)
counter += 1
@ -595,7 +612,7 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
for item in sequence_items:
add_serial(item)
else:
add_error(_(f"Invalid group sequence: {group}"))
add_error(_(f'Invalid group sequence: {group}'))
else:
# At this point, we assume that the 'group' is just a single serial value
@ -605,10 +622,14 @@ def extract_serial_numbers(input_string, expected_quantity: int, starting_value=
raise ValidationError(errors)
if len(serials) == 0:
raise ValidationError([_("No serial numbers found")])
raise ValidationError([_('No serial numbers found')])
if len(errors) == 0 and len(serials) != expected_quantity:
raise ValidationError([_(f"Number of unique serial numbers ({len(serials)}) must match quantity ({expected_quantity})")])
raise ValidationError([
_(
f'Number of unique serial numbers ({len(serials)}) must match quantity ({expected_quantity})'
)
])
return serials
@ -645,9 +666,7 @@ def validateFilterString(value, model=None):
pair = group.split('=')
if len(pair) != 2:
raise ValidationError(
f"Invalid group: {group}"
)
raise ValidationError(f'Invalid group: {group}')
k, v = pair
@ -655,9 +674,7 @@ def validateFilterString(value, model=None):
v = v.strip()
if not k or not v:
raise ValidationError(
f"Invalid group: {group}"
)
raise ValidationError(f'Invalid group: {group}')
results[k] = v
@ -666,9 +683,7 @@ def validateFilterString(value, model=None):
try:
model.objects.filter(**results)
except FieldError as e:
raise ValidationError(
str(e),
)
raise ValidationError(str(e))
return results
@ -706,7 +721,11 @@ def clean_decimal(number):
# Number cannot be converted to Decimal (eg. a string containing letters)
return Decimal(0)
return clean_number.quantize(Decimal(1)) if clean_number == clean_number.to_integral() else clean_number.normalize()
return (
clean_number.quantize(Decimal(1))
if clean_number == clean_number.to_integral()
else clean_number.normalize()
)
def strip_html_tags(value: str, raise_error=True, field_name=None):
@ -714,55 +733,45 @@ def strip_html_tags(value: str, raise_error=True, field_name=None):
If raise_error is True, a ValidationError will be thrown if HTML tags are detected
"""
cleaned = clean(
value,
strip=True,
tags=[],
attributes=[],
)
cleaned = clean(value, strip=True, tags=[], attributes=[])
# Add escaped characters back in
replacements = {
'&gt;': '>',
'&lt;': '<',
'&amp;': '&',
}
replacements = {'&gt;': '>', '&lt;': '<', '&amp;': '&'}
for o, r in replacements.items():
cleaned = cleaned.replace(o, r)
# If the length changed, it means that HTML tags were removed!
if len(cleaned) != len(value) and raise_error:
field = field_name or 'non_field_errors'
raise ValidationError({
field: [_("Remove HTML tags from this value")]
})
raise ValidationError({field: [_('Remove HTML tags from this value')]})
return cleaned
def remove_non_printable_characters(value: str, remove_newline=True, remove_ascii=True, remove_unicode=True):
def remove_non_printable_characters(
value: str, remove_newline=True, remove_ascii=True, remove_unicode=True
):
"""Remove non-printable / control characters from the provided string"""
cleaned = value
if remove_ascii:
# Remove ASCII control characters
# Note that we do not sub out 0x0A (\n) here, it is done separately below
cleaned = regex.sub(u'[\x00-\x09]+', '', cleaned)
cleaned = regex.sub(u'[\x0b-\x1F\x7F]+', '', cleaned)
cleaned = regex.sub('[\x00-\x09]+', '', cleaned)
cleaned = regex.sub('[\x0b-\x1f\x7f]+', '', cleaned)
if remove_newline:
cleaned = regex.sub(u'[\x0a]+', '', cleaned)
cleaned = regex.sub('[\x0a]+', '', cleaned)
if remove_unicode:
# Remove Unicode control characters
if remove_newline:
cleaned = regex.sub(u'[^\P{C}]+', '', cleaned)
cleaned = regex.sub('[^\P{C}]+', '', cleaned)
else:
# Use 'negative-lookahead' to exclude newline character
cleaned = regex.sub(u'(?![\x0A])[^\P{C}]+', '', cleaned)
cleaned = regex.sub('(?![\x0a])[^\P{C}]+', '', cleaned)
return cleaned
@ -783,7 +792,9 @@ def hash_barcode(barcode_data):
return str(hash.hexdigest())
def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = 'object_id'):
def get_objectreference(
obj, type_ref: str = 'content_type', object_ref: str = 'object_id'
):
"""Lookup method for the GenericForeignKey fields.
Attributes:
@ -821,11 +832,7 @@ def get_objectreference(obj, type_ref: str = 'content_type', object_ref: str = '
ret = {}
if url_fnc:
ret['link'] = url_fnc()
return {
'name': str(item),
'model': str(model_cls._meta.verbose_name),
**ret
}
return {'name': str(item), 'model': str(model_cls._meta.verbose_name), **ret}
def inheritors(cls):

View File

@ -19,8 +19,11 @@ import common.models
import InvenTree
import InvenTree.helpers_model
import InvenTree.version
from common.notifications import (InvenTreeNotificationBodies,
NotificationBody, trigger_notification)
from common.notifications import (
InvenTreeNotificationBodies,
NotificationBody,
trigger_notification,
)
from InvenTree.format import format_money
logger = logging.getLogger('inventree')
@ -49,7 +52,9 @@ def construct_absolute_url(*arg, **kwargs):
if not site_url:
# Otherwise, try to use the InvenTree setting
try:
site_url = common.models.InvenTreeSetting.get_setting('INVENTREE_BASE_URL', create=False, cache=False)
site_url = common.models.InvenTreeSetting.get_setting(
'INVENTREE_BASE_URL', create=False, cache=False
)
except (ProgrammingError, OperationalError):
pass
@ -100,12 +105,22 @@ def download_image_from_url(remote_url, timeout=2.5):
validator(remote_url)
# Calculate maximum allowable image size (in bytes)
max_size = int(common.models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE')) * 1024 * 1024
max_size = (
int(
common.models.InvenTreeSetting.get_setting(
'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE'
)
)
* 1024
* 1024
)
# Add user specified user-agent to request (if specified)
user_agent = common.models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT')
user_agent = common.models.InvenTreeSetting.get_setting(
'INVENTREE_DOWNLOAD_FROM_URL_USER_AGENT'
)
if user_agent:
headers = {"User-Agent": user_agent}
headers = {'User-Agent': user_agent}
else:
headers = None
@ -120,24 +135,28 @@ def download_image_from_url(remote_url, timeout=2.5):
# Throw an error if anything goes wrong
response.raise_for_status()
except requests.exceptions.ConnectionError as exc:
raise Exception(_("Connection error") + f": {str(exc)}")
raise Exception(_('Connection error') + f': {str(exc)}')
except requests.exceptions.Timeout as exc:
raise exc
except requests.exceptions.HTTPError:
raise requests.exceptions.HTTPError(_("Server responded with invalid status code") + f": {response.status_code}")
raise requests.exceptions.HTTPError(
_('Server responded with invalid status code') + f': {response.status_code}'
)
except Exception as exc:
raise Exception(_("Exception occurred") + f": {str(exc)}")
raise Exception(_('Exception occurred') + f': {str(exc)}')
if response.status_code != 200:
raise Exception(_("Server responded with invalid status code") + f": {response.status_code}")
raise Exception(
_('Server responded with invalid status code') + f': {response.status_code}'
)
try:
content_length = int(response.headers.get('Content-Length', 0))
except ValueError:
raise ValueError(_("Server responded with invalid Content-Length value"))
raise ValueError(_('Server responded with invalid Content-Length value'))
if content_length > max_size:
raise ValueError(_("Image size is too large"))
raise ValueError(_('Image size is too large'))
# Download the file, ensuring we do not exceed the reported size
file = io.BytesIO()
@ -149,12 +168,12 @@ def download_image_from_url(remote_url, timeout=2.5):
dl_size += len(chunk)
if dl_size > max_size:
raise ValueError(_("Image download exceeded maximum size"))
raise ValueError(_('Image download exceeded maximum size'))
file.write(chunk)
if dl_size == 0:
raise ValueError(_("Remote server returned empty response"))
raise ValueError(_('Remote server returned empty response'))
# Now, attempt to convert the downloaded data to a valid image file
# img.verify() will throw an exception if the image is not valid
@ -162,12 +181,18 @@ def download_image_from_url(remote_url, timeout=2.5):
img = Image.open(file).convert()
img.verify()
except Exception:
raise TypeError(_("Supplied URL is not a valid image file"))
raise TypeError(_('Supplied URL is not a valid image file'))
return img
def render_currency(money, decimal_places=None, currency=None, min_decimal_places=None, max_decimal_places=None):
def render_currency(
money,
decimal_places=None,
currency=None,
min_decimal_places=None,
max_decimal_places=None,
):
"""Render a currency / Money object to a formatted string (e.g. for reports)
Arguments:
@ -192,13 +217,19 @@ def render_currency(money, decimal_places=None, currency=None, min_decimal_place
pass
if decimal_places is None:
decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES', 6)
decimal_places = common.models.InvenTreeSetting.get_setting(
'PRICING_DECIMAL_PLACES', 6
)
if min_decimal_places is None:
min_decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES_MIN', 0)
min_decimal_places = common.models.InvenTreeSetting.get_setting(
'PRICING_DECIMAL_PLACES_MIN', 0
)
if max_decimal_places is None:
max_decimal_places = common.models.InvenTreeSetting.get_setting('PRICING_DECIMAL_PLACES', 6)
max_decimal_places = common.models.InvenTreeSetting.get_setting(
'PRICING_DECIMAL_PLACES', 6
)
value = Decimal(str(money.amount)).normalize()
value = str(value)
@ -229,7 +260,9 @@ def getModelsWithMixin(mixin_class) -> list:
from django.contrib.contenttypes.models import ContentType
try:
db_models = [x.model_class() for x in ContentType.objects.all() if x is not None]
db_models = [
x.model_class() for x in ContentType.objects.all() if x is not None
]
except (OperationalError, ProgrammingError):
# Database is likely not yet ready
db_models = []
@ -237,7 +270,12 @@ def getModelsWithMixin(mixin_class) -> list:
return [x for x in db_models if x is not None and issubclass(x, mixin_class)]
def notify_responsible(instance, sender, content: NotificationBody = InvenTreeNotificationBodies.NewOrder, exclude=None):
def notify_responsible(
instance,
sender,
content: NotificationBody = InvenTreeNotificationBodies.NewOrder,
exclude=None,
):
"""Notify all responsible parties of a change in an instance.
Parses the supplied content with the provided instance and sender and sends a notification to all responsible users,
@ -249,10 +287,18 @@ def notify_responsible(instance, sender, content: NotificationBody = InvenTreeNo
content (NotificationBody, optional): _description_. Defaults to InvenTreeNotificationBodies.NewOrder.
exclude (User, optional): User instance that should be excluded. Defaults to None.
"""
notify_users([instance.responsible], instance, sender, content=content, exclude=exclude)
notify_users(
[instance.responsible], instance, sender, content=content, exclude=exclude
)
def notify_users(users, instance, sender, content: NotificationBody = InvenTreeNotificationBodies.NewOrder, exclude=None):
def notify_users(
users,
instance,
sender,
content: NotificationBody = InvenTreeNotificationBodies.NewOrder,
exclude=None,
):
"""Notify all passed users or groups.
Parses the supplied content with the provided instance and sender and sends a notification to all users,
@ -278,10 +324,10 @@ def notify_users(users, instance, sender, content: NotificationBody = InvenTreeN
'instance': instance,
'name': content.name.format(**content_context),
'message': content.message.format(**content_context),
'link': InvenTree.helpers_model.construct_absolute_url(instance.get_absolute_url()),
'template': {
'subject': content.name.format(**content_context),
}
'link': InvenTree.helpers_model.construct_absolute_url(
instance.get_absolute_url()
),
'template': {'subject': content.name.format(**content_context)},
}
if content.template:

View File

@ -1,4 +1,5 @@
"""Functions for magic login."""
from django.conf import settings
from django.contrib.auth.models import User
from django.contrib.sites.models import Site
@ -17,15 +18,13 @@ def send_simple_login_email(user, link):
"""Send an email with the login link to this user."""
site = Site.objects.get_current()
context = {
"username": user.username,
"site_name": site.name,
"link": link,
}
email_plaintext_message = render_to_string("InvenTree/user_simple_login.txt", context)
context = {'username': user.username, 'site_name': site.name, 'link': link}
email_plaintext_message = render_to_string(
'InvenTree/user_simple_login.txt', context
)
send_mail(
_(f"[{site.name}] Log in to the app"),
_(f'[{site.name}] Log in to the app'),
email_plaintext_message,
settings.DEFAULT_FROM_EMAIL,
[user.email],
@ -35,7 +34,7 @@ def send_simple_login_email(user, link):
class GetSimpleLoginSerializer(serializers.Serializer):
"""Serializer for the simple login view."""
email = serializers.CharField(label=_("Email"))
email = serializers.CharField(label=_('Email'))
class GetSimpleLoginView(APIView):
@ -48,14 +47,14 @@ class GetSimpleLoginView(APIView):
"""Get the token for the current user or fail."""
serializer = self.serializer_class(data=request.data)
serializer.is_valid(raise_exception=True)
self.email_submitted(email=serializer.data["email"])
return Response({"status": "ok"})
self.email_submitted(email=serializer.data['email'])
return Response({'status': 'ok'})
def email_submitted(self, email):
"""Notify user about link."""
user = self.get_user(email)
if user is None:
print("user not found:", email)
print('user not found:', email)
return
link = self.create_link(user)
send_simple_login_email(user, link)
@ -69,7 +68,7 @@ class GetSimpleLoginView(APIView):
def create_link(self, user):
"""Create a login link for this user."""
link = reverse("sesame-login")
link = reverse('sesame-login')
link = self.request.build_absolute_uri(link)
link += sesame.utils.get_query_string(user)
return link

View File

@ -12,7 +12,7 @@ class Command(BaseCommand):
def handle(self, *args, **kwargs):
"""Cleanup old (undefined) settings in the database."""
logger.info("Collecting settings")
logger.info('Collecting settings')
from common.models import InvenTreeSetting, InvenTreeUserSetting
# general settings
@ -35,4 +35,4 @@ class Command(BaseCommand):
setting.delete()
logger.info("deleted user setting '%s'", setting.key)
logger.info("checked all settings")
logger.info('checked all settings')

View File

@ -13,7 +13,6 @@ from django.utils.translation import override as lang_over
def render_file(file_name, source, target, locales, ctx):
"""Renders a file into all provided locales."""
for locale in locales:
# Enforce lower-case for locale names
locale = locale.lower()
locale = locale.replace('_', '-')
@ -50,15 +49,19 @@ class Command(BaseCommand):
# render!
request = HttpRequest()
ctx = {}
processors = tuple(import_string(path) for path in settings.STATFILES_I18_PROCESSORS)
processors = tuple(
import_string(path) for path in settings.STATFILES_I18_PROCESSORS
)
for processor in processors:
ctx.update(processor(request))
for file in os.listdir(SOURCE_DIR, ):
for file in os.listdir(SOURCE_DIR):
path = os.path.join(SOURCE_DIR, file)
if os.path.exists(path) and os.path.isfile(path):
print(f"render {file}")
print(f'render {file}')
render_file(file, SOURCE_DIR, TARGET_DIR, locales, ctx)
else:
raise NotImplementedError('Using multi-level directories is not implemented at this point') # TODO multilevel dir if needed
print(f"rendered all files in {SOURCE_DIR}")
raise NotImplementedError(
'Using multi-level directories is not implemented at this point'
) # TODO multilevel dir if needed
print(f'rendered all files in {SOURCE_DIR}')

View File

@ -13,45 +13,50 @@ class Command(BaseCommand):
"""Rebuild all database models which leverage the MPTT structure."""
# Part model
try:
print("Rebuilding Part objects")
print('Rebuilding Part objects')
from part.models import Part
Part.objects.rebuild()
except Exception:
print("Error rebuilding Part objects")
print('Error rebuilding Part objects')
# Part category
try:
print("Rebuilding PartCategory objects")
print('Rebuilding PartCategory objects')
from part.models import PartCategory
PartCategory.objects.rebuild()
except Exception:
print("Error rebuilding PartCategory objects")
print('Error rebuilding PartCategory objects')
# StockItem model
try:
print("Rebuilding StockItem objects")
print('Rebuilding StockItem objects')
from stock.models import StockItem
StockItem.objects.rebuild()
except Exception:
print("Error rebuilding StockItem objects")
print('Error rebuilding StockItem objects')
# StockLocation model
try:
print("Rebuilding StockLocation objects")
print('Rebuilding StockLocation objects')
from stock.models import StockLocation
StockLocation.objects.rebuild()
except Exception:
print("Error rebuilding StockLocation objects")
print('Error rebuilding StockLocation objects')
# Build model
try:
print("Rebuilding Build objects")
print('Rebuilding Build objects')
from build.models import Build
Build.objects.rebuild()
except Exception:
print("Error rebuilding Build objects")
print('Error rebuilding Build objects')

View File

@ -37,20 +37,20 @@ class Command(BaseCommand):
def handle(self, *args, **kwargs):
"""Rebuild all thumbnail images."""
logger.info("Rebuilding Part thumbnails")
logger.info('Rebuilding Part thumbnails')
for part in Part.objects.exclude(image=None):
try:
self.rebuild_thumbnail(part)
except (OperationalError, ProgrammingError):
logger.exception("ERROR: Database read error.")
logger.exception('ERROR: Database read error.')
break
logger.info("Rebuilding Company thumbnails")
logger.info('Rebuilding Company thumbnails')
for company in Company.objects.exclude(image=None):
try:
self.rebuild_thumbnail(company)
except (OperationalError, ProgrammingError):
logger.exception("ERROR: abase read error.")
logger.exception('ERROR: abase read error.')
break

View File

@ -18,7 +18,12 @@ class Command(BaseCommand):
if not mail:
raise KeyError('A mail is required')
user = get_user_model()
mfa_user = [*set(user.objects.filter(email=mail) | user.objects.filter(emailaddress__email=mail))]
mfa_user = [
*set(
user.objects.filter(email=mail)
| user.objects.filter(emailaddress__email=mail)
)
]
if len(mfa_user) == 0:
print('No user with this mail associated')

View File

@ -12,12 +12,11 @@ class Command(BaseCommand):
def handle(self, *args, **kwargs):
"""Wait till the database is ready."""
self.stdout.write("Waiting for database...")
self.stdout.write('Waiting for database...')
connected = False
while not connected:
time.sleep(2)
try:
@ -26,12 +25,12 @@ class Command(BaseCommand):
connected = True
except OperationalError as e:
self.stdout.write(f"Could not connect to database: {e}")
self.stdout.write(f'Could not connect to database: {e}')
except ImproperlyConfigured as e:
self.stdout.write(f"Improperly configured: {e}")
self.stdout.write(f'Improperly configured: {e}')
else:
if not connection.is_usable():
self.stdout.write("Database configuration is not usable")
self.stdout.write('Database configuration is not usable')
if connected:
self.stdout.write("Database connection successful!")
self.stdout.write('Database connection successful!')

View File

@ -45,8 +45,9 @@ class InvenTreeMetadata(SimpleMetadata):
context = {}
if str2bool(request.query_params.get('context', False)):
if hasattr(self, 'serializer') and hasattr(self.serializer, 'get_context_data'):
if hasattr(self, 'serializer') and hasattr(
self.serializer, 'get_context_data'
):
context = self.serializer.get_context_data()
metadata['context'] = context
@ -68,7 +69,7 @@ class InvenTreeMetadata(SimpleMetadata):
metadata['model'] = tbl_label
table = f"{app_label}_{tbl_label}"
table = f'{app_label}_{tbl_label}'
actions = metadata.get('actions', None)
@ -86,12 +87,11 @@ class InvenTreeMetadata(SimpleMetadata):
}
# let the view define a custom rolemap
if hasattr(view, "rolemap"):
if hasattr(view, 'rolemap'):
rolemap.update(view.rolemap)
# Remove any HTTP methods that the user does not have permission for
for method, permission in rolemap.items():
result = check(user, table, permission)
if method in actions and not result:
@ -123,10 +123,7 @@ class InvenTreeMetadata(SimpleMetadata):
model_class = None
# Attributes to copy extra attributes from the model to the field (if they don't exist)
extra_attributes = [
'help_text',
'max_length',
]
extra_attributes = ['help_text', 'max_length']
try:
model_class = serializer.Meta.model
@ -142,11 +139,8 @@ class InvenTreeMetadata(SimpleMetadata):
# Iterate through simple fields
for name, field in model_fields.fields.items():
if name in serializer_info.keys():
if field.has_default():
default = field.default
if callable(default):
@ -162,13 +156,11 @@ class InvenTreeMetadata(SimpleMetadata):
for attr in extra_attributes:
if attr not in serializer_info[name]:
if hasattr(field, attr):
serializer_info[name][attr] = getattr(field, attr)
# Iterate through relations
for name, relation in model_fields.relations.items():
if name not in serializer_info.keys():
# Skip relation not defined in serializer
continue
@ -179,11 +171,17 @@ class InvenTreeMetadata(SimpleMetadata):
# Extract and provide the "limit_choices_to" filters
# This is used to automatically filter AJAX requests
serializer_info[name]['filters'] = relation.model_field.get_limit_choices_to()
serializer_info[name]['filters'] = (
relation.model_field.get_limit_choices_to()
)
for attr in extra_attributes:
if attr not in serializer_info[name] and hasattr(relation.model_field, attr):
serializer_info[name][attr] = getattr(relation.model_field, attr)
if attr not in serializer_info[name] and hasattr(
relation.model_field, attr
):
serializer_info[name][attr] = getattr(
relation.model_field, attr
)
if name in model_default_values:
serializer_info[name]['default'] = model_default_values[name]
@ -220,11 +218,9 @@ class InvenTreeMetadata(SimpleMetadata):
"""If there is an instance associated with this API View, introspect that instance to find any specific API info."""
if hasattr(instance, 'api_instance_filters'):
instance_filters = instance.api_instance_filters()
for field_name, field_filters in instance_filters.items():
if field_name not in serializer_info.keys():
# The field might be missing, but is added later on
# This function seems to get called multiple times?
@ -256,17 +252,20 @@ class InvenTreeMetadata(SimpleMetadata):
# Force non-nullable fields to read as "required"
# (even if there is a default value!)
if not field.allow_null and not (hasattr(field, 'allow_blank') and field.allow_blank):
if not field.allow_null and not (
hasattr(field, 'allow_blank') and field.allow_blank
):
field_info['required'] = True
# Introspect writable related fields
if field_info['type'] == 'field' and not field_info['read_only']:
# If the field is a PrimaryKeyRelatedField, we can extract the model from the queryset
if isinstance(field, serializers.PrimaryKeyRelatedField):
model = field.queryset.model
else:
logger.debug("Could not extract model for:", field_info.get('label'), '->', field)
logger.debug(
'Could not extract model for:', field_info.get('label'), '->', field
)
model = None
if model:
@ -287,4 +286,4 @@ class InvenTreeMetadata(SimpleMetadata):
return field_info
InvenTreeMetadata.label_lookup[DependentField] = "dependent field"
InvenTreeMetadata.label_lookup[DependentField] = 'dependent field'

View File

@ -9,14 +9,13 @@ from django.http import HttpResponse
from django.shortcuts import redirect
from django.urls import Resolver404, include, re_path, resolve, reverse_lazy
from allauth_2fa.middleware import (AllauthTwoFactorMiddleware,
BaseRequire2FAMiddleware)
from allauth_2fa.middleware import AllauthTwoFactorMiddleware, BaseRequire2FAMiddleware
from error_report.middleware import ExceptionProcessor
from InvenTree.urls import frontendpatterns
from users.models import ApiToken
logger = logging.getLogger("inventree")
logger = logging.getLogger('inventree')
class AuthRequiredMiddleware(object):
@ -64,11 +63,20 @@ class AuthRequiredMiddleware(object):
elif request.path_info.startswith('/accounts/'):
authorized = True
elif request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/') or request.path_info.startswith('/assets/') or request.path_info == f'/{settings.FRONTEND_URL_BASE}':
elif (
request.path_info.startswith(f'/{settings.FRONTEND_URL_BASE}/')
or request.path_info.startswith('/assets/')
or request.path_info == f'/{settings.FRONTEND_URL_BASE}'
):
authorized = True
elif 'Authorization' in request.headers.keys() or 'authorization' in request.headers.keys():
auth = request.headers.get('Authorization', request.headers.get('authorization')).strip()
elif (
'Authorization' in request.headers.keys()
or 'authorization' in request.headers.keys()
):
auth = request.headers.get(
'Authorization', request.headers.get('authorization')
).strip()
if auth.lower().startswith('token') and len(auth.split()) == 2:
token_key = auth.split()[1]
@ -78,13 +86,12 @@ class AuthRequiredMiddleware(object):
token = ApiToken.objects.get(key=token_key)
if token.active and token.user:
# Provide the user information to the request
request.user = token.user
authorized = True
except ApiToken.DoesNotExist:
logger.warning("Access denied for unknown token %s", token_key)
logger.warning('Access denied for unknown token %s', token_key)
# No authorization was found for the request
if not authorized:
@ -99,17 +106,16 @@ class AuthRequiredMiddleware(object):
]
# Do not redirect requests to any of these paths
paths_ignore = [
'/api/',
'/js/',
'/media/',
'/static/',
]
paths_ignore = ['/api/', '/js/', '/media/', '/static/']
if path not in urls and not any(path.startswith(p) for p in paths_ignore):
if path not in urls and not any(
path.startswith(p) for p in paths_ignore
):
# Save the 'next' parameter to pass through to the login view
return redirect(f'{reverse_lazy("account_login")}?next={request.path}')
return redirect(
f'{reverse_lazy("account_login")}?next={request.path}'
)
# Return a 401 (Unauthorized) response code for this request
return HttpResponse('Unauthorized', status=401)
@ -123,6 +129,7 @@ url_matcher = re_path('', include(frontendpatterns))
class Check2FAMiddleware(BaseRequire2FAMiddleware):
"""Check if user is required to have MFA enabled."""
def require_2fa(self, request):
"""Use setting to check if MFA should be enforced for frontend page."""
from common.models import InvenTreeSetting
@ -137,6 +144,7 @@ class Check2FAMiddleware(BaseRequire2FAMiddleware):
class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware):
"""This function ensures only frontend code triggers the MFA auth cycle."""
def process_request(self, request):
"""Check if requested url is forntend and enforce MFA check."""
try:
@ -148,6 +156,7 @@ class CustomAllauthTwoFactorMiddleware(AllauthTwoFactorMiddleware):
class InvenTreeRemoteUserMiddleware(PersistentRemoteUserMiddleware):
"""Middleware to check if HTTP-header based auth is enabled and to set it up."""
header = settings.REMOTE_LOGIN_HEADER
def process_request(self, request):

View File

@ -62,7 +62,7 @@ class DiffMixin:
return field_name in self.get_field_deltas()
class CleanMixin():
class CleanMixin:
"""Model mixin class which cleans inputs using the Mozilla bleach tools."""
# Define a list of field names which will *not* be cleaned
@ -74,13 +74,17 @@ class CleanMixin():
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)
def update(self, request, *args, **kwargs):
"""Override to clean data before processing it."""
partial = kwargs.pop('partial', False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=self.clean_data(request.data), partial=partial)
serializer = self.get_serializer(
instance, data=self.clean_data(request.data), partial=partial
)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
@ -113,9 +117,7 @@ class CleanMixin():
field = model._meta.get_field(field)
# The following field types allow newline characters
allow_newline = [
InvenTreeNotesField,
]
allow_newline = [InvenTreeNotesField]
for field_type in allow_newline:
if issubclass(type(field), field_type):
@ -127,7 +129,9 @@ class CleanMixin():
except FieldDoesNotExist:
pass
cleaned = remove_non_printable_characters(cleaned, remove_newline=remove_newline)
cleaned = remove_non_printable_characters(
cleaned, remove_newline=remove_newline
)
return cleaned
@ -148,7 +152,6 @@ class CleanMixin():
clean_data = {}
for k, v in data.items():
if k in self.SAFE_FIELDS:
ret = v
elif isinstance(v, str):
@ -177,16 +180,19 @@ class CreateAPI(CleanMixin, generics.CreateAPIView):
class RetrieveAPI(generics.RetrieveAPIView):
"""View for retrieve API."""
pass
class RetrieveUpdateAPI(CleanMixin, generics.RetrieveUpdateAPIView):
"""View for retrieve and update API."""
pass
class CustomDestroyModelMixin:
"""This mixin was created pass the kwargs from the API to the models."""
def destroy(self, request, *args, **kwargs):
"""Custom destroy method to pass kwargs."""
instance = self.get_object()
@ -198,11 +204,14 @@ class CustomDestroyModelMixin:
instance.delete(**kwargs)
class CustomRetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
CustomDestroyModelMixin,
generics.GenericAPIView):
class CustomRetrieveUpdateDestroyAPIView(
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
CustomDestroyModelMixin,
generics.GenericAPIView,
):
"""This APIView was created pass the kwargs from the API to the models."""
def get(self, request, *args, **kwargs):
"""Custom get method to pass kwargs."""
return self.retrieve(request, *args, **kwargs)

View File

@ -58,6 +58,7 @@ class MetadataMixin(models.Model):
class Meta:
"""Meta for MetadataMixin."""
abstract = True
def save(self, *args, **kwargs):
@ -78,10 +79,13 @@ class MetadataMixin(models.Model):
self.metadata = {}
if type(self.metadata) is not dict:
raise ValidationError({'metadata': _('Metadata must be a python dict object')})
raise ValidationError({
'metadata': _('Metadata must be a python dict object')
})
metadata = models.JSONField(
blank=True, null=True,
blank=True,
null=True,
verbose_name=_('Plugin Metadata'),
help_text=_('JSON metadata field, for use by external plugins'),
)
@ -100,7 +104,9 @@ class MetadataMixin(models.Model):
return self.metadata.get(key, backup_value)
def set_metadata(self, key: str, data, commit: bool = True, overwrite: bool = False):
def set_metadata(
self, key: str, data, commit: bool = True, overwrite: bool = False
):
"""Save the provided metadata under the provided key.
Args:
@ -137,7 +143,6 @@ class DataImportMixin(object):
fields = cls.IMPORT_FIELDS
for name, field in fields.items():
# Attempt to extract base field information from the model
base_field = None
@ -207,7 +212,10 @@ class ReferenceIndexingMixin(models.Model):
# import at function level to prevent cyclic imports
from common.models import InvenTreeSetting
return InvenTreeSetting.get_setting(cls.REFERENCE_PATTERN_SETTING, create=False).strip()
return InvenTreeSetting.get_setting(
cls.REFERENCE_PATTERN_SETTING, create=False
).strip()
@classmethod
def get_reference_context(cls):
@ -216,10 +224,7 @@ class ReferenceIndexingMixin(models.Model):
- Returns a python dict object which contains the context data for formatting the reference string.
- The default implementation provides some default context information
"""
return {
'ref': cls.get_next_reference(),
'date': datetime.now(),
}
return {'ref': cls.get_next_reference(), 'date': datetime.now()}
@classmethod
def get_most_recent_item(cls):
@ -246,7 +251,9 @@ class ReferenceIndexingMixin(models.Model):
reference = latest.reference.strip
try:
reference = InvenTree.format.extract_named_group('ref', reference, cls.get_reference_pattern())
reference = InvenTree.format.extract_named_group(
'ref', reference, cls.get_reference_pattern()
)
except Exception:
# If reference cannot be extracted using the pattern, try just the integer value
reference = str(latest.reference_int)
@ -296,7 +303,7 @@ class ReferenceIndexingMixin(models.Model):
if recent:
reference = recent.reference
else:
reference = ""
reference = ''
return reference
@ -309,20 +316,20 @@ class ReferenceIndexingMixin(models.Model):
info = InvenTree.format.parse_format_string(pattern)
except Exception as exc:
raise ValidationError({
"value": _("Improperly formatted pattern") + ": " + str(exc)
'value': _('Improperly formatted pattern') + ': ' + str(exc)
})
# Check that only 'allowed' keys are provided
for key in info.keys():
if key not in ctx.keys():
raise ValidationError({
"value": _("Unknown format key specified") + f": '{key}'"
'value': _('Unknown format key specified') + f": '{key}'"
})
# Check that the 'ref' variable is specified
if 'ref' not in info.keys():
raise ValidationError({
'value': _("Missing required format key") + ": 'ref'"
'value': _('Missing required format key') + ": 'ref'"
})
@classmethod
@ -333,14 +340,16 @@ class ReferenceIndexingMixin(models.Model):
value = str(value).strip()
if len(value) == 0:
raise ValidationError(_("Reference field cannot be empty"))
raise ValidationError(_('Reference field cannot be empty'))
# An 'empty' pattern means no further validation is required
if not pattern:
return
if not InvenTree.format.validate_string(value, pattern):
raise ValidationError(_("Reference must match required pattern") + ": " + pattern)
raise ValidationError(
_('Reference must match required pattern') + ': ' + pattern
)
# Check that the reference field can be rebuild
cls.rebuild_reference_field(value, validate=True)
@ -361,7 +370,9 @@ class ReferenceIndexingMixin(models.Model):
"""
try:
# Extract named group based on provided pattern
reference = InvenTree.format.extract_named_group('ref', reference, cls.get_reference_pattern())
reference = InvenTree.format.extract_named_group(
'ref', reference, cls.get_reference_pattern()
)
except Exception:
pass
@ -369,16 +380,14 @@ class ReferenceIndexingMixin(models.Model):
if validate:
if reference_int > models.BigIntegerField.MAX_BIGINT:
raise ValidationError({
"reference": _("Reference number is too large")
})
raise ValidationError({'reference': _('Reference number is too large')})
return reference_int
reference_int = models.BigIntegerField(default=0)
def extract_int(reference, clip=0x7fffffff, allow_negative=False):
def extract_int(reference, clip=0x7FFFFFFF, allow_negative=False):
"""Extract an integer out of reference."""
# Default value if we cannot convert to an integer
ref_int = 0
@ -390,7 +399,7 @@ def extract_int(reference, clip=0x7fffffff, allow_negative=False):
return 0
# Look at the start of the string - can it be "integerized"?
result = re.match(r"^(\d+)", reference)
result = re.match(r'^(\d+)', reference)
if result and len(result.groups()) == 1:
ref = result.groups()[0]
@ -438,6 +447,7 @@ class InvenTreeAttachment(models.Model):
class Meta:
"""Metaclass options. Abstract ensures no database table is created."""
abstract = True
def getSubdir(self):
@ -445,7 +455,7 @@ class InvenTreeAttachment(models.Model):
Note: Re-implement this for each subclass of InvenTreeAttachment
"""
return "attachments"
return 'attachments'
def save(self, *args, **kwargs):
"""Provide better validation error."""
@ -472,28 +482,40 @@ class InvenTreeAttachment(models.Model):
return os.path.basename(self.attachment.name)
return str(self.link)
attachment = models.FileField(upload_to=rename_attachment, verbose_name=_('Attachment'),
help_text=_('Select file to attach'),
blank=True, null=True
)
link = InvenTree.fields.InvenTreeURLField(
blank=True, null=True,
verbose_name=_('Link'),
help_text=_('Link to external URL')
attachment = models.FileField(
upload_to=rename_attachment,
verbose_name=_('Attachment'),
help_text=_('Select file to attach'),
blank=True,
null=True,
)
comment = models.CharField(blank=True, max_length=100, verbose_name=_('Comment'), help_text=_('File comment'))
link = InvenTree.fields.InvenTreeURLField(
blank=True,
null=True,
verbose_name=_('Link'),
help_text=_('Link to external URL'),
)
comment = models.CharField(
blank=True,
max_length=100,
verbose_name=_('Comment'),
help_text=_('File comment'),
)
user = models.ForeignKey(
User,
on_delete=models.SET_NULL,
blank=True, null=True,
blank=True,
null=True,
verbose_name=_('User'),
help_text=_('User'),
)
upload_date = models.DateField(auto_now_add=True, null=True, blank=True, verbose_name=_('upload date'))
upload_date = models.DateField(
auto_now_add=True, null=True, blank=True, verbose_name=_('upload date')
)
@property
def basename(self):
@ -522,35 +544,59 @@ class InvenTreeAttachment(models.Model):
# Check that there are no directory tricks going on...
if new_file.parent != attachment_dir:
logger.error("Attempted to rename attachment outside valid directory: '%s'", new_file)
raise ValidationError(_("Invalid attachment directory"))
logger.error(
"Attempted to rename attachment outside valid directory: '%s'", new_file
)
raise ValidationError(_('Invalid attachment directory'))
# Ignore further checks if the filename is not actually being renamed
if new_file == old_file:
return
forbidden = ["'", '"', "#", "@", "!", "&", "^", "<", ">", ":", ";", "/", "\\", "|", "?", "*", "%", "~", "`"]
forbidden = [
"'",
'"',
'#',
'@',
'!',
'&',
'^',
'<',
'>',
':',
';',
'/',
'\\',
'|',
'?',
'*',
'%',
'~',
'`',
]
for c in forbidden:
if c in fn:
raise ValidationError(_(f"Filename contains illegal character '{c}'"))
if len(fn.split('.')) < 2:
raise ValidationError(_("Filename missing extension"))
raise ValidationError(_('Filename missing extension'))
if not old_file.exists():
logger.error("Trying to rename attachment '%s' which does not exist", old_file)
logger.error(
"Trying to rename attachment '%s' which does not exist", old_file
)
return
if new_file.exists():
raise ValidationError(_("Attachment with this filename already exists"))
raise ValidationError(_('Attachment with this filename already exists'))
try:
os.rename(old_file, new_file)
self.attachment.name = os.path.join(self.getSubdir(), fn)
self.save()
except Exception:
raise ValidationError(_("Error renaming file"))
raise ValidationError(_('Error renaming file'))
def fully_qualified_url(self):
"""Return a 'fully qualified' URL for this attachment.
@ -586,10 +632,12 @@ class InvenTreeTree(MPTTModel):
class Meta:
"""Metaclass defines extra model properties."""
abstract = True
class MPTTMeta:
"""Set insert order."""
order_insertion_by = ['name']
def delete(self, delete_children=False, delete_items=False):
@ -607,13 +655,19 @@ class InvenTreeTree(MPTTModel):
self.refresh_from_db()
except self.__class__.DoesNotExist:
# If the object no longer exists, raise a ValidationError
raise ValidationError("Object %s of type %s no longer exists", str(self), str(self.__class__))
raise ValidationError(
'Object %s of type %s no longer exists', str(self), str(self.__class__)
)
# Cache node ID values for lower nodes, before we delete this one
lower_nodes = list(self.get_descendants(include_self=False).values_list('pk', flat=True))
lower_nodes = list(
self.get_descendants(include_self=False).values_list('pk', flat=True)
)
# 1. Update nodes and items under the current node
self.handle_tree_delete(delete_children=delete_children, delete_items=delete_items)
self.handle_tree_delete(
delete_children=delete_children, delete_items=delete_items
)
# 2. Delete *this* node
super().delete()
@ -673,9 +727,7 @@ class InvenTreeTree(MPTTModel):
# - Move all items at any lower level to the parent of this item
# - Delete all descendant nodes
elif delete_children and not delete_items:
self.get_items(cascade=True).update(**{
self.ITEM_PARENT_KEY: self.parent
})
self.get_items(cascade=True).update(**{self.ITEM_PARENT_KEY: self.parent})
self.delete_nodes(child_nodes)
@ -690,9 +742,7 @@ class InvenTreeTree(MPTTModel):
# - Move all items directly associated with this node up one level
# - Move any direct child nodes up one level
elif not delete_children and not delete_items:
self.get_items(cascade=False).update(**{
self.ITEM_PARENT_KEY: self.parent
})
self.get_items(cascade=False).update(**{self.ITEM_PARENT_KEY: self.parent})
self.get_children().update(parent=self.parent)
def delete_nodes(self, nodes):
@ -719,8 +769,7 @@ class InvenTreeTree(MPTTModel):
super().validate_unique(exclude)
results = self.__class__.objects.filter(
name=self.name,
parent=self.parent
name=self.name, parent=self.parent
).exclude(pk=self.pk)
if results.exists():
@ -730,17 +779,11 @@ class InvenTreeTree(MPTTModel):
def api_instance_filters(self):
"""Instance filters for InvenTreeTree models."""
return {
'parent': {
'exclude_tree': self.pk,
}
}
return {'parent': {'exclude_tree': self.pk}}
def construct_pathstring(self):
"""Construct the pathstring for this tree node"""
return InvenTree.helpers.constructPathString(
[item.name for item in self.path]
)
return InvenTree.helpers.constructPathString([item.name for item in self.path])
def save(self, *args, **kwargs):
"""Custom save method for InvenTreeTree abstract model"""
@ -748,15 +791,12 @@ class InvenTreeTree(MPTTModel):
super().save(*args, **kwargs)
except InvalidMove:
# Provide better error for parent selection
raise ValidationError({
'parent': _("Invalid choice"),
})
raise ValidationError({'parent': _('Invalid choice')})
# Re-calculate the 'pathstring' field
pathstring = self.construct_pathstring()
if pathstring != self.pathstring:
if 'force_insert' in kwargs:
del kwargs['force_insert']
@ -781,33 +821,29 @@ class InvenTreeTree(MPTTModel):
self.__class__.objects.bulk_update(nodes_to_update, ['pathstring'])
name = models.CharField(
blank=False,
max_length=100,
verbose_name=_("Name"),
help_text=_("Name"),
blank=False, max_length=100, verbose_name=_('Name'), help_text=_('Name')
)
description = models.CharField(
blank=True,
max_length=250,
verbose_name=_("Description"),
help_text=_("Description (optional)")
verbose_name=_('Description'),
help_text=_('Description (optional)'),
)
# When a category is deleted, graft the children onto its parent
parent = TreeForeignKey('self',
on_delete=models.DO_NOTHING,
blank=True,
null=True,
verbose_name=_("parent"),
related_name='children')
parent = TreeForeignKey(
'self',
on_delete=models.DO_NOTHING,
blank=True,
null=True,
verbose_name=_('parent'),
related_name='children',
)
# The 'pathstring' field is calculated each time the model is saved
pathstring = models.CharField(
blank=True,
max_length=250,
verbose_name=_('Path'),
help_text=_('Path')
blank=True, max_length=250, verbose_name=_('Path'), help_text=_('Path')
)
def get_items(self, cascade=False):
@ -818,7 +854,7 @@ class InvenTreeTree(MPTTModel):
The default implementation returns an empty list
"""
raise NotImplementedError(f"items() method not implemented for {type(self)}")
raise NotImplementedError(f'items() method not implemented for {type(self)}')
def getUniqueParents(self):
"""Return a flat set of all parent items that exist above this node.
@ -889,16 +925,11 @@ class InvenTreeTree(MPTTModel):
name: <name>,
}
"""
return [
{
'pk': item.pk,
'name': item.name
} for item in self.path
]
return [{'pk': item.pk, 'name': item.name} for item in self.path]
def __str__(self):
"""String representation of a category is the full path to that category."""
return f"{self.pathstring} - {self.description}"
return f'{self.pathstring} - {self.description}'
class InvenTreeNotesMixin(models.Model):
@ -914,11 +945,11 @@ class InvenTreeNotesMixin(models.Model):
Note: abstract must be true, as this is only a mixin, not a separate table
"""
abstract = True
notes = InvenTree.fields.InvenTreeNotesField(
verbose_name=_('Notes'),
help_text=_('Markdown notes (optional)'),
verbose_name=_('Notes'), help_text=_('Markdown notes (optional)')
)
@ -941,18 +972,21 @@ class InvenTreeBarcodeMixin(models.Model):
Note: abstract must be true, as this is only a mixin, not a separate table
"""
abstract = True
barcode_data = models.CharField(
blank=True, max_length=500,
blank=True,
max_length=500,
verbose_name=_('Barcode Data'),
help_text=_('Third party barcode data'),
)
barcode_hash = models.CharField(
blank=True, max_length=128,
blank=True,
max_length=128,
verbose_name=_('Barcode Hash'),
help_text=_('Unique hash of barcode data')
help_text=_('Unique hash of barcode data'),
)
@classmethod
@ -964,21 +998,17 @@ class InvenTreeBarcodeMixin(models.Model):
def format_barcode(self, **kwargs):
"""Return a JSON string for formatting a QR code for this model instance."""
return InvenTree.helpers.MakeBarcode(
self.__class__.barcode_model_type(),
self.pk,
**kwargs
self.__class__.barcode_model_type(), self.pk, **kwargs
)
def format_matched_response(self):
"""Format a standard response for a matched barcode."""
data = {
'pk': self.pk,
}
data = {'pk': self.pk}
if hasattr(self, 'get_api_url'):
api_url = self.get_api_url()
data['api_url'] = f"{api_url}{self.pk}/"
data['api_url'] = f'{api_url}{self.pk}/'
if hasattr(self, 'get_absolute_url'):
data['web_url'] = self.get_absolute_url()
@ -995,7 +1025,9 @@ class InvenTreeBarcodeMixin(models.Model):
"""Check if a model instance exists with the specified third-party barcode hash."""
return cls.objects.filter(barcode_hash=barcode_hash).first()
def assign_barcode(self, barcode_hash=None, barcode_data=None, raise_error=True, save=True):
def assign_barcode(
self, barcode_hash=None, barcode_data=None, raise_error=True, save=True
):
"""Assign an external (third-party) barcode to this object."""
# Must provide either barcode_hash or barcode_data
if barcode_hash is None and barcode_data is None:
@ -1008,7 +1040,7 @@ class InvenTreeBarcodeMixin(models.Model):
# Check for existing item
if self.__class__.lookup_barcode(barcode_hash) is not None:
if raise_error:
raise ValidationError(_("Existing barcode found"))
raise ValidationError(_('Existing barcode found'))
else:
return False
@ -1044,20 +1076,24 @@ def after_error_logged(sender, instance: Error, created: bool, **kwargs):
users = get_user_model().objects.filter(is_staff=True)
link = InvenTree.helpers_model.construct_absolute_url(
reverse('admin:error_report_error_change', kwargs={'object_id': instance.pk})
reverse(
'admin:error_report_error_change', kwargs={'object_id': instance.pk}
)
)
context = {
'error': instance,
'name': _('Server Error'),
'message': _('An error has been logged by the server.'),
'link': link
'link': link,
}
target_users = []
for user in users:
if common.models.InvenTreeUserSetting.get_setting('NOTIFICATION_ERROR_REPORT', True, user=user):
if common.models.InvenTreeUserSetting.get_setting(
'NOTIFICATION_ERROR_REPORT', True, user=user
):
target_users.append(user)
if len(target_users) > 0:
@ -1066,7 +1102,7 @@ def after_error_logged(sender, instance: Error, created: bool, **kwargs):
'inventree.error_log',
context=context,
targets=target_users,
delivery_methods={common.notifications.UIMessageNotification, },
delivery_methods={common.notifications.UIMessageNotification},
)
except Exception as exc:

View File

@ -18,7 +18,7 @@ def get_model_for_view(view, raise_error=True):
if hasattr(view, 'get_serializer_class'):
return view.get_serializr_class().Meta.model
raise AttributeError(f"Serializer class not specified for {view.__class__}")
raise AttributeError(f'Serializer class not specified for {view.__class__}')
class RolePermission(permissions.BasePermission):
@ -62,7 +62,7 @@ class RolePermission(permissions.BasePermission):
}
# let the view define a custom rolemap
if hasattr(view, "rolemap"):
if hasattr(view, 'rolemap'):
rolemap.update(view.rolemap)
permission = rolemap[request.method]
@ -78,7 +78,7 @@ class RolePermission(permissions.BasePermission):
app_label = model._meta.app_label
model_name = model._meta.model_name
table = f"{app_label}_{model_name}"
table = f'{app_label}_{model_name}'
except AttributeError:
# We will assume that if the serializer class does *not* have a Meta,
# then we don't need a permission
@ -100,12 +100,18 @@ class IsStaffOrReadOnly(permissions.IsAdminUser):
def has_permission(self, request, view):
"""Check if the user is a superuser."""
return bool(request.user and request.user.is_staff or request.method in permissions.SAFE_METHODS)
return bool(
request.user
and request.user.is_staff
or request.method in permissions.SAFE_METHODS
)
def auth_exempt(view_func):
"""Mark a view function as being exempt from auth requirements."""
def wrapped_view(*args, **kwargs):
return view_func(*args, **kwargs)
wrapped_view.auth_exempt = True
return wraps(view_func)(wrapped_view)

View File

@ -16,11 +16,7 @@ def isImportingData():
def isRunningMigrations():
"""Return True if the database is currently running migrations."""
return any((x in sys.argv for x in [
'migrate',
'makemigrations',
'showmigrations'
]))
return any((x in sys.argv for x in ['migrate', 'makemigrations', 'showmigrations']))
def isInMainThread():
@ -29,13 +25,15 @@ def isInMainThread():
- The RUN_MAIN env is set in that case. However if --noreload is applied, this variable
is not set because there are no different threads.
"""
if "runserver" in sys.argv and "--noreload" not in sys.argv:
return os.environ.get('RUN_MAIN', None) == "true"
if 'runserver' in sys.argv and '--noreload' not in sys.argv:
return os.environ.get('RUN_MAIN', None) == 'true'
return True
def canAppAccessDatabase(allow_test: bool = False, allow_plugins: bool = False, allow_shell: bool = False):
def canAppAccessDatabase(
allow_test: bool = False, allow_plugins: bool = False, allow_shell: bool = False
):
"""Returns True if the apps.py file can access database records.
There are some circumstances where we don't want the ready function in apps.py

View File

@ -1,49 +1,196 @@
"""Functions to sanitize user input files."""
from bleach import clean
from bleach.css_sanitizer import CSSSanitizer
ALLOWED_ELEMENTS_SVG = [
'a', 'animate', 'animateColor', 'animateMotion',
'animateTransform', 'circle', 'defs', 'desc', 'ellipse', 'font-face',
'font-face-name', 'font-face-src', 'g', 'glyph', 'hkern',
'linearGradient', 'line', 'marker', 'metadata', 'missing-glyph',
'mpath', 'path', 'polygon', 'polyline', 'radialGradient', 'rect',
'set', 'stop', 'svg', 'switch', 'text', 'title', 'tspan', 'use'
'a',
'animate',
'animateColor',
'animateMotion',
'animateTransform',
'circle',
'defs',
'desc',
'ellipse',
'font-face',
'font-face-name',
'font-face-src',
'g',
'glyph',
'hkern',
'linearGradient',
'line',
'marker',
'metadata',
'missing-glyph',
'mpath',
'path',
'polygon',
'polyline',
'radialGradient',
'rect',
'set',
'stop',
'svg',
'switch',
'text',
'title',
'tspan',
'use',
]
ALLOWED_ATTRIBUTES_SVG = [
'accent-height', 'accumulate', 'additive', 'alphabetic',
'arabic-form', 'ascent', 'attributeName', 'attributeType',
'baseProfile', 'bbox', 'begin', 'by', 'calcMode', 'cap-height',
'class', 'color', 'color-rendering', 'content', 'cx', 'cy', 'd', 'dx',
'dy', 'descent', 'display', 'dur', 'end', 'fill', 'fill-opacity',
'fill-rule', 'font-family', 'font-size', 'font-stretch', 'font-style',
'font-variant', 'font-weight', 'from', 'fx', 'fy', 'g1', 'g2',
'glyph-name', 'gradientUnits', 'hanging', 'height', 'horiz-adv-x',
'horiz-origin-x', 'id', 'ideographic', 'k', 'keyPoints',
'keySplines', 'keyTimes', 'lang', 'marker-end', 'marker-mid',
'marker-start', 'markerHeight', 'markerUnits', 'markerWidth',
'mathematical', 'max', 'min', 'name', 'offset', 'opacity', 'orient',
'origin', 'overline-position', 'overline-thickness', 'panose-1',
'path', 'pathLength', 'points', 'preserveAspectRatio', 'r', 'refX',
'refY', 'repeatCount', 'repeatDur', 'requiredExtensions',
'requiredFeatures', 'restart', 'rotate', 'rx', 'ry', 'slope',
'stemh', 'stemv', 'stop-color', 'stop-opacity',
'strikethrough-position', 'strikethrough-thickness', 'stroke',
'stroke-dasharray', 'stroke-dashoffset', 'stroke-linecap',
'stroke-linejoin', 'stroke-miterlimit', 'stroke-opacity',
'stroke-width', 'systemLanguage', 'target', 'text-anchor', 'to',
'transform', 'type', 'u1', 'u2', 'underline-position',
'underline-thickness', 'unicode', 'unicode-range', 'units-per-em',
'values', 'version', 'viewBox', 'visibility', 'width', 'widths', 'x',
'x-height', 'x1', 'x2', 'xlink:actuate', 'xlink:arcrole',
'xlink:href', 'xlink:role', 'xlink:show', 'xlink:title',
'xlink:type', 'xml:base', 'xml:lang', 'xml:space', 'xmlns',
'xmlns:xlink', 'y', 'y1', 'y2', 'zoomAndPan', 'style'
'accent-height',
'accumulate',
'additive',
'alphabetic',
'arabic-form',
'ascent',
'attributeName',
'attributeType',
'baseProfile',
'bbox',
'begin',
'by',
'calcMode',
'cap-height',
'class',
'color',
'color-rendering',
'content',
'cx',
'cy',
'd',
'dx',
'dy',
'descent',
'display',
'dur',
'end',
'fill',
'fill-opacity',
'fill-rule',
'font-family',
'font-size',
'font-stretch',
'font-style',
'font-variant',
'font-weight',
'from',
'fx',
'fy',
'g1',
'g2',
'glyph-name',
'gradientUnits',
'hanging',
'height',
'horiz-adv-x',
'horiz-origin-x',
'id',
'ideographic',
'k',
'keyPoints',
'keySplines',
'keyTimes',
'lang',
'marker-end',
'marker-mid',
'marker-start',
'markerHeight',
'markerUnits',
'markerWidth',
'mathematical',
'max',
'min',
'name',
'offset',
'opacity',
'orient',
'origin',
'overline-position',
'overline-thickness',
'panose-1',
'path',
'pathLength',
'points',
'preserveAspectRatio',
'r',
'refX',
'refY',
'repeatCount',
'repeatDur',
'requiredExtensions',
'requiredFeatures',
'restart',
'rotate',
'rx',
'ry',
'slope',
'stemh',
'stemv',
'stop-color',
'stop-opacity',
'strikethrough-position',
'strikethrough-thickness',
'stroke',
'stroke-dasharray',
'stroke-dashoffset',
'stroke-linecap',
'stroke-linejoin',
'stroke-miterlimit',
'stroke-opacity',
'stroke-width',
'systemLanguage',
'target',
'text-anchor',
'to',
'transform',
'type',
'u1',
'u2',
'underline-position',
'underline-thickness',
'unicode',
'unicode-range',
'units-per-em',
'values',
'version',
'viewBox',
'visibility',
'width',
'widths',
'x',
'x-height',
'x1',
'x2',
'xlink:actuate',
'xlink:arcrole',
'xlink:href',
'xlink:role',
'xlink:show',
'xlink:title',
'xlink:type',
'xml:base',
'xml:lang',
'xml:space',
'xmlns',
'xmlns:xlink',
'y',
'y1',
'y2',
'zoomAndPan',
'style',
]
def sanitize_svg(file_data, strip: bool = True, elements: str = ALLOWED_ELEMENTS_SVG, attributes: str = ALLOWED_ATTRIBUTES_SVG) -> str:
def sanitize_svg(
file_data,
strip: bool = True,
elements: str = ALLOWED_ELEMENTS_SVG,
attributes: str = ALLOWED_ATTRIBUTES_SVG,
) -> str:
"""Sanitize a SVG file.
Args:
@ -65,7 +212,7 @@ def sanitize_svg(file_data, strip: bool = True, elements: str = ALLOWED_ELEMENTS
attributes=attributes,
strip=strip,
strip_comments=strip,
css_sanitizer=CSSSanitizer()
css_sanitizer=CSSSanitizer(),
)
return cleaned

View File

@ -37,7 +37,7 @@ def sentry_ignore_errors():
def init_sentry(dsn, sample_rate, tags):
"""Initialize sentry.io error reporting"""
logger.info("Initializing sentry.io integration")
logger.info('Initializing sentry.io integration')
sentry_sdk.init(
dsn=dsn,
@ -46,7 +46,9 @@ def init_sentry(dsn, sample_rate, tags):
send_default_pii=True,
ignore_errors=sentry_ignore_errors(),
release=InvenTree.version.INVENTREE_SW_VERSION,
environment='development' if InvenTree.version.isInvenTreeDevelopmentVersion() else 'production'
environment='development'
if InvenTree.version.isInvenTreeDevelopmentVersion()
else 'production',
)
for key, val in tags.items():
@ -62,11 +64,10 @@ def init_sentry(dsn, sample_rate, tags):
def report_exception(exc):
"""Report an exception to sentry.io"""
if settings.SENTRY_ENABLED and settings.SENTRY_DSN:
if not any(isinstance(exc, e) for e in sentry_ignore_errors()):
logger.info("Reporting exception to sentry.io: %s", exc)
logger.info('Reporting exception to sentry.io: %s', exc)
try:
sentry_sdk.capture_exception(exc)
except Exception:
logger.warning("Failed to report exception to sentry.io")
logger.warning('Failed to report exception to sentry.io')

View File

@ -37,9 +37,9 @@ class InvenTreeMoneySerializer(MoneyField):
def __init__(self, *args, **kwargs):
"""Override default values."""
kwargs["max_digits"] = kwargs.get("max_digits", 19)
self.decimal_places = kwargs["decimal_places"] = kwargs.get("decimal_places", 6)
kwargs["required"] = kwargs.get("required", False)
kwargs['max_digits'] = kwargs.get('max_digits', 19)
self.decimal_places = kwargs['decimal_places'] = kwargs.get('decimal_places', 6)
kwargs['required'] = kwargs.get('required', False)
super().__init__(*args, **kwargs)
@ -57,13 +57,18 @@ class InvenTreeMoneySerializer(MoneyField):
amount = Decimal(amount)
amount = round(amount, self.decimal_places)
except Exception:
raise ValidationError({
self.field_name: [_("Must be a valid number")],
})
raise ValidationError({self.field_name: [_('Must be a valid number')]})
currency = data.get(get_currency_field_name(self.field_name), self.default_currency)
currency = data.get(
get_currency_field_name(self.field_name), self.default_currency
)
if currency and amount is not None and not isinstance(amount, MONEY_CLASSES) and amount is not empty:
if (
currency
and amount is not None
and not isinstance(amount, MONEY_CLASSES)
and amount is not empty
):
return Money(amount, currency)
return amount
@ -76,7 +81,9 @@ class InvenTreeCurrencySerializer(serializers.ChoiceField):
"""Initialize the currency serializer"""
choices = currency_code_mappings()
allow_blank = kwargs.get('allow_blank', False) or kwargs.get('allow_null', False)
allow_blank = kwargs.get('allow_blank', False) or kwargs.get(
'allow_null', False
)
if allow_blank:
choices = [('', '---------')] + choices
@ -97,6 +104,7 @@ class InvenTreeCurrencySerializer(serializers.ChoiceField):
class DependentField(serializers.Field):
"""A dependent field can be used to dynamically return child fields based on the value of other fields."""
child = None
def __init__(self, *args, depends_on, field_serializer, **kwargs):
@ -126,7 +134,7 @@ class DependentField(serializers.Field):
def get_child(self, raise_exception=False):
"""This method tries to extract the child based on the provided data in the request by the client."""
data = deepcopy(self.context["request"].data)
data = deepcopy(self.context['request'].data)
def visit_parent(node):
"""Recursively extract the data for the parent field/serializer in reverse."""
@ -136,8 +144,9 @@ class DependentField(serializers.Field):
visit_parent(node.parent)
# only do for composite fields and stop right before the current field
if hasattr(node, "child") and node is not self and isinstance(data, dict):
if hasattr(node, 'child') and node is not self and isinstance(data, dict):
data = data.get(node.field_name, None)
visit_parent(self)
# ensure that data is a dictionary and that a parent exists
@ -152,7 +161,9 @@ class DependentField(serializers.Field):
# partially validate the data for options requests that set raise_exception while calling .get_child(...)
if raise_exception:
validation_data = {k: v for k, v in data.items() if k in self.depends_on}
serializer = self.parent.__class__(context=self.context, data=validation_data, partial=True)
serializer = self.parent.__class__(
context=self.context, data=validation_data, partial=True
)
serializer.is_valid(raise_exception=raise_exception)
# try to get the field serializer
@ -196,7 +207,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
"""Custom __init__ routine to ensure that *default* values (as specified in the ORM) are used by the DRF serializers, *if* the values are not provided by the user."""
# If instance is None, we are creating a new instance
if instance is None and data is not empty:
if data is None:
data = OrderedDict()
else:
@ -211,7 +221,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
fields = model_meta.get_field_info(ModelClass)
for field_name, field in fields.fields.items():
"""
Update the field IF (and ONLY IF):
@ -219,7 +228,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
- The field does not already have a value set
"""
if field.has_default() and field_name not in data:
value = field.default
# Account for callable functions
@ -247,9 +255,7 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
fields = model_meta.get_field_info(ModelClass)
for field_name, field in fields.fields.items():
if field.has_default() and field_name not in initials:
value = field.default
# Account for callable functions
@ -337,7 +343,6 @@ class InvenTreeModelSerializer(serializers.ModelSerializer):
try:
instance.full_clean()
except (ValidationError, DjangoValidationError) as exc:
data = exc.message_dict
# Change '__all__' key (django style) to 'non_field_errors' (DRF style)
@ -369,6 +374,7 @@ class InvenTreeTaggitSerializer(TaggitSerializer):
class InvenTreeTagModelSerializer(InvenTreeTaggitSerializer, InvenTreeModelSerializer):
"""Combination of InvenTreeTaggitSerializer and InvenTreeModelSerializer."""
pass
@ -377,38 +383,31 @@ class UserSerializer(InvenTreeModelSerializer):
class Meta:
"""Metaclass defines serializer fields."""
model = User
fields = [
'pk',
'username',
'first_name',
'last_name',
'email',
]
read_only_fields = [
'username',
]
model = User
fields = ['pk', 'username', 'first_name', 'last_name', 'email']
read_only_fields = ['username']
class ExendedUserSerializer(UserSerializer):
"""Serializer for a User with a bit more info."""
from users.serializers import GroupSerializer
groups = GroupSerializer(read_only=True, many=True)
class Meta(UserSerializer.Meta):
"""Metaclass defines serializer fields."""
fields = UserSerializer.Meta.fields + [
'groups',
'is_staff',
'is_superuser',
'is_active'
'is_active',
]
read_only_fields = UserSerializer.Meta.read_only_fields + [
'groups',
]
read_only_fields = UserSerializer.Meta.read_only_fields + ['groups']
def validate(self, attrs):
"""Expanded validation for changing user role."""
@ -424,17 +423,20 @@ class ExendedUserSerializer(UserSerializer):
# Staff can change any role except is_superuser
pass
else:
raise PermissionDenied(_("You do not have permission to change this user role."))
raise PermissionDenied(
_('You do not have permission to change this user role.')
)
return super().validate(attrs)
class UserCreateSerializer(ExendedUserSerializer):
"""Serializer for creating a new User."""
def validate(self, attrs):
"""Expanded valiadation for auth."""
# Check that the user trying to create a new user is a superuser
if not self.context['request'].user.is_superuser:
raise serializers.ValidationError(_("Only superusers can create new users"))
raise serializers.ValidationError(_('Only superusers can create new users'))
# Generate a random password
password = User.objects.make_random_password(length=14)
@ -451,8 +453,10 @@ class UserCreateSerializer(ExendedUserSerializer):
current_site = Site.objects.get_current()
domain = current_site.domain
instance.email_user(
subject=_(f"Welcome to {current_site.name}"),
message=_(f"Your account has been created.\n\nPlease use the password reset function to get access (at https://{domain})."),
subject=_(f'Welcome to {current_site.name}'),
message=_(
f'Your account has been created.\n\nPlease use the password reset function to get access (at https://{domain}).'
),
)
return instance
@ -509,17 +513,11 @@ class InvenTreeAttachmentSerializer(InvenTreeModelSerializer):
user_detail = UserSerializer(source='user', read_only=True, many=False)
attachment = InvenTreeAttachmentSerializerField(
required=False,
allow_null=False,
)
attachment = InvenTreeAttachmentSerializerField(required=False, allow_null=False)
# The 'filename' field must be present in the serializer
filename = serializers.CharField(
label=_('Filename'),
required=False,
source='basename',
allow_blank=False,
label=_('Filename'), required=False, source='basename', allow_blank=False
)
upload_date = serializers.DateField(read_only=True)
@ -553,7 +551,7 @@ class InvenTreeDecimalField(serializers.FloatField):
try:
return Decimal(str(data))
except Exception:
raise serializers.ValidationError(_("Invalid value"))
raise serializers.ValidationError(_('Invalid value'))
class DataFileUploadSerializer(serializers.Serializer):
@ -570,13 +568,11 @@ class DataFileUploadSerializer(serializers.Serializer):
class Meta:
"""Metaclass options."""
fields = [
'data_file',
]
fields = ['data_file']
data_file = serializers.FileField(
label=_("Data File"),
help_text=_("Select data file for upload"),
label=_('Data File'),
help_text=_('Select data file for upload'),
required=True,
allow_empty_file=False,
)
@ -590,20 +586,16 @@ class DataFileUploadSerializer(serializers.Serializer):
# Remove the leading . from the extension
ext = ext[1:]
accepted_file_types = [
'xls', 'xlsx',
'csv', 'tsv',
'xml',
]
accepted_file_types = ['xls', 'xlsx', 'csv', 'tsv', 'xml']
if ext not in accepted_file_types:
raise serializers.ValidationError(_("Unsupported file type"))
raise serializers.ValidationError(_('Unsupported file type'))
# Impose a 50MB limit on uploaded BOM files
max_upload_file_size = 50 * 1024 * 1024
if data_file.size > max_upload_file_size:
raise serializers.ValidationError(_("File is too large"))
raise serializers.ValidationError(_('File is too large'))
# Read file data into memory (bytes object)
try:
@ -624,10 +616,10 @@ class DataFileUploadSerializer(serializers.Serializer):
raise serializers.ValidationError(str(e))
if len(self.dataset.headers) == 0:
raise serializers.ValidationError(_("No columns found in file"))
raise serializers.ValidationError(_('No columns found in file'))
if len(self.dataset) == 0:
raise serializers.ValidationError(_("No data rows found in file"))
raise serializers.ValidationError(_('No data rows found in file'))
return data_file
@ -721,24 +713,14 @@ class DataFileExtractSerializer(serializers.Serializer):
class Meta:
"""Metaclass options."""
fields = [
'columns',
'rows',
]
fields = ['columns', 'rows']
# Mapping of columns
columns = serializers.ListField(
child=serializers.CharField(
allow_blank=True,
),
)
columns = serializers.ListField(child=serializers.CharField(allow_blank=True))
rows = serializers.ListField(
child=serializers.ListField(
child=serializers.CharField(
allow_blank=True,
allow_null=True,
),
child=serializers.CharField(allow_blank=True, allow_null=True)
)
)
@ -750,10 +732,10 @@ class DataFileExtractSerializer(serializers.Serializer):
self.rows = data.get('rows', [])
if len(self.rows) == 0:
raise serializers.ValidationError(_("No data rows provided"))
raise serializers.ValidationError(_('No data rows provided'))
if len(self.columns) == 0:
raise serializers.ValidationError(_("No data columns supplied"))
raise serializers.ValidationError(_('No data columns supplied'))
self.validate_extracted_columns()
@ -776,16 +758,9 @@ class DataFileExtractSerializer(serializers.Serializer):
processed_row = self.process_row(self.row_to_dict(row))
if processed_row:
rows.append({
"original": row,
"data": processed_row,
})
rows.append({'original': row, 'data': processed_row})
return {
'fields': model_fields,
'columns': self.columns,
'rows': rows,
}
return {'fields': model_fields, 'columns': self.columns, 'rows': rows}
def process_row(self, row):
"""Process a 'row' of data, which is a mapped column:value dict.
@ -799,12 +774,9 @@ class DataFileExtractSerializer(serializers.Serializer):
def row_to_dict(self, row):
"""Convert a "row" to a named data dict."""
row_dict = {
'errors': {},
}
row_dict = {'errors': {}}
for idx, value in enumerate(row):
if idx < len(self.columns):
col = self.columns[idx]
@ -824,16 +796,16 @@ class DataFileExtractSerializer(serializers.Serializer):
cols_seen = set()
for name, field in model_fields.items():
required = field.get('required', False)
# Check for missing required columns
if required:
if name not in self.columns:
raise serializers.ValidationError(_(f"Missing required column: '{name}'"))
raise serializers.ValidationError(
_(f"Missing required column: '{name}'")
)
for col in self.columns:
if not col:
continue
@ -856,16 +828,14 @@ class RemoteImageMixin(metaclass=serializers.SerializerMetaclass):
def skip_create_fields(self):
"""Ensure the 'remote_image' field is skipped when creating a new instance"""
return [
'remote_image',
]
return ['remote_image']
remote_image = serializers.URLField(
required=False,
allow_blank=False,
write_only=True,
label=_("Remote Image"),
help_text=_("URL of remote image file"),
label=_('Remote Image'),
help_text=_('URL of remote image file'),
)
def validate_remote_image(self, url):
@ -877,8 +847,12 @@ class RemoteImageMixin(metaclass=serializers.SerializerMetaclass):
if not url:
return
if not common_models.InvenTreeSetting.get_setting('INVENTREE_DOWNLOAD_FROM_URL'):
raise ValidationError(_("Downloading images from remote URL is not enabled"))
if not common_models.InvenTreeSetting.get_setting(
'INVENTREE_DOWNLOAD_FROM_URL'
):
raise ValidationError(
_('Downloading images from remote URL is not enabled')
)
try:
self.remote_image_file = download_image_from_url(url)

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,5 @@
"""API endpoints for social authentication with allauth."""
import logging
from importlib import import_module
@ -6,8 +7,7 @@ from django.urls import NoReverseMatch, include, path, reverse
from allauth.account.models import EmailAddress
from allauth.socialaccount import providers
from allauth.socialaccount.providers.oauth2.views import (OAuth2Adapter,
OAuth2LoginView)
from allauth.socialaccount.providers.oauth2.views import OAuth2Adapter, OAuth2LoginView
from drf_spectacular.utils import OpenApiResponse, extend_schema
from rest_framework.exceptions import NotFound
from rest_framework.permissions import AllowAny, IsAuthenticated
@ -23,6 +23,7 @@ logger = logging.getLogger('inventree')
class GenericOAuth2ApiLoginView(OAuth2LoginView):
"""Api view to login a user with a social account"""
def dispatch(self, request, *args, **kwargs):
"""Dispatch the regular login view directly."""
return self.login(request, *args, **kwargs)
@ -44,8 +45,16 @@ class GenericOAuth2ApiConnectView(GenericOAuth2ApiLoginView):
def handle_oauth2(adapter: OAuth2Adapter):
"""Define urls for oauth2 endpoints."""
return [
path('login/', GenericOAuth2ApiLoginView.adapter_view(adapter), name=f'{provider.id}_api_login'),
path('connect/', GenericOAuth2ApiConnectView.adapter_view(adapter), name=f'{provider.id}_api_connect'),
path(
'login/',
GenericOAuth2ApiLoginView.adapter_view(adapter),
name=f'{provider.id}_api_login',
),
path(
'connect/',
GenericOAuth2ApiConnectView.adapter_view(adapter),
name=f'{provider.id}_api_connect',
),
]
@ -64,15 +73,20 @@ social_auth_urlpatterns = []
provider_urlpatterns = []
for name, provider in providers.registry.provider_map.items():
try:
prov_mod = import_module(provider.get_package() + ".views")
prov_mod = import_module(provider.get_package() + '.views')
except ImportError:
logger.exception("Could not import authentication provider %s", name)
logger.exception('Could not import authentication provider %s', name)
continue
# Try to extract the adapter class
adapters = [cls for cls in prov_mod.__dict__.values() if isinstance(cls, type) and not cls == OAuth2Adapter and issubclass(cls, OAuth2Adapter)]
adapters = [
cls
for cls in prov_mod.__dict__.values()
if isinstance(cls, type)
and not cls == OAuth2Adapter
and issubclass(cls, OAuth2Adapter)
]
# Get urls
urls = []
@ -80,10 +94,17 @@ for name, provider in providers.registry.provider_map.items():
urls = handle_oauth2(adapter=adapters[0])
else:
if provider.id in legacy:
logger.warning('`%s` is not supported on platform UI. Use `%s` instead.', provider.id, legacy[provider.id])
logger.warning(
'`%s` is not supported on platform UI. Use `%s` instead.',
provider.id,
legacy[provider.id],
)
continue
else:
logger.error('Found handler that is not yet ready for platform UI: `%s`. Open an feature request on GitHub if you need it implemented.', provider.id)
logger.error(
'Found handler that is not yet ready for platform UI: `%s`. Open an feature request on GitHub if you need it implemented.',
provider.id,
)
continue
provider_urlpatterns += [path(f'{provider.id}/', include(urls))]
@ -93,6 +114,7 @@ social_auth_urlpatterns += provider_urlpatterns
class SocialProviderListView(ListAPI):
"""List of available social providers."""
permission_classes = (AllowAny,)
def get(self, request, *args, **kwargs):
@ -102,21 +124,27 @@ class SocialProviderListView(ListAPI):
provider_data = {
'id': provider.id,
'name': provider.name,
'configured': False
'configured': False,
}
try:
provider_data['login'] = request.build_absolute_uri(reverse(f'{provider.id}_api_login'))
provider_data['login'] = request.build_absolute_uri(
reverse(f'{provider.id}_api_login')
)
except NoReverseMatch:
provider_data['login'] = None
try:
provider_data['connect'] = request.build_absolute_uri(reverse(f'{provider.id}_api_connect'))
provider_data['connect'] = request.build_absolute_uri(
reverse(f'{provider.id}_api_connect')
)
except NoReverseMatch:
provider_data['connect'] = None
provider_data['configured'] = InvenTree.sso.check_provider(provider)
provider_data['display_name'] = InvenTree.sso.provider_display_name(provider)
provider_data['display_name'] = InvenTree.sso.provider_display_name(
provider
)
provider_list.append(provider_data)
@ -124,7 +152,7 @@ class SocialProviderListView(ListAPI):
'sso_enabled': InvenTree.sso.login_enabled(),
'sso_registration': InvenTree.sso.registration_enabled(),
'mfa_required': InvenTreeSetting.get_setting('LOGIN_ENFORCE_MFA'),
'providers': provider_list
'providers': provider_list,
}
return Response(data)
@ -151,6 +179,7 @@ class EmptyEmailAddressSerializer(InvenTreeModelSerializer):
class EmailListView(ListCreateAPI):
"""List of registered email addresses for current users."""
permission_classes = (IsAuthenticated,)
serializer_class = EmailAddressSerializer
@ -161,12 +190,15 @@ class EmailListView(ListCreateAPI):
class EmailActionMixin(CreateAPI):
"""Mixin to modify email addresses for current users."""
serializer_class = EmptyEmailAddressSerializer
permission_classes = (IsAuthenticated,)
def get_queryset(self):
"""Filter queryset for current user."""
return EmailAddress.objects.filter(user=self.request.user, pk=self.kwargs['pk']).first()
return EmailAddress.objects.filter(
user=self.request.user, pk=self.kwargs['pk']
).first()
@extend_schema(responses={200: OpenApiResponse(response=EmailAddressSerializer)})
def post(self, request, *args, **kwargs):

View File

@ -1,6 +1,5 @@
"""Helper functions for Single Sign On functionality"""
import logging
from common.models import InvenTreeSetting
@ -49,7 +48,7 @@ def check_provider(provider, raise_error=False):
if allauth.app_settings.SITES_ENABLED:
# At least one matching site must be specified
if not app.sites.exists():
logger.error("SocialApp %s has no sites configured", app)
logger.error('SocialApp %s has no sites configured', app)
return False
# At this point, we assume that the provider is correctly configured

View File

@ -13,7 +13,7 @@ from django_q.status import Stat
import InvenTree.email
import InvenTree.ready
logger = logging.getLogger("inventree")
logger = logging.getLogger('inventree')
def is_worker_running(**kwargs):
@ -33,9 +33,7 @@ def is_worker_running(**kwargs):
now = timezone.now()
past = now - timedelta(minutes=10)
results = Success.objects.filter(
started__gte=past
)
results = Success.objects.filter(started__gte=past)
# If any results are returned, then the background worker is running!
try:
@ -65,13 +63,13 @@ def check_system_health(**kwargs):
if not is_worker_running(**kwargs): # pragma: no cover
result = False
logger.warning(_("Background worker check failed"))
logger.warning(_('Background worker check failed'))
if not InvenTree.email.is_email_configured(): # pragma: no cover
result = False
logger.warning(_("Email backend not configured"))
logger.warning(_('Email backend not configured'))
if not result: # pragma: no cover
logger.warning(_("InvenTree system health checks failed"))
logger.warning(_('InvenTree system health checks failed'))
return result

View File

@ -9,68 +9,68 @@ class PurchaseOrderStatus(StatusCode):
"""Defines a set of status codes for a PurchaseOrder."""
# Order status codes
PENDING = 10, _("Pending"), 'secondary' # Order is pending (not yet placed)
PLACED = 20, _("Placed"), 'primary' # Order has been placed with supplier
COMPLETE = 30, _("Complete"), 'success' # Order has been completed
CANCELLED = 40, _("Cancelled"), 'danger' # Order was cancelled
LOST = 50, _("Lost"), 'warning' # Order was lost
RETURNED = 60, _("Returned"), 'warning' # Order was returned
PENDING = 10, _('Pending'), 'secondary' # Order is pending (not yet placed)
PLACED = 20, _('Placed'), 'primary' # Order has been placed with supplier
COMPLETE = 30, _('Complete'), 'success' # Order has been completed
CANCELLED = 40, _('Cancelled'), 'danger' # Order was cancelled
LOST = 50, _('Lost'), 'warning' # Order was lost
RETURNED = 60, _('Returned'), 'warning' # Order was returned
class PurchaseOrderStatusGroups:
"""Groups for PurchaseOrderStatus codes."""
# Open orders
OPEN = [
PurchaseOrderStatus.PENDING.value,
PurchaseOrderStatus.PLACED.value,
]
OPEN = [PurchaseOrderStatus.PENDING.value, PurchaseOrderStatus.PLACED.value]
# Failed orders
FAILED = [
PurchaseOrderStatus.CANCELLED.value,
PurchaseOrderStatus.LOST.value,
PurchaseOrderStatus.RETURNED.value
PurchaseOrderStatus.RETURNED.value,
]
class SalesOrderStatus(StatusCode):
"""Defines a set of status codes for a SalesOrder."""
PENDING = 10, _("Pending"), 'secondary' # Order is pending
IN_PROGRESS = 15, _("In Progress"), 'primary' # Order has been issued, and is in progress
SHIPPED = 20, _("Shipped"), 'success' # Order has been shipped to customer
CANCELLED = 40, _("Cancelled"), 'danger' # Order has been cancelled
LOST = 50, _("Lost"), 'warning' # Order was lost
RETURNED = 60, _("Returned"), 'warning' # Order was returned
PENDING = 10, _('Pending'), 'secondary' # Order is pending
IN_PROGRESS = (
15,
_('In Progress'),
'primary',
) # Order has been issued, and is in progress
SHIPPED = 20, _('Shipped'), 'success' # Order has been shipped to customer
CANCELLED = 40, _('Cancelled'), 'danger' # Order has been cancelled
LOST = 50, _('Lost'), 'warning' # Order was lost
RETURNED = 60, _('Returned'), 'warning' # Order was returned
class SalesOrderStatusGroups:
"""Groups for SalesOrderStatus codes."""
# Open orders
OPEN = [
SalesOrderStatus.PENDING.value,
SalesOrderStatus.IN_PROGRESS.value,
]
OPEN = [SalesOrderStatus.PENDING.value, SalesOrderStatus.IN_PROGRESS.value]
# Completed orders
COMPLETE = [
SalesOrderStatus.SHIPPED.value,
]
COMPLETE = [SalesOrderStatus.SHIPPED.value]
class StockStatus(StatusCode):
"""Status codes for Stock."""
OK = 10, _("OK"), 'success' # Item is OK
ATTENTION = 50, _("Attention needed"), 'warning' # Item requires attention
DAMAGED = 55, _("Damaged"), 'warning' # Item is damaged
DESTROYED = 60, _("Destroyed"), 'danger' # Item is destroyed
REJECTED = 65, _("Rejected"), 'danger' # Item is rejected
LOST = 70, _("Lost"), 'dark' # Item has been lost
QUARANTINED = 75, _("Quarantined"), 'info' # Item has been quarantined and is unavailable
RETURNED = 85, _("Returned"), 'warning' # Item has been returned from a customer
OK = 10, _('OK'), 'success' # Item is OK
ATTENTION = 50, _('Attention needed'), 'warning' # Item requires attention
DAMAGED = 55, _('Damaged'), 'warning' # Item is damaged
DESTROYED = 60, _('Destroyed'), 'danger' # Item is destroyed
REJECTED = 65, _('Rejected'), 'danger' # Item is rejected
LOST = 70, _('Lost'), 'dark' # Item has been lost
QUARANTINED = (
75,
_('Quarantined'),
'info',
) # Item has been quarantined and is unavailable
RETURNED = 85, _('Returned'), 'warning' # Item has been returned from a customer
class StockStatusGroups:
@ -129,7 +129,7 @@ class StockHistoryCode(StatusCode):
BUILD_CONSUMED = 57, _('Consumed by build order')
# Sales order codes
SHIPPED_AGAINST_SALES_ORDER = 60, _("Shipped against Sales Order")
SHIPPED_AGAINST_SALES_ORDER = 60, _('Shipped against Sales Order')
# Purchase order codes
RECEIVED_AGAINST_PURCHASE_ORDER = 70, _('Received against Purchase Order')
@ -145,59 +145,53 @@ class StockHistoryCode(StatusCode):
class BuildStatus(StatusCode):
"""Build status codes."""
PENDING = 10, _("Pending"), 'secondary' # Build is pending / active
PRODUCTION = 20, _("Production"), 'primary' # BuildOrder is in production
CANCELLED = 30, _("Cancelled"), 'danger' # Build was cancelled
COMPLETE = 40, _("Complete"), 'success' # Build is complete
PENDING = 10, _('Pending'), 'secondary' # Build is pending / active
PRODUCTION = 20, _('Production'), 'primary' # BuildOrder is in production
CANCELLED = 30, _('Cancelled'), 'danger' # Build was cancelled
COMPLETE = 40, _('Complete'), 'success' # Build is complete
class BuildStatusGroups:
"""Groups for BuildStatus codes."""
ACTIVE_CODES = [
BuildStatus.PENDING.value,
BuildStatus.PRODUCTION.value,
]
ACTIVE_CODES = [BuildStatus.PENDING.value, BuildStatus.PRODUCTION.value]
class ReturnOrderStatus(StatusCode):
"""Defines a set of status codes for a ReturnOrder"""
# Order is pending, waiting for receipt of items
PENDING = 10, _("Pending"), 'secondary'
PENDING = 10, _('Pending'), 'secondary'
# Items have been received, and are being inspected
IN_PROGRESS = 20, _("In Progress"), 'primary'
IN_PROGRESS = 20, _('In Progress'), 'primary'
COMPLETE = 30, _("Complete"), 'success'
CANCELLED = 40, _("Cancelled"), 'danger'
COMPLETE = 30, _('Complete'), 'success'
CANCELLED = 40, _('Cancelled'), 'danger'
class ReturnOrderStatusGroups:
"""Groups for ReturnOrderStatus codes."""
OPEN = [
ReturnOrderStatus.PENDING.value,
ReturnOrderStatus.IN_PROGRESS.value,
]
OPEN = [ReturnOrderStatus.PENDING.value, ReturnOrderStatus.IN_PROGRESS.value]
class ReturnOrderLineStatus(StatusCode):
"""Defines a set of status codes for a ReturnOrderLineItem"""
PENDING = 10, _("Pending"), 'secondary'
PENDING = 10, _('Pending'), 'secondary'
# Item is to be returned to customer, no other action
RETURN = 20, _("Return"), 'success'
RETURN = 20, _('Return'), 'success'
# Item is to be repaired, and returned to customer
REPAIR = 30, _("Repair"), 'primary'
REPAIR = 30, _('Repair'), 'primary'
# Item is to be replaced (new item shipped)
REPLACE = 40, _("Replace"), 'warning'
REPLACE = 40, _('Replace'), 'warning'
# Item is to be refunded (cannot be repaired)
REFUND = 50, _("Refund"), 'info'
REFUND = 50, _('Refund'), 'info'
# Item is rejected
REJECT = 60, _("Reject"), 'danger'
REJECT = 60, _('Reject'), 'danger'

View File

@ -16,20 +16,22 @@ from django.core.exceptions import AppRegistryNotReady
from django.core.management import call_command
from django.db import DEFAULT_DB_ALIAS, connections
from django.db.migrations.executor import MigrationExecutor
from django.db.utils import (NotSupportedError, OperationalError,
ProgrammingError)
from django.db.utils import NotSupportedError, OperationalError, ProgrammingError
from django.utils import timezone
import requests
from maintenance_mode.core import (get_maintenance_mode, maintenance_mode_on,
set_maintenance_mode)
from maintenance_mode.core import (
get_maintenance_mode,
maintenance_mode_on,
set_maintenance_mode,
)
from InvenTree.config import get_setting
from plugin import registry
from .version import isInvenTreeUpToDate
logger = logging.getLogger("inventree")
logger = logging.getLogger('inventree')
def schedule_task(taskname, **kwargs):
@ -44,7 +46,7 @@ def schedule_task(taskname, **kwargs):
try:
from django_q.models import Schedule
except AppRegistryNotReady: # pragma: no cover
logger.info("Could not start background tasks - App registry not ready")
logger.info('Could not start background tasks - App registry not ready')
return
try:
@ -57,11 +59,7 @@ def schedule_task(taskname, **kwargs):
else:
logger.info("Creating scheduled task '%s'", taskname)
Schedule.objects.create(
name=taskname,
func=taskname,
**kwargs
)
Schedule.objects.create(name=taskname, func=taskname, **kwargs)
except (OperationalError, ProgrammingError): # pragma: no cover
# Required if the DB is not ready yet
pass
@ -96,7 +94,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool:
from InvenTree.ready import isInTestMode
if n_days <= 0:
logger.info("Specified interval for task '%s' < 1 - task will not run", task_name)
logger.info(
"Specified interval for task '%s' < 1 - task will not run", task_name
)
return False
# Sleep a random number of seconds to prevent worker conflict
@ -119,7 +119,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool:
threshold = datetime.now() - timedelta(days=n_days)
if last_success > threshold:
logger.info("Last successful run for '%s' was too recent - skipping task", task_name)
logger.info(
"Last successful run for '%s' was too recent - skipping task", task_name
)
return False
# Check for any information we have about this task
@ -136,7 +138,9 @@ def check_daily_holdoff(task_name: str, n_days: int = 1) -> bool:
threshold = datetime.now() - timedelta(hours=12)
if last_attempt > threshold:
logger.info("Last attempt for '%s' was too recent - skipping task", task_name)
logger.info(
"Last attempt for '%s' was too recent - skipping task", task_name
)
return False
# Record this attempt
@ -152,17 +156,23 @@ def record_task_attempt(task_name: str):
logger.info("Logging task attempt for '%s'", task_name)
InvenTreeSetting.set_setting(f'_{task_name}_ATTEMPT', datetime.now().isoformat(), None)
InvenTreeSetting.set_setting(
f'_{task_name}_ATTEMPT', datetime.now().isoformat(), None
)
def record_task_success(task_name: str):
"""Record that a multi-day task was successful *now*"""
from common.models import InvenTreeSetting
InvenTreeSetting.set_setting(f'_{task_name}_SUCCESS', datetime.now().isoformat(), None)
InvenTreeSetting.set_setting(
f'_{task_name}_SUCCESS', datetime.now().isoformat(), None
)
def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs) -> bool:
def offload_task(
taskname, *args, force_async=False, force_sync=False, **kwargs
) -> bool:
"""Create an AsyncTask if workers are running. This is different to a 'scheduled' task, in that it only runs once!
If workers are not running or force_sync flag, is set then the task is ran synchronously.
@ -205,7 +215,6 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs)
raise_warning(f"WARNING: '{taskname}' not offloaded due to {str(exc)}")
return False
else:
if callable(taskname):
# function was passed - use that
_func = taskname
@ -215,14 +224,18 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs)
app, mod, func = taskname.split('.')
app_mod = app + '.' + mod
except ValueError:
raise_warning(f"WARNING: '{taskname}' not started - Malformed function path")
raise_warning(
f"WARNING: '{taskname}' not started - Malformed function path"
)
return False
# Import module from app
try:
_mod = importlib.import_module(app_mod)
except ModuleNotFoundError:
raise_warning(f"WARNING: '{taskname}' not started - No module named '{app_mod}'")
raise_warning(
f"WARNING: '{taskname}' not started - No module named '{app_mod}'"
)
return False
# Retrieve function
@ -236,7 +249,9 @@ def offload_task(taskname, *args, force_async=False, force_sync=False, **kwargs)
if not _func:
_func = eval(func) # pragma: no cover
except NameError:
raise_warning(f"WARNING: '{taskname}' not started - No function named '{func}'")
raise_warning(
f"WARNING: '{taskname}' not started - No function named '{func}'"
)
return False
# Workers are not running: run it as synchronous task
@ -263,18 +278,19 @@ class ScheduledTask:
interval: str
minutes: int = None
MINUTES = "I"
HOURLY = "H"
DAILY = "D"
WEEKLY = "W"
MONTHLY = "M"
QUARTERLY = "Q"
YEARLY = "Y"
MINUTES = 'I'
HOURLY = 'H'
DAILY = 'D'
WEEKLY = 'W'
MONTHLY = 'M'
QUARTERLY = 'Q'
YEARLY = 'Y'
TYPE = [MINUTES, HOURLY, DAILY, WEEKLY, MONTHLY, QUARTERLY, YEARLY]
class TaskRegister:
"""Registry for periodic tasks."""
task_list: List[ScheduledTask] = []
def register(self, task, schedule, minutes: int = None):
@ -320,6 +336,7 @@ def scheduled_task(interval: str, minutes: int = None, tasklist: TaskRegister =
_tasks.register(admin_class, interval, minutes=minutes)
return admin_class
return _task_wrapper
@ -332,7 +349,7 @@ def heartbeat():
try:
from django_q.models import Success
except AppRegistryNotReady: # pragma: no cover
logger.info("Could not perform heartbeat task - App registry not ready")
logger.info('Could not perform heartbeat task - App registry not ready')
return
threshold = timezone.now() - timedelta(minutes=30)
@ -340,8 +357,7 @@ def heartbeat():
# Delete heartbeat results more than half an hour old,
# otherwise they just create extra noise
heartbeats = Success.objects.filter(
func='InvenTree.tasks.heartbeat',
started__lte=threshold
func='InvenTree.tasks.heartbeat', started__lte=threshold
)
heartbeats.delete()
@ -359,16 +375,16 @@ def delete_successful_tasks():
threshold = timezone.now() - timedelta(days=days)
# Delete successful tasks
results = Success.objects.filter(
started__lte=threshold
)
results = Success.objects.filter(started__lte=threshold)
if results.count() > 0:
logger.info("Deleting %s successful task records", results.count())
logger.info('Deleting %s successful task records', results.count())
results.delete()
except AppRegistryNotReady: # pragma: no cover
logger.info("Could not perform 'delete_successful_tasks' - App registry not ready")
logger.info(
"Could not perform 'delete_successful_tasks' - App registry not ready"
)
@scheduled_task(ScheduledTask.DAILY)
@ -383,12 +399,10 @@ def delete_failed_tasks():
threshold = timezone.now() - timedelta(days=days)
# Delete failed tasks
results = Failure.objects.filter(
started__lte=threshold
)
results = Failure.objects.filter(started__lte=threshold)
if results.count() > 0:
logger.info("Deleting %s failed task records", results.count())
logger.info('Deleting %s failed task records', results.count())
results.delete()
except AppRegistryNotReady: # pragma: no cover
@ -406,47 +420,48 @@ def delete_old_error_logs():
days = InvenTreeSetting.get_setting('INVENTREE_DELETE_ERRORS_DAYS', 30)
threshold = timezone.now() - timedelta(days=days)
errors = Error.objects.filter(
when__lte=threshold,
)
errors = Error.objects.filter(when__lte=threshold)
if errors.count() > 0:
logger.info("Deleting %s old error logs", errors.count())
logger.info('Deleting %s old error logs', errors.count())
errors.delete()
except AppRegistryNotReady: # pragma: no cover
# Apps not yet loaded
logger.info("Could not perform 'delete_old_error_logs' - App registry not ready")
logger.info(
"Could not perform 'delete_old_error_logs' - App registry not ready"
)
@scheduled_task(ScheduledTask.DAILY)
def delete_old_notifications():
"""Delete old notification logs"""
try:
from common.models import (InvenTreeSetting, NotificationEntry,
NotificationMessage)
from common.models import (
InvenTreeSetting,
NotificationEntry,
NotificationMessage,
)
days = InvenTreeSetting.get_setting('INVENTREE_DELETE_NOTIFICATIONS_DAYS', 30)
threshold = timezone.now() - timedelta(days=days)
items = NotificationEntry.objects.filter(
updated__lte=threshold
)
items = NotificationEntry.objects.filter(updated__lte=threshold)
if items.count() > 0:
logger.info("Deleted %s old notification entries", items.count())
logger.info('Deleted %s old notification entries', items.count())
items.delete()
items = NotificationMessage.objects.filter(
creation__lte=threshold
)
items = NotificationMessage.objects.filter(creation__lte=threshold)
if items.count() > 0:
logger.info("Deleted %s old notification messages", items.count())
logger.info('Deleted %s old notification messages', items.count())
items.delete()
except AppRegistryNotReady:
logger.info("Could not perform 'delete_old_notifications' - App registry not ready")
logger.info(
"Could not perform 'delete_old_notifications' - App registry not ready"
)
@scheduled_task(ScheduledTask.DAILY)
@ -460,13 +475,17 @@ def check_for_updates():
logger.info("Could not perform 'check_for_updates' - App registry not ready")
return
interval = int(common.models.InvenTreeSetting.get_setting('INVENTREE_UPDATE_CHECK_INTERVAL', 7, cache=False))
interval = int(
common.models.InvenTreeSetting.get_setting(
'INVENTREE_UPDATE_CHECK_INTERVAL', 7, cache=False
)
)
# Check if we should check for updates *today*
if not check_daily_holdoff('check_for_updates', interval):
return
logger.info("Checking for InvenTree software updates")
logger.info('Checking for InvenTree software updates')
headers = {}
@ -475,15 +494,17 @@ def check_for_updates():
token = os.getenv('GITHUB_TOKEN', None)
if token:
headers['Authorization'] = f"Bearer {token}"
headers['Authorization'] = f'Bearer {token}'
response = requests.get(
'https://api.github.com/repos/inventree/inventree/releases/latest',
headers=headers
headers=headers,
)
if response.status_code != 200:
raise ValueError(f'Unexpected status code from GitHub API: {response.status_code}') # pragma: no cover
raise ValueError(
f'Unexpected status code from GitHub API: {response.status_code}'
) # pragma: no cover
data = json.loads(response.text)
@ -492,7 +513,7 @@ def check_for_updates():
if not tag:
raise ValueError("'tag_name' missing from GitHub response") # pragma: no cover
match = re.match(r"^.*(\d+)\.(\d+)\.(\d+).*$", tag)
match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', tag)
if len(match.groups()) != 3: # pragma: no cover
logger.warning("Version '%s' did not match expected pattern", tag)
@ -506,31 +527,26 @@ def check_for_updates():
logger.info("Latest InvenTree version: '%s'", tag)
# Save the version to the database
common.models.InvenTreeSetting.set_setting(
'_INVENTREE_LATEST_VERSION',
tag,
None
)
common.models.InvenTreeSetting.set_setting('_INVENTREE_LATEST_VERSION', tag, None)
# Record that this task was successful
record_task_success('check_for_updates')
# Send notification if there is a new version
if not isInvenTreeUpToDate():
logger.warning("InvenTree is not up-to-date, sending notification")
logger.warning('InvenTree is not up-to-date, sending notification')
plg = registry.get_plugin('InvenTreeCoreNotificationsPlugin')
if not plg:
logger.warning("Cannot send notification - plugin not found")
logger.warning('Cannot send notification - plugin not found')
return
plg = plg.plugin_config()
if not plg:
logger.warning("Cannot send notification - plugin config not found")
logger.warning('Cannot send notification - plugin config not found')
return
# Send notification
trigger_superuser_notification(
plg,
f'An update for InvenTree to version {tag} is available',
plg, f'An update for InvenTree to version {tag} is available'
)
@ -549,17 +565,21 @@ def update_exchange_rates(force: bool = False):
from InvenTree.exchange import InvenTreeExchange
except AppRegistryNotReady: # pragma: no cover
# Apps not yet loaded!
logger.info("Could not perform 'update_exchange_rates' - App registry not ready")
logger.info(
"Could not perform 'update_exchange_rates' - App registry not ready"
)
return
except Exception as exc: # pragma: no cover
logger.info("Could not perform 'update_exchange_rates' - %s", exc)
return
if not force:
interval = int(InvenTreeSetting.get_setting('CURRENCY_UPDATE_INTERVAL', 1, cache=False))
interval = int(
InvenTreeSetting.get_setting('CURRENCY_UPDATE_INTERVAL', 1, cache=False)
)
if not check_daily_holdoff('update_exchange_rates', interval):
logger.info("Skipping exchange rate update (interval not reached)")
logger.info('Skipping exchange rate update (interval not reached)')
return
backend = InvenTreeExchange()
@ -570,15 +590,17 @@ def update_exchange_rates(force: bool = False):
backend.update_rates(base_currency=base)
# Remove any exchange rates which are not in the provided currencies
Rate.objects.filter(backend="InvenTreeExchange").exclude(currency__in=currency_codes()).delete()
Rate.objects.filter(backend='InvenTreeExchange').exclude(
currency__in=currency_codes()
).delete()
# Record successful task execution
record_task_success('update_exchange_rates')
except (AppRegistryNotReady, OperationalError, ProgrammingError):
logger.warning("Could not update exchange rates - database not ready")
logger.warning('Could not update exchange rates - database not ready')
except Exception as e: # pragma: no cover
logger.exception("Error updating exchange rates: %s", str(type(e)))
logger.exception('Error updating exchange rates: %s', str(type(e)))
@scheduled_task(ScheduledTask.DAILY)
@ -590,16 +612,20 @@ def run_backup():
# Backups are not enabled - exit early
return
interval = int(InvenTreeSetting.get_setting('INVENTREE_BACKUP_DAYS', 1, cache=False))
interval = int(
InvenTreeSetting.get_setting('INVENTREE_BACKUP_DAYS', 1, cache=False)
)
# Check if should run this task *today*
if not check_daily_holdoff('run_backup', interval):
return
logger.info("Performing automated database backup task")
logger.info('Performing automated database backup task')
call_command("dbbackup", noinput=True, clean=True, compress=True, interactive=False)
call_command("mediabackup", noinput=True, clean=True, compress=True, interactive=False)
call_command('dbbackup', noinput=True, clean=True, compress=True, interactive=False)
call_command(
'mediabackup', noinput=True, clean=True, compress=True, interactive=False
)
# Record that this task was successful
record_task_success('run_backup')
@ -627,7 +653,7 @@ def check_for_migrations():
logger.info('There are %s pending migrations', n)
InvenTreeSetting.set_setting('_PENDING_MIGRATIONS', n, None)
logger.info("Checking for pending database migrations")
logger.info('Checking for pending database migrations')
# Force plugin registry reload
registry.check_reload()
@ -645,12 +671,12 @@ def check_for_migrations():
# Test if auto-updates are enabled
if not get_setting('INVENTREE_AUTO_UPDATE', 'auto_update'):
logger.info("Auto-update is disabled - skipping migrations")
logger.info('Auto-update is disabled - skipping migrations')
return
# Log open migrations
for migration in plan:
logger.info("- %s", str(migration[0]))
logger.info('- %s', str(migration[0]))
# Set the application to maintenance mode - no access from now on.
set_maintenance_mode(True)
@ -668,13 +694,13 @@ def check_for_migrations():
else:
set_pending_migrations(0)
logger.info("Completed %s migrations", n)
logger.info('Completed %s migrations', n)
# Make sure we are out of maintenance mode
if get_maintenance_mode():
logger.warning("Maintenance mode was not disabled - forcing it now")
logger.warning('Maintenance mode was not disabled - forcing it now')
set_maintenance_mode(False)
logger.info("Manually released maintenance mode")
logger.info('Manually released maintenance mode')
# We should be current now - triggering full reload to make sure all models
# are loaded fully in their new state.

View File

@ -17,6 +17,7 @@ class HTMLAPITests(InvenTreeTestCase):
which raised an AssertionError when using the HTML API interface,
while the regular JSON interface continued to work as expected.
"""
roles = 'all'
def test_part_api(self):
@ -60,12 +61,7 @@ class HTMLAPITests(InvenTreeTestCase):
class APITests(InvenTreeAPITestCase):
"""Tests for the InvenTree API."""
fixtures = [
'location',
'category',
'part',
'stock'
]
fixtures = ['location', 'category', 'part', 'stock']
token = None
auto_login = False
@ -73,11 +69,11 @@ class APITests(InvenTreeAPITestCase):
"""Helper function to use basic auth."""
# Use basic authentication
authstring = bytes("{u}:{p}".format(u=self.username, p=self.password), "ascii")
authstring = bytes('{u}:{p}'.format(u=self.username, p=self.password), 'ascii')
# Use "basic" auth by default
auth = b64encode(authstring).decode("ascii")
self.client.credentials(HTTP_AUTHORIZATION="Basic {auth}".format(auth=auth))
auth = b64encode(authstring).decode('ascii')
self.client.credentials(HTTP_AUTHORIZATION='Basic {auth}'.format(auth=auth))
def tokenAuth(self):
"""Helper function to use token auth."""
@ -270,33 +266,20 @@ class BulkDeleteTests(InvenTreeAPITestCase):
url = reverse('api-stock-test-result-list')
# DELETE without any of the required fields
response = self.delete(
url,
{},
expected_code=400
)
response = self.delete(url, {}, expected_code=400)
self.assertIn('List of items or filters must be provided for bulk deletion', str(response.data))
self.assertIn(
'List of items or filters must be provided for bulk deletion',
str(response.data),
)
# DELETE with invalid 'items'
response = self.delete(
url,
{
'items': {"hello": "world"},
},
expected_code=400,
)
response = self.delete(url, {'items': {'hello': 'world'}}, expected_code=400)
self.assertIn("'items' must be supplied as a list object", str(response.data))
# DELETE with invalid 'filters'
response = self.delete(
url,
{
'filters': [1, 2, 3],
},
expected_code=400,
)
response = self.delete(url, {'filters': [1, 2, 3]}, expected_code=400)
self.assertIn("'filters' must be supplied as a dict object", str(response.data))
@ -317,11 +300,7 @@ class SearchTests(InvenTreeAPITestCase):
def test_empty(self):
"""Test empty request"""
data = [
'',
None,
{},
]
data = ['', None, {}]
for d in data:
response = self.post(reverse('api-search'), d, expected_code=400)
@ -331,13 +310,8 @@ class SearchTests(InvenTreeAPITestCase):
"""Test individual result types"""
response = self.post(
reverse('api-search'),
{
'search': 'chair',
'limit': 3,
'part': {},
'build': {},
},
expected_code=200
{'search': 'chair', 'limit': 3, 'part': {}, 'build': {}},
expected_code=200,
)
# No build results
@ -354,12 +328,7 @@ class SearchTests(InvenTreeAPITestCase):
# Search for orders
response = self.post(
reverse('api-search'),
{
'search': '01',
'limit': 2,
'purchaseorder': {},
'salesorder': {},
},
{'search': '01', 'limit': 2, 'purchaseorder': {}, 'salesorder': {}},
expected_code=200,
)
@ -392,33 +361,25 @@ class SearchTests(InvenTreeAPITestCase):
'salesorder',
]
query = {
'search': 'c',
'limit': 3,
}
query = {'search': 'c', 'limit': 3}
for mdl in models:
query[mdl] = {}
response = self.post(
reverse('api-search'),
query,
expected_code=200
)
response = self.post(reverse('api-search'), query, expected_code=200)
# Check for 'permission denied' error
for mdl in models:
self.assertEqual(response.data[mdl]['error'], 'User does not have permission to view this model')
self.assertEqual(
response.data[mdl]['error'],
'User does not have permission to view this model',
)
# Assign view roles for some parts
self.assignRole('build.view')
self.assignRole('part.view')
response = self.post(
reverse('api-search'),
query,
expected_code=200
)
response = self.post(reverse('api-search'), query, expected_code=200)
# Check for expected results, based on permissions
# We expect results to be returned for the following model types
@ -438,4 +399,6 @@ class SearchTests(InvenTreeAPITestCase):
self.assertIn('count', result)
else:
self.assertIn('error', result)
self.assertEqual(result['error'], 'User does not have permission to view this model')
self.assertEqual(
result['error'], 'User does not have permission to view this model'
)

View File

@ -1,6 +1,5 @@
"""Tests for api_version."""
from django.urls import reverse
from InvenTree.api_version import INVENTREE_API_VERSION

View File

@ -66,6 +66,7 @@ class MiddlewareTests(InvenTreeTestCase):
def test_error_exceptions(self):
"""Test that ignored errors are not logged."""
def check(excpected_nbr=0):
# Check that errors are empty
errors = Error.objects.all()

View File

@ -67,15 +67,23 @@ class InvenTreeTaskTests(TestCase):
# Error runs
# Malformed taskname
with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTree' not started - Malformed function path"):
with self.assertWarnsMessage(
UserWarning, "WARNING: 'InvenTree' not started - Malformed function path"
):
InvenTree.tasks.offload_task('InvenTree')
# Non existent app
with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTreeABC.test_tasks.doesnotmatter' not started - No module named 'InvenTreeABC.test_tasks'"):
with self.assertWarnsMessage(
UserWarning,
"WARNING: 'InvenTreeABC.test_tasks.doesnotmatter' not started - No module named 'InvenTreeABC.test_tasks'",
):
InvenTree.tasks.offload_task('InvenTreeABC.test_tasks.doesnotmatter')
# Non existent function
with self.assertWarnsMessage(UserWarning, "WARNING: 'InvenTree.test_tasks.doesnotexist' not started - No function named 'doesnotexist'"):
with self.assertWarnsMessage(
UserWarning,
"WARNING: 'InvenTree.test_tasks.doesnotexist' not started - No function named 'doesnotexist'",
):
InvenTree.tasks.offload_task('InvenTree.test_tasks.doesnotexist')
def test_task_hearbeat(self):
@ -86,7 +94,9 @@ class InvenTreeTaskTests(TestCase):
"""Test the task delete_successful_tasks."""
from django_q.models import Success
Success.objects.create(name='abc', func='abc', stopped=threshold, started=threshold_low)
Success.objects.create(
name='abc', func='abc', stopped=threshold, started=threshold_low
)
InvenTree.tasks.offload_task(InvenTree.tasks.delete_successful_tasks)
results = Success.objects.filter(started__lte=threshold)
self.assertEqual(len(results), 0)
@ -99,14 +109,14 @@ class InvenTreeTaskTests(TestCase):
error_obj.save()
# Check that it is not empty
errors = Error.objects.filter(when__lte=threshold,)
errors = Error.objects.filter(when__lte=threshold)
self.assertNotEqual(len(errors), 0)
# Run action
InvenTree.tasks.offload_task(InvenTree.tasks.delete_old_error_logs)
# Check that it is empty again
errors = Error.objects.filter(when__lte=threshold,)
errors = Error.objects.filter(when__lte=threshold)
self.assertEqual(len(errors), 0)
def test_task_check_for_updates(self):
@ -146,7 +156,9 @@ class InvenTreeTaskTests(TestCase):
# Cleanup
try:
migration_name = InvenTree.tasks.get_migration_plan()[0][0].name + '.py'
migration_path = settings.BASE_DIR / 'InvenTree' / 'migrations' / migration_name
migration_path = (
settings.BASE_DIR / 'InvenTree' / 'migrations' / migration_name
)
migration_path.unlink()
except IndexError: # pragma: no cover
pass

View File

@ -44,7 +44,7 @@ class URLTest(TestCase):
('part', 'templates'),
('report', 'templates'),
('stock', 'templates'),
('templates', ),
('templates',),
]
template_files = []
@ -53,11 +53,9 @@ class URLTest(TestCase):
tld = os.path.join(here, '..')
for directory in template_dirs:
template_dir = os.path.join(tld, *directory)
for path in Path(template_dir).rglob(suffix):
f = os.path.abspath(path)
if f not in template_files:
@ -69,25 +67,18 @@ class URLTest(TestCase):
"""Search for all instances of {% url %} in supplied template file."""
urls = []
pattern = "{% url ['\"]([^'\"]+)['\"]([^%]*)%}"
pattern = '{% url [\'"]([^\'"]+)[\'"]([^%]*)%}'
with open(input_file, 'r') as f:
data = f.read()
results = re.findall(pattern, data)
for result in results:
if len(result) == 2:
urls.append([
result[0].strip(),
result[1].strip()
])
urls.append([result[0].strip(), result[1].strip()])
elif len(result) == 1: # pragma: no cover
urls.append([
result[0].strip(),
''
])
urls.append([result[0].strip(), ''])
return urls
@ -100,16 +91,16 @@ class URLTest(TestCase):
pk = None
# TODO: Handle reverse lookup of admin URLs!
if url.startswith("admin:"):
if url.startswith('admin:'):
return
# TODO can this be more elegant?
if url.startswith("account_"):
if url.startswith('account_'):
return
if pk:
# We will assume that there is at least one item in the database
reverse(url, kwargs={"pk": 1})
reverse(url, kwargs={'pk': 1})
else:
reverse(url)
@ -122,14 +113,14 @@ class URLTest(TestCase):
def test_html_templates(self):
"""Test all HTML templates for broken url tags."""
template_files = self.find_files("*.html")
template_files = self.find_files('*.html')
for f in template_files:
self.check_file(f)
def test_js_templates(self):
"""Test all JS templates for broken url tags."""
template_files = self.find_files("*.js")
template_files = self.find_files('*.js')
for f in template_files:
self.check_file(f)

View File

@ -23,13 +23,13 @@ class ViewTests(InvenTreeTestCase):
def test_index_redirect(self):
"""Top-level URL should redirect to "index" page."""
response = self.client.get("/")
response = self.client.get('/')
self.assertEqual(response.status_code, 302)
def get_index_page(self):
"""Retrieve the index page (used for subsequent unit tests)"""
response = self.client.get("/index/")
response = self.client.get('/index/')
self.assertEqual(response.status_code, 200)
@ -60,35 +60,20 @@ class ViewTests(InvenTreeTestCase):
self.assertEqual(response.status_code, 200)
content = response.content.decode()
user_panels = [
'account',
'user-display',
'user-home',
'user-reports',
]
user_panels = ['account', 'user-display', 'user-home', 'user-reports']
staff_panels = [
'server',
'login',
'barcodes',
'pricing',
'parts',
'stock',
]
staff_panels = ['server', 'login', 'barcodes', 'pricing', 'parts', 'stock']
plugin_panels = [
'plugin',
]
plugin_panels = ['plugin']
# Default user has staff access, so all panels will be present
for panel in user_panels + staff_panels + plugin_panels:
self.assertIn(f"select-{panel}", content)
self.assertIn(f"panel-{panel}", content)
self.assertIn(f'select-{panel}', content)
self.assertIn(f'panel-{panel}', content)
# Now create a user who does not have staff access
pleb_user = get_user_model().objects.create_user(
username='pleb',
password='notstaff',
username='pleb', password='notstaff'
)
pleb_user.groups.add(self.group)
@ -98,10 +83,7 @@ class ViewTests(InvenTreeTestCase):
self.client.logout()
result = self.client.login(
username='pleb',
password='notstaff',
)
result = self.client.login(username='pleb', password='notstaff')
self.assertTrue(result)
@ -111,22 +93,24 @@ class ViewTests(InvenTreeTestCase):
# Normal user still has access to user-specific panels
for panel in user_panels:
self.assertIn(f"select-{panel}", content)
self.assertIn(f"panel-{panel}", content)
self.assertIn(f'select-{panel}', content)
self.assertIn(f'panel-{panel}', content)
# Normal user does NOT have access to global or plugin settings
for panel in staff_panels + plugin_panels:
self.assertNotIn(f"select-{panel}", content)
self.assertNotIn(f"panel-{panel}", content)
self.assertNotIn(f'select-{panel}', content)
self.assertNotIn(f'panel-{panel}', content)
def test_url_login(self):
"""Test logging in via arguments"""
# Log out
self.client.logout()
response = self.client.get("/index/")
response = self.client.get('/index/')
self.assertEqual(response.status_code, 302)
# Try login with url
response = self.client.get(f"/accounts/login/?next=/&login={self.username}&password={self.password}")
response = self.client.get(
f'/accounts/login/?next=/&login={self.username}&password={self.password}'
)
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, '/')

View File

@ -45,12 +45,12 @@ class ConversionTest(TestCase):
def test_prefixes(self):
"""Test inputs where prefixes are used"""
tests = {
"3": 3,
"3m": 3,
"3mm": 0.003,
"3k": 3000,
"3u": 0.000003,
"3 inch": 0.0762,
'3': 3,
'3m': 3,
'3mm': 0.003,
'3k': 3000,
'3u': 0.000003,
'3 inch': 0.0762,
}
for val, expected in tests.items():
@ -60,13 +60,13 @@ class ConversionTest(TestCase):
def test_base_units(self):
"""Test conversion to specified base units"""
tests = {
"3": 3,
"3 dozen": 36,
"50 dozen kW": 600000,
"1 / 10": 0.1,
"1/2 kW": 500,
"1/2 dozen kW": 6000,
"0.005 MW": 5000,
'3': 3,
'3 dozen': 36,
'50 dozen kW': 600000,
'1 / 10': 0.1,
'1/2 kW': 500,
'1/2 dozen kW': 6000,
'0.005 MW': 5000,
}
for val, expected in tests.items():
@ -104,13 +104,7 @@ class ConversionTest(TestCase):
def test_invalid_units(self):
"""Test conversion with bad units"""
tests = {
'3': '10',
'13': '-?-',
'-3': 'xyz',
'-12': '-12',
'1/0': '1/0',
}
tests = {'3': '10', '13': '-?-', '-3': 'xyz', '-12': '-12', '1/0': '1/0'}
for val, unit in tests.items():
with self.assertRaises(ValidationError):
@ -118,12 +112,7 @@ class ConversionTest(TestCase):
def test_invalid_values(self):
"""Test conversion of invalid inputs"""
inputs = [
'-x',
'1/0',
'xyz',
'12B45C'
]
inputs = ['-x', '1/0', 'xyz', '12B45C']
for val in inputs:
# Test with a provided unit
@ -148,9 +137,7 @@ class ConversionTest(TestCase):
# Create a new custom unit
CustomUnit.objects.create(
name='fanciful_unit',
definition='henry / mm',
symbol='hpmm',
name='fanciful_unit', definition='henry / mm', symbol='hpmm'
)
# Reload registry
@ -171,7 +158,9 @@ class ConversionTest(TestCase):
for val, expected in tests.items():
# Convert, and leave units
q = InvenTree.conversion.convert_physical_value(val, 'henry / km', strip_units=False)
q = InvenTree.conversion.convert_physical_value(
val, 'henry / km', strip_units=False
)
self.assertAlmostEqual(float(q.magnitude), expected, 2)
# Convert and strip units
@ -184,24 +173,24 @@ class ValidatorTest(TestCase):
def test_overage(self):
"""Test overage validator."""
validate_overage("100%")
validate_overage("10")
validate_overage("45.2 %")
validate_overage('100%')
validate_overage('10')
validate_overage('45.2 %')
with self.assertRaises(django_exceptions.ValidationError):
validate_overage("-1")
validate_overage('-1')
with self.assertRaises(django_exceptions.ValidationError):
validate_overage("-2.04 %")
validate_overage('-2.04 %')
with self.assertRaises(django_exceptions.ValidationError):
validate_overage("105%")
validate_overage('105%')
with self.assertRaises(django_exceptions.ValidationError):
validate_overage("xxx %")
validate_overage('xxx %')
with self.assertRaises(django_exceptions.ValidationError):
validate_overage("aaaa")
validate_overage('aaaa')
def test_url_validation(self):
"""Test for AllowedURLValidator"""
@ -241,7 +230,7 @@ class FormatTest(TestCase):
def test_parse(self):
"""Tests for the 'parse_format_string' function"""
# Extract data from a valid format string
fmt = "PO-{abc:02f}-{ref:04d}-{date}-???"
fmt = 'PO-{abc:02f}-{ref:04d}-{date}-???'
info = InvenTree.format.parse_format_string(fmt)
@ -250,22 +239,17 @@ class FormatTest(TestCase):
self.assertIn('date', info)
# Try with invalid strings
for fmt in [
'PO-{{xyz}',
'PO-{xyz}}',
'PO-{xyz}-{',
]:
for fmt in ['PO-{{xyz}', 'PO-{xyz}}', 'PO-{xyz}-{']:
with self.assertRaises(ValueError):
InvenTree.format.parse_format_string(fmt)
def test_create_regex(self):
"""Test function for creating a regex from a format string"""
tests = {
"PO-123-{ref:04f}": r"^PO\-123\-(?P<ref>.+)$",
"{PO}-???-{ref}-{date}-22": r"^(?P<PO>.+)\-...\-(?P<ref>.+)\-(?P<date>.+)\-22$",
"ABC-123-###-{ref}": r"^ABC\-123\-\d\d\d\-(?P<ref>.+)$",
"ABC-123": r"^ABC\-123$",
'PO-123-{ref:04f}': r'^PO\-123\-(?P<ref>.+)$',
'{PO}-???-{ref}-{date}-22': r'^(?P<PO>.+)\-...\-(?P<ref>.+)\-(?P<date>.+)\-22$',
'ABC-123-###-{ref}': r'^ABC\-123\-\d\d\d\-(?P<ref>.+)$',
'ABC-123': r'^ABC\-123$',
}
for fmt, reg in tests.items():
@ -275,55 +259,44 @@ class FormatTest(TestCase):
"""Test that string validation works as expected"""
# These tests should pass
for value, pattern in {
"ABC-hello-123": "???-{q}-###",
"BO-1234": "BO-{ref}",
"111.222.fred.china": "???.###.{name}.{place}",
"PO-1234": "PO-{ref:04d}"
'ABC-hello-123': '???-{q}-###',
'BO-1234': 'BO-{ref}',
'111.222.fred.china': '???.###.{name}.{place}',
'PO-1234': 'PO-{ref:04d}',
}.items():
self.assertTrue(InvenTree.format.validate_string(value, pattern))
# These tests should fail
for value, pattern in {
"ABC-hello-123": "###-{q}-???",
"BO-1234": "BO.{ref}",
"BO-####": "BO-{pattern}-{next}",
"BO-123d": "BO-{ref:04d}"
'ABC-hello-123': '###-{q}-???',
'BO-1234': 'BO.{ref}',
'BO-####': 'BO-{pattern}-{next}',
'BO-123d': 'BO-{ref:04d}',
}.items():
self.assertFalse(InvenTree.format.validate_string(value, pattern))
def test_extract_value(self):
"""Test that we can extract named values based on a format string"""
# Simple tests based on a straight-forward format string
fmt = "PO-###-{ref:04d}"
fmt = 'PO-###-{ref:04d}'
tests = {
"123": "PO-123-123",
"456": "PO-123-456",
"789": "PO-123-789",
}
tests = {'123': 'PO-123-123', '456': 'PO-123-456', '789': 'PO-123-789'}
for k, v in tests.items():
self.assertEqual(InvenTree.format.extract_named_group('ref', v, fmt), k)
# However these ones should fail
tests = {
'abc': 'PO-123-abc',
'xyz': 'PO-123-xyz',
}
tests = {'abc': 'PO-123-abc', 'xyz': 'PO-123-xyz'}
for v in tests.values():
with self.assertRaises(ValueError):
InvenTree.format.extract_named_group('ref', v, fmt)
# More complex tests
fmt = "PO-{date}-{test}-???-{ref}-###"
val = "PO-2022-02-01-hello-ABC-12345-222"
fmt = 'PO-{date}-{test}-???-{ref}-###'
val = 'PO-2022-02-01-hello-ABC-12345-222'
data = {
'date': '2022-02-01',
'test': 'hello',
'ref': '12345',
}
data = {'date': '2022-02-01', 'test': 'hello', 'ref': '12345'}
for k, v in data.items():
self.assertEqual(InvenTree.format.extract_named_group(k, val, fmt), v)
@ -332,60 +305,46 @@ class FormatTest(TestCase):
# Raises a ValueError as the format string is bad
with self.assertRaises(ValueError):
InvenTree.format.extract_named_group(
"test",
"PO-1234-5",
"PO-{test}-{"
)
InvenTree.format.extract_named_group('test', 'PO-1234-5', 'PO-{test}-{')
# Raises a NameError as the named group does not exist in the format string
with self.assertRaises(NameError):
InvenTree.format.extract_named_group(
"missing",
"PO-12345",
"PO-{test}",
)
InvenTree.format.extract_named_group('missing', 'PO-12345', 'PO-{test}')
# Raises a ValueError as the value does not match the format string
with self.assertRaises(ValueError):
InvenTree.format.extract_named_group(
"test",
"PO-1234",
"PO-{test}-1234",
)
InvenTree.format.extract_named_group('test', 'PO-1234', 'PO-{test}-1234')
with self.assertRaises(ValueError):
InvenTree.format.extract_named_group(
"test",
"PO-ABC-xyz",
"PO-###-{test}",
)
InvenTree.format.extract_named_group('test', 'PO-ABC-xyz', 'PO-###-{test}')
def test_currency_formatting(self):
"""Test that currency formatting works correctly for multiple currencies"""
test_data = (
(Money( 3651.285718, "USD"), 4, "$3,651.2857" ), # noqa: E201,E202
(Money(487587.849178, "CAD"), 5, "CA$487,587.84918"), # noqa: E201,E202
(Money( 0.348102, "EUR"), 1, "€0.3" ), # noqa: E201,E202
(Money( 0.916530, "GBP"), 1, "£0.9" ), # noqa: E201,E202
(Money( 61.031024, "JPY"), 3, "¥61.031" ), # noqa: E201,E202
(Money( 49609.694602, "JPY"), 1, "¥49,609.7" ), # noqa: E201,E202
(Money(155565.264777, "AUD"), 2, "A$155,565.26" ), # noqa: E201,E202
(Money( 0.820437, "CNY"), 4, "CN¥0.8204" ), # noqa: E201,E202
(Money( 7587.849178, "EUR"), 0, "€7,588" ), # noqa: E201,E202
(Money( 0.348102, "GBP"), 3, "£0.348" ), # noqa: E201,E202
(Money( 0.652923, "CHF"), 0, "CHF1" ), # noqa: E201,E202
(Money( 0.820437, "CNY"), 1, "CN¥0.8" ), # noqa: E201,E202
(Money(98789.5295680, "CHF"), 0, "CHF98,790" ), # noqa: E201,E202
(Money( 0.585787, "USD"), 1, "$0.6" ), # noqa: E201,E202
(Money( 0.690541, "CAD"), 3, "CA$0.691" ), # noqa: E201,E202
(Money( 427.814104, "AUD"), 5, "A$427.81410" ), # noqa: E201,E202
(Money(3651.285718, 'USD'), 4, '$3,651.2857'), # noqa: E201,E202
(Money(487587.849178, 'CAD'), 5, 'CA$487,587.84918'), # noqa: E201,E202
(Money(0.348102, 'EUR'), 1, '€0.3'), # noqa: E201,E202
(Money(0.916530, 'GBP'), 1, '£0.9'), # noqa: E201,E202
(Money(61.031024, 'JPY'), 3, '¥61.031'), # noqa: E201,E202
(Money(49609.694602, 'JPY'), 1, '¥49,609.7'), # noqa: E201,E202
(Money(155565.264777, 'AUD'), 2, 'A$155,565.26'), # noqa: E201,E202
(Money(0.820437, 'CNY'), 4, 'CN¥0.8204'), # noqa: E201,E202
(Money(7587.849178, 'EUR'), 0, '€7,588'), # noqa: E201,E202
(Money(0.348102, 'GBP'), 3, '£0.348'), # noqa: E201,E202
(Money(0.652923, 'CHF'), 0, 'CHF1'), # noqa: E201,E202
(Money(0.820437, 'CNY'), 1, 'CN¥0.8'), # noqa: E201,E202
(Money(98789.5295680, 'CHF'), 0, 'CHF98,790'), # noqa: E201,E202
(Money(0.585787, 'USD'), 1, '$0.6'), # noqa: E201,E202
(Money(0.690541, 'CAD'), 3, 'CA$0.691'), # noqa: E201,E202
(Money(427.814104, 'AUD'), 5, 'A$427.81410'), # noqa: E201,E202
)
with self.settings(LANGUAGE_CODE="en-us"):
with self.settings(LANGUAGE_CODE='en-us'):
for value, decimal_places, expected_result in test_data:
result = InvenTree.format.format_money(value, decimal_places=decimal_places)
result = InvenTree.format.format_money(
value, decimal_places=decimal_places
)
assert result == expected_result
@ -394,30 +353,35 @@ class TestHelpers(TestCase):
def test_absolute_url(self):
"""Test helper function for generating an absolute URL"""
base = "https://demo.inventree.org:12345"
base = 'https://demo.inventree.org:12345'
InvenTreeSetting.set_setting('INVENTREE_BASE_URL', base, change_user=None)
tests = {
"": base,
"api/": base + "/api/",
"/api/": base + "/api/",
"api": base + "/api",
"media/label/output/": base + "/media/label/output/",
"static/logo.png": base + "/static/logo.png",
"https://www.google.com": "https://www.google.com",
"https://demo.inventree.org:12345/out.html": "https://demo.inventree.org:12345/out.html",
"https://demo.inventree.org/test.html": "https://demo.inventree.org/test.html",
"http://www.cwi.nl:80/%7Eguido/Python.html": "http://www.cwi.nl:80/%7Eguido/Python.html",
"test.org": base + "/test.org",
'': base,
'api/': base + '/api/',
'/api/': base + '/api/',
'api': base + '/api',
'media/label/output/': base + '/media/label/output/',
'static/logo.png': base + '/static/logo.png',
'https://www.google.com': 'https://www.google.com',
'https://demo.inventree.org:12345/out.html': 'https://demo.inventree.org:12345/out.html',
'https://demo.inventree.org/test.html': 'https://demo.inventree.org/test.html',
'http://www.cwi.nl:80/%7Eguido/Python.html': 'http://www.cwi.nl:80/%7Eguido/Python.html',
'test.org': base + '/test.org',
}
for url, expected in tests.items():
# Test with supplied base URL
self.assertEqual(InvenTree.helpers_model.construct_absolute_url(url, site_url=base), expected)
self.assertEqual(
InvenTree.helpers_model.construct_absolute_url(url, site_url=base),
expected,
)
# Test without supplied base URL
self.assertEqual(InvenTree.helpers_model.construct_absolute_url(url), expected)
self.assertEqual(
InvenTree.helpers_model.construct_absolute_url(url), expected
)
def test_image_url(self):
"""Test if a filename looks like an image."""
@ -453,7 +417,9 @@ class TestHelpers(TestCase):
"""Test static url helpers."""
self.assertEqual(helpers.getStaticUrl('test.jpg'), '/static/test.jpg')
self.assertEqual(helpers.getBlankImage(), '/static/img/blank_image.png')
self.assertEqual(helpers.getBlankThumbnail(), '/static/img/blank_image.thumbnail.png')
self.assertEqual(
helpers.getBlankThumbnail(), '/static/img/blank_image.thumbnail.png'
)
def testMediaUrl(self):
"""Test getMediaUrl."""
@ -476,12 +442,7 @@ class TestHelpers(TestCase):
def test_download_image(self):
"""Test function for downloading image from remote URL"""
# Run check with a sequence of bad URLs
for url in [
"blog",
"htp://test.com/?",
"google",
"\\invalid-url"
]:
for url in ['blog', 'htp://test.com/?', 'google', '\\invalid-url']:
with self.assertRaises(django_exceptions.ValidationError):
InvenTree.helpers_model.download_image_from_url(url)
@ -496,16 +457,17 @@ class TestHelpers(TestCase):
with self.assertRaises(expected_error):
while tries < retries:
try:
InvenTree.helpers_model.download_image_from_url(url, timeout=timeout)
InvenTree.helpers_model.download_image_from_url(
url, timeout=timeout
)
break
except Exception as exc:
if type(exc) is expected_error:
# Re-throw this error
raise exc
else:
print("Unexpected error:", type(exc), exc)
print('Unexpected error:', type(exc), exc)
tries += 1
time.sleep(10 * tries)
@ -518,16 +480,20 @@ class TestHelpers(TestCase):
# TODO: Re-implement this test when we are happier with the external service
# dl_helper("https://httpstat.us/200?sleep=5000", requests.exceptions.ReadTimeout, timeout=1)
large_img = "https://github.com/inventree/InvenTree/raw/master/InvenTree/InvenTree/static/img/paper_splash_large.jpg"
large_img = 'https://github.com/inventree/InvenTree/raw/master/InvenTree/InvenTree/static/img/paper_splash_large.jpg'
InvenTreeSetting.set_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 1, change_user=None)
InvenTreeSetting.set_setting(
'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 1, change_user=None
)
# Attempt to download an image which is too large
with self.assertRaises(ValueError):
InvenTree.helpers_model.download_image_from_url(large_img, timeout=10)
# Increase allowable download size
InvenTreeSetting.set_setting('INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 5, change_user=None)
InvenTreeSetting.set_setting(
'INVENTREE_DOWNLOAD_IMAGE_MAX_SIZE', 5, change_user=None
)
# Download a valid image (should not throw an error)
InvenTree.helpers_model.download_image_from_url(large_img, timeout=10)
@ -561,14 +527,14 @@ class TestIncrement(TestCase):
def tests(self):
"""Test 'intelligent' incrementing function."""
tests = [
("", '1'),
(1, "2"),
("001", "002"),
("1001", "1002"),
("ABC123", "ABC124"),
("XYZ0", "XYZ1"),
("123Q", "123Q"),
("QQQ", "QQQ"),
('', '1'),
(1, '2'),
('001', '002'),
('1001', '1002'),
('ABC123', 'ABC124'),
('XYZ0', 'XYZ1'),
('123Q', '123Q'),
('QQQ', 'QQQ'),
]
for test in tests:
@ -584,13 +550,7 @@ class TestMakeBarcode(TestCase):
def test_barcode_extended(self):
"""Test creation of barcode with extended data."""
bc = helpers.MakeBarcode(
"part",
3,
{
"id": 3,
"url": "www.google.com",
},
brief=False
'part', 3, {'id': 3, 'url': 'www.google.com'}, brief=False
)
self.assertIn('part', bc)
@ -604,10 +564,7 @@ class TestMakeBarcode(TestCase):
def test_barcode_brief(self):
"""Test creation of simple barcode."""
bc = helpers.MakeBarcode(
"stockitem",
7,
)
bc = helpers.MakeBarcode('stockitem', 7)
data = json.loads(bc)
self.assertEqual(len(data), 1)
@ -619,16 +576,14 @@ class TestDownloadFile(TestCase):
def test_download(self):
"""Tests for DownloadFile."""
helpers.DownloadFile("hello world", "out.txt")
helpers.DownloadFile(bytes(b"hello world"), "out.bin")
helpers.DownloadFile('hello world', 'out.txt')
helpers.DownloadFile(bytes(b'hello world'), 'out.bin')
class TestMPTT(TestCase):
"""Tests for the MPTT tree models."""
fixtures = [
'location',
]
fixtures = ['location']
@classmethod
def setUpTestData(cls):
@ -681,62 +636,62 @@ class TestSerialNumberExtraction(TestCase):
e = helpers.extract_serial_numbers
# Test a range of numbers
sn = e("1-5", 5, 1)
sn = e('1-5', 5, 1)
self.assertEqual(len(sn), 5)
for i in range(1, 6):
self.assertIn(str(i), sn)
sn = e("11-30", 20, 1)
sn = e('11-30', 20, 1)
self.assertEqual(len(sn), 20)
sn = e("1, 2, 3, 4, 5", 5, 1)
sn = e('1, 2, 3, 4, 5', 5, 1)
self.assertEqual(len(sn), 5)
# Test partially specifying serials
sn = e("1, 2, 4+", 5, 1)
sn = e('1, 2, 4+', 5, 1)
self.assertEqual(len(sn), 5)
self.assertEqual(sn, ['1', '2', '4', '5', '6'])
# Test groups are not interpolated if enough serials are supplied
sn = e("1, 2, 3, AF5-69H, 5", 5, 1)
sn = e('1, 2, 3, AF5-69H, 5', 5, 1)
self.assertEqual(len(sn), 5)
self.assertEqual(sn, ['1', '2', '3', 'AF5-69H', '5'])
# Test groups are not interpolated with more than one hyphen in a word
sn = e("1, 2, TG-4SR-92, 4+", 5, 1)
sn = e('1, 2, TG-4SR-92, 4+', 5, 1)
self.assertEqual(len(sn), 5)
self.assertEqual(sn, ['1', '2', "TG-4SR-92", '4', '5'])
self.assertEqual(sn, ['1', '2', 'TG-4SR-92', '4', '5'])
# Test multiple placeholders
sn = e("1 2 ~ ~ ~", 5, 2)
sn = e('1 2 ~ ~ ~', 5, 2)
self.assertEqual(len(sn), 5)
self.assertEqual(sn, ['1', '2', '3', '4', '5'])
sn = e("1-5, 10-15", 11, 1)
sn = e('1-5, 10-15', 11, 1)
self.assertIn('3', sn)
self.assertIn('13', sn)
sn = e("1+", 10, 1)
sn = e('1+', 10, 1)
self.assertEqual(len(sn), 10)
self.assertEqual(sn, [str(_) for _ in range(1, 11)])
sn = e("4, 1+2", 4, 1)
sn = e('4, 1+2', 4, 1)
self.assertEqual(len(sn), 4)
self.assertEqual(sn, ['4', '1', '2', '3'])
sn = e("~", 1, 1)
sn = e('~', 1, 1)
self.assertEqual(len(sn), 1)
self.assertEqual(sn, ['2'])
sn = e("~", 1, 3)
sn = e('~', 1, 3)
self.assertEqual(len(sn), 1)
self.assertEqual(sn, ['4'])
sn = e("~+", 2, 4)
sn = e('~+', 2, 4)
self.assertEqual(len(sn), 2)
self.assertEqual(sn, ['5', '6'])
sn = e("~+3", 4, 4)
sn = e('~+3', 4, 4)
self.assertEqual(len(sn), 4)
self.assertEqual(sn, ['5', '6', '7', '8'])
@ -746,70 +701,70 @@ class TestSerialNumberExtraction(TestCase):
# Test duplicates
with self.assertRaises(ValidationError):
e("1,2,3,3,3", 5, 1)
e('1,2,3,3,3', 5, 1)
# Test invalid length
with self.assertRaises(ValidationError):
e("1,2,3", 5, 1)
e('1,2,3', 5, 1)
# Test empty string
with self.assertRaises(ValidationError):
e(", , ,", 0, 1)
e(', , ,', 0, 1)
# Test incorrect sign in group
with self.assertRaises(ValidationError):
e("10-2", 8, 1)
e('10-2', 8, 1)
# Test invalid group
with self.assertRaises(ValidationError):
e("1-5-10", 10, 1)
e('1-5-10', 10, 1)
with self.assertRaises(ValidationError):
e("10, a, 7-70j", 4, 1)
e('10, a, 7-70j', 4, 1)
# Test groups are not interpolated with word characters
with self.assertRaises(ValidationError):
e("1, 2, 3, E-5", 5, 1)
e('1, 2, 3, E-5', 5, 1)
# Extract a range of values with a smaller range
with self.assertRaises(ValidationError) as exc:
e("11-50", 10, 1)
e('11-50', 10, 1)
self.assertIn('Range quantity exceeds 10', str(exc))
# Test groups are not interpolated with alpha characters
with self.assertRaises(ValidationError) as exc:
e("1, A-2, 3+", 5, 1)
e('1, A-2, 3+', 5, 1)
self.assertIn('Invalid group range: A-2', str(exc))
def test_combinations(self):
"""Test complex serial number combinations."""
e = helpers.extract_serial_numbers
sn = e("1 3-5 9+2", 7, 1)
sn = e('1 3-5 9+2', 7, 1)
self.assertEqual(len(sn), 7)
self.assertEqual(sn, ['1', '3', '4', '5', '9', '10', '11'])
sn = e("1,3-5,9+2", 7, 1)
sn = e('1,3-5,9+2', 7, 1)
self.assertEqual(len(sn), 7)
self.assertEqual(sn, ['1', '3', '4', '5', '9', '10', '11'])
sn = e("~+2", 3, 13)
sn = e('~+2', 3, 13)
self.assertEqual(len(sn), 3)
self.assertEqual(sn, ['14', '15', '16'])
sn = e("~+", 2, 13)
sn = e('~+', 2, 13)
self.assertEqual(len(sn), 2)
self.assertEqual(sn, ['14', '15'])
# Test multiple increment groups
sn = e("~+4, 20+4, 30+4", 15, 10)
sn = e('~+4, 20+4, 30+4', 15, 10)
self.assertEqual(len(sn), 15)
for v in [14, 24, 34]:
self.assertIn(str(v), sn)
# Test multiple range groups
sn = e("11-20, 41-50, 91-100", 30, 1)
sn = e('11-20, 41-50, 91-100', 30, 1)
self.assertEqual(len(sn), 30)
for v in range(11, 21):
@ -848,7 +803,7 @@ class TestVersionNumber(TestCase):
"""Test that the git commit information is extracted successfully."""
envs = {
'INVENTREE_COMMIT_HASH': 'abcdef',
'INVENTREE_COMMIT_DATE': '2022-12-31'
'INVENTREE_COMMIT_DATE': '2022-12-31',
}
# Check that the environment variables take priority
@ -861,10 +816,16 @@ class TestVersionNumber(TestCase):
# Check that the current .git values work too
hash = str(subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8').strip()
hash = str(
subprocess.check_output('git rev-parse --short HEAD'.split()), 'utf-8'
).strip()
self.assertEqual(hash, version.inventreeCommitHash())
d = str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8').strip().split(' ')[0]
d = (
str(subprocess.check_output('git show -s --format=%ci'.split()), 'utf-8')
.strip()
.split(' ')[0]
)
self.assertEqual(d, version.inventreeCommitDate())
@ -898,7 +859,7 @@ class CurrencyTests(TestCase):
break
else: # pragma: no cover
print("Exchange rate update failed - retrying")
print('Exchange rate update failed - retrying')
print(f'Expected {currency_codes()}, got {[a.currency for a in rates]}')
time.sleep(1)
@ -975,16 +936,14 @@ class TestSettings(InvenTreeTestCase):
self.assertEqual(user_count(), 1)
# not enough set
self.run_reload({
'INVENTREE_ADMIN_USER': 'admin'
})
self.run_reload({'INVENTREE_ADMIN_USER': 'admin'})
self.assertEqual(user_count(), 1)
# enough set
self.run_reload({
'INVENTREE_ADMIN_USER': 'admin', # set username
'INVENTREE_ADMIN_EMAIL': 'info@example.com', # set email
'INVENTREE_ADMIN_PASSWORD': 'password123' # set password
'INVENTREE_ADMIN_PASSWORD': 'password123', # set password
})
self.assertEqual(user_count(), 2)
@ -1027,27 +986,27 @@ class TestSettings(InvenTreeTestCase):
"""Test get_config_file."""
# normal run - not configured
valid = [
'inventree/config.yaml',
'inventree/data/config.yaml',
]
valid = ['inventree/config.yaml', 'inventree/data/config.yaml']
self.assertTrue(any(opt in str(config.get_config_file()).lower() for opt in valid))
self.assertTrue(
any(opt in str(config.get_config_file()).lower() for opt in valid)
)
# with env set
with self.in_env_context({'INVENTREE_CONFIG_FILE': 'my_special_conf.yaml'}):
self.assertIn('inventree/my_special_conf.yaml', str(config.get_config_file()).lower())
self.assertIn(
'inventree/my_special_conf.yaml', str(config.get_config_file()).lower()
)
def test_helpers_plugin_file(self):
"""Test get_plugin_file."""
# normal run - not configured
valid = [
'inventree/plugins.txt',
'inventree/data/plugins.txt',
]
valid = ['inventree/plugins.txt', 'inventree/data/plugins.txt']
self.assertTrue(any(opt in str(config.get_plugin_file()).lower() for opt in valid))
self.assertTrue(
any(opt in str(config.get_plugin_file()).lower() for opt in valid)
)
# with env set
with self.in_env_context({'INVENTREE_PLUGIN_FILE': 'my_special_plugins.txt'}):
@ -1064,11 +1023,15 @@ class TestSettings(InvenTreeTestCase):
self.assertEqual(config.get_setting(TEST_ENV_NAME, None), '321')
# test typecasting to dict - None should be mapped to empty dict
self.assertEqual(config.get_setting(TEST_ENV_NAME, None, None, typecast=dict), {})
self.assertEqual(
config.get_setting(TEST_ENV_NAME, None, None, typecast=dict), {}
)
# test typecasting to dict - valid JSON string should be mapped to corresponding dict
with self.in_env_context({TEST_ENV_NAME: '{"a": 1}'}):
self.assertEqual(config.get_setting(TEST_ENV_NAME, None, typecast=dict), {"a": 1})
self.assertEqual(
config.get_setting(TEST_ENV_NAME, None, typecast=dict), {'a': 1}
)
# test typecasting to dict - invalid JSON string should be mapped to empty dict
with self.in_env_context({TEST_ENV_NAME: "{'a': 1}"}):
@ -1084,8 +1047,8 @@ class TestInstanceName(InvenTreeTestCase):
self.assertEqual(version.inventreeInstanceTitle(), 'InvenTree')
# set up required setting
InvenTreeSetting.set_setting("INVENTREE_INSTANCE_TITLE", True, self.user)
InvenTreeSetting.set_setting("INVENTREE_INSTANCE", "Testing title", self.user)
InvenTreeSetting.set_setting('INVENTREE_INSTANCE_TITLE', True, self.user)
InvenTreeSetting.set_setting('INVENTREE_INSTANCE', 'Testing title', self.user)
self.assertEqual(version.inventreeInstanceTitle(), 'Testing title')
@ -1096,7 +1059,9 @@ class TestInstanceName(InvenTreeTestCase):
def test_instance_url(self):
"""Test instance url settings."""
# Set up required setting
InvenTreeSetting.set_setting("INVENTREE_BASE_URL", "http://127.1.2.3", self.user)
InvenTreeSetting.set_setting(
'INVENTREE_BASE_URL', 'http://127.1.2.3', self.user
)
# The site should also be changed
site_obj = Site.objects.all().order_by('id').first()
@ -1106,12 +1071,7 @@ class TestInstanceName(InvenTreeTestCase):
class TestOffloadTask(InvenTreeTestCase):
"""Tests for offloading tasks to the background worker"""
fixtures = [
'category',
'part',
'location',
'stock',
]
fixtures = ['category', 'part', 'location', 'stock']
def test_offload_tasks(self):
"""Test that we can offload various tasks to the background worker thread.
@ -1127,38 +1087,38 @@ class TestOffloadTask(InvenTreeTestCase):
Ref: https://github.com/inventree/InvenTree/pull/3273
"""
self.assertTrue(offload_task(
'dummy_tasks.stock',
item=StockItem.objects.get(pk=1),
loc=StockLocation.objects.get(pk=1),
force_async=True
))
self.assertTrue(
offload_task(
'dummy_tasks.stock',
item=StockItem.objects.get(pk=1),
loc=StockLocation.objects.get(pk=1),
force_async=True,
)
)
self.assertTrue(offload_task(
'dummy_task.numbers',
1, 2, 3, 4, 5,
force_async=True
))
self.assertTrue(
offload_task('dummy_task.numbers', 1, 2, 3, 4, 5, force_async=True)
)
# Offload a dummy task, but force sync
# This should fail, because the function does not exist
with self.assertLogs(logger='inventree', level='WARNING') as log:
self.assertFalse(offload_task(
'dummy_task.numbers',
1, 1, 1,
force_sync=True
))
self.assertFalse(
offload_task('dummy_task.numbers', 1, 1, 1, force_sync=True)
)
self.assertIn("Malformed function path", str(log.output))
self.assertIn('Malformed function path', str(log.output))
# Offload dummy task with a Part instance
# This should succeed, ensuring that the Part instance is correctly pickled
self.assertTrue(offload_task(
'dummy_tasks.parts',
part=Part.objects.get(pk=1),
cat=PartCategory.objects.get(pk=1),
force_async=True
))
self.assertTrue(
offload_task(
'dummy_tasks.parts',
part=Part.objects.get(pk=1),
cat=PartCategory.objects.get(pk=1),
force_async=True,
)
)
def test_daily_holdoff(self):
"""Tests for daily task holdoff helper functions"""
@ -1180,7 +1140,9 @@ class TestOffloadTask(InvenTreeTestCase):
# An attempt has been logged, but it is too recent
result = InvenTree.tasks.check_daily_holdoff('dummy_task')
self.assertFalse(result)
self.assertIn("Last attempt for 'dummy_task' was too recent", str(cm.output))
self.assertIn(
"Last attempt for 'dummy_task' was too recent", str(cm.output)
)
# Mark last attempt a few days ago - should now return True
t_old = datetime.now() - timedelta(days=3)
@ -1191,13 +1153,17 @@ class TestOffloadTask(InvenTreeTestCase):
self.assertTrue(result)
# Last attempt should have been updated
self.assertNotEqual(t_old, InvenTreeSetting.get_setting('_dummy_task_ATTEMPT', '', cache=False))
self.assertNotEqual(
t_old, InvenTreeSetting.get_setting('_dummy_task_ATTEMPT', '', cache=False)
)
# Last attempt should prevent us now
with self.assertLogs(logger='inventree', level='INFO') as cm:
result = InvenTree.tasks.check_daily_holdoff('dummy_task')
self.assertFalse(result)
self.assertIn("Last attempt for 'dummy_task' was too recent", str(cm.output))
self.assertIn(
"Last attempt for 'dummy_task' was too recent", str(cm.output)
)
# Configure so a task was successful too recently
InvenTreeSetting.set_setting('_dummy_task_ATTEMPT', t_old, None)

View File

@ -24,7 +24,7 @@ def reload_translation_stats():
keys = _translation_stats.keys()
# Note that the names used in the stats file may not align 100%
for (code, _lang) in settings.LANGUAGES:
for code, _lang in settings.LANGUAGES:
if code in keys:
# Direct match, move on
continue

View File

@ -39,7 +39,7 @@ def getMigrationFileNames(app):
files = local_dir.joinpath('..', app, 'migrations').iterdir()
# Regex pattern for migration files
regex = re.compile(r"^[\d]+_.*\.py$")
regex = re.compile(r'^[\d]+_.*\.py$')
migration_files = []
@ -56,7 +56,6 @@ def getOldestMigrationFile(app, exclude_extension=True, ignore_initial=True):
oldest_file = None
for f in getMigrationFileNames(app):
if ignore_initial and f.startswith('0001_initial'):
continue
@ -115,9 +114,7 @@ class UserMixin:
# Create a user to log in with
cls.user = get_user_model().objects.create_user(
username=cls.username,
password=cls.password,
email=cls.email
username=cls.username, password=cls.password, email=cls.email
)
# Create a group for the user
@ -163,15 +160,15 @@ class UserMixin:
raise TypeError('assignRole: assign_all must be a boolean value')
if not role and not assign_all:
raise ValueError('assignRole: either role must be provided, or assign_all must be set')
raise ValueError(
'assignRole: either role must be provided, or assign_all must be set'
)
if not assign_all and role:
rule, perm = role.split('.')
for ruleset in group.rule_sets.all():
if assign_all or ruleset.name == rule:
if assign_all or perm == 'view':
ruleset.can_view = True
elif assign_all or perm == 'change':
@ -205,37 +202,24 @@ class ExchangeRateMixin:
def generate_exchange_rates(self):
"""Helper function which generates some exchange rates to work with"""
rates = {
'AUD': 1.5,
'CAD': 1.7,
'GBP': 0.9,
'USD': 1.0,
}
rates = {'AUD': 1.5, 'CAD': 1.7, 'GBP': 0.9, 'USD': 1.0}
# Create a dummy backend
ExchangeBackend.objects.create(
name='InvenTreeExchange',
base_currency='USD',
)
ExchangeBackend.objects.create(name='InvenTreeExchange', base_currency='USD')
backend = ExchangeBackend.objects.get(name='InvenTreeExchange')
items = []
for currency, rate in rates.items():
items.append(
Rate(
currency=currency,
value=rate,
backend=backend,
)
)
items.append(Rate(currency=currency, value=rate, backend=backend))
Rate.objects.bulk_create(items)
class InvenTreeTestCase(ExchangeRateMixin, UserMixin, TestCase):
"""Testcase with user setup buildin."""
pass
@ -243,7 +227,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
"""Base class for running InvenTree API tests."""
@contextmanager
def assertNumQueriesLessThan(self, value, using='default', verbose=False, debug=False):
def assertNumQueriesLessThan(
self, value, using='default', verbose=False, debug=False
):
"""Context manager to check that the number of queries is less than a certain value.
Example:
@ -252,17 +238,17 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
Ref: https://stackoverflow.com/questions/1254170/django-is-there-a-way-to-count-sql-queries-from-an-unit-test/59089020#59089020
"""
with CaptureQueriesContext(connections[using]) as context:
yield # your test will be run here
yield # your test will be run here
if verbose:
msg = "\r\n%s" % json.dumps(context.captured_queries, indent=4)
msg = '\r\n%s' % json.dumps(context.captured_queries, indent=4)
else:
msg = None
n = len(context.captured_queries)
if debug:
print(f"Expected less than {value} queries, got {n} queries")
print(f'Expected less than {value} queries, got {n} queries')
self.assertLess(n, value, msg=msg)
@ -273,8 +259,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
return
if expected_code != response.status_code:
print(f"Unexpected {method} response at '{url}': status_code = {response.status_code}")
print(
f"Unexpected {method} response at '{url}': status_code = {response.status_code}"
)
if hasattr(response, 'data'):
print('data:', response.data)
@ -359,7 +346,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
return response
def download_file(self, url, data, expected_code=None, expected_fn=None, decode=True):
def download_file(
self, url, data, expected_code=None, expected_fn=None, decode=True
):
"""Download a file from the server, and return an in-memory file."""
response = self.client.get(url, data=data, format='json')
@ -367,7 +356,9 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
# Check that the response is of the correct type
if not isinstance(response, StreamingHttpResponse):
raise ValueError("Response is not a StreamingHttpResponse object as expected")
raise ValueError(
'Response is not a StreamingHttpResponse object as expected'
)
# Extract filename
disposition = response.headers['Content-Disposition']
@ -394,7 +385,14 @@ class InvenTreeAPITestCase(ExchangeRateMixin, UserMixin, APITestCase):
return file
def process_csv(self, file_object, delimiter=',', required_cols=None, excluded_cols=None, required_rows=None):
def process_csv(
self,
file_object,
delimiter=',',
required_cols=None,
excluded_cols=None,
required_rows=None,
):
"""Helper function to process and validate a downloaded csv file."""
# Check that the correct object type has been passed
self.assertTrue(isinstance(file_object, io.StringIO))

View File

@ -10,9 +10,11 @@ from django.urls import include, path, re_path
from django.views.decorators.csrf import csrf_exempt
from django.views.generic.base import RedirectView
from dj_rest_auth.registration.views import (ConfirmEmailView,
SocialAccountDisconnectView,
SocialAccountListView)
from dj_rest_auth.registration.views import (
ConfirmEmailView,
SocialAccountDisconnectView,
SocialAccountListView,
)
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView
from sesame.views import LoginView
@ -28,36 +30,49 @@ import stock.api
import users.api
from build.urls import build_urls
from common.urls import common_urls
from company.urls import (company_urls, manufacturer_part_urls,
supplier_part_urls)
from company.urls import company_urls, manufacturer_part_urls, supplier_part_urls
from order.urls import order_urls
from part.urls import part_urls
from plugin.urls import get_plugin_urls
from stock.urls import stock_urls
from web.urls import urlpatterns as platform_urls
from .api import (APISearchView, InfoView, NotFoundView, VersionTextView,
VersionView)
from .api import APISearchView, InfoView, NotFoundView, VersionTextView, VersionView
from .magic_login import GetSimpleLoginView
from .social_auth_urls import (EmailListView, EmailPrimaryView,
EmailRemoveView, EmailVerifyView,
SocialProviderListView, social_auth_urlpatterns)
from .views import (AboutView, AppearanceSelectView, CustomConnectionsView,
CustomEmailView, CustomLoginView,
CustomPasswordResetFromKeyView,
CustomSessionDeleteOtherView, CustomSessionDeleteView,
DatabaseStatsView, DynamicJsView, EditUserView, IndexView,
NotificationsView, SearchView, SetPasswordView,
SettingsView, auth_request)
from .social_auth_urls import (
EmailListView,
EmailPrimaryView,
EmailRemoveView,
EmailVerifyView,
SocialProviderListView,
social_auth_urlpatterns,
)
from .views import (
AboutView,
AppearanceSelectView,
CustomConnectionsView,
CustomEmailView,
CustomLoginView,
CustomPasswordResetFromKeyView,
CustomSessionDeleteOtherView,
CustomSessionDeleteView,
DatabaseStatsView,
DynamicJsView,
EditUserView,
IndexView,
NotificationsView,
SearchView,
SetPasswordView,
SettingsView,
auth_request,
)
admin.site.site_header = "InvenTree Admin"
admin.site.site_header = 'InvenTree Admin'
apipatterns = [
# Global search
path('search/', APISearchView.as_view(), name='api-search'),
re_path(r'^settings/', include(common.api.settings_api_urls)),
re_path(r'^part/', include(part.api.part_api_urls)),
re_path(r'^bom/', include(part.api.bom_api_urls)),
@ -69,117 +84,288 @@ apipatterns = [
re_path(r'^report/', include(report.api.report_api_urls)),
re_path(r'^user/', include(users.api.user_urls)),
re_path(r'^admin/', include(common.api.admin_api_urls)),
# Plugin endpoints
path('', include(plugin.api.plugin_api_urls)),
# Common endpoints endpoint
path('', include(common.api.common_api_urls)),
# OpenAPI Schema
re_path('schema/', SpectacularAPIView.as_view(custom_settings={'SCHEMA_PATH_PREFIX': '/api/'}), name='schema'),
re_path(
'schema/',
SpectacularAPIView.as_view(custom_settings={'SCHEMA_PATH_PREFIX': '/api/'}),
name='schema',
),
# InvenTree information endpoints
path("version-text", VersionTextView.as_view(), name="api-version-text"), # version text
path(
'version-text', VersionTextView.as_view(), name='api-version-text'
), # version text
path('version/', VersionView.as_view(), name='api-version'), # version info
path('', InfoView.as_view(), name='api-inventree-info'), # server info
# Auth API endpoints
path('auth/', include([
re_path(r'^registration/account-confirm-email/(?P<key>[-:\w]+)/$', ConfirmEmailView.as_view(), name='account_confirm_email'),
path('registration/', include('dj_rest_auth.registration.urls')),
path('providers/', SocialProviderListView.as_view(), name='social_providers'),
path('emails/', include([path('<int:pk>/', include([
path('primary/', EmailPrimaryView.as_view(), name='email-primary'),
path('verify/', EmailVerifyView.as_view(), name='email-verify'),
path('remove/', EmailRemoveView().as_view(), name='email-remove'),])),
path('', EmailListView.as_view(), name='email-list')
])),
path('social/', include(social_auth_urlpatterns)),
path('social/', SocialAccountListView.as_view(), name='social_account_list'),
path('social/<int:pk>/disconnect/', SocialAccountDisconnectView.as_view(), name='social_account_disconnect'),
path('', include('dj_rest_auth.urls')),
])),
path(
'auth/',
include([
re_path(
r'^registration/account-confirm-email/(?P<key>[-:\w]+)/$',
ConfirmEmailView.as_view(),
name='account_confirm_email',
),
path('registration/', include('dj_rest_auth.registration.urls')),
path(
'providers/', SocialProviderListView.as_view(), name='social_providers'
),
path(
'emails/',
include([
path(
'<int:pk>/',
include([
path(
'primary/',
EmailPrimaryView.as_view(),
name='email-primary',
),
path(
'verify/',
EmailVerifyView.as_view(),
name='email-verify',
),
path(
'remove/',
EmailRemoveView().as_view(),
name='email-remove',
),
]),
),
path('', EmailListView.as_view(), name='email-list'),
]),
),
path('social/', include(social_auth_urlpatterns)),
path(
'social/', SocialAccountListView.as_view(), name='social_account_list'
),
path(
'social/<int:pk>/disconnect/',
SocialAccountDisconnectView.as_view(),
name='social_account_disconnect',
),
path('', include('dj_rest_auth.urls')),
]),
),
# Magic login URLs
path("email/generate/", csrf_exempt(GetSimpleLoginView().as_view()), name="sesame-generate",),
path("email/login/", LoginView.as_view(), name="sesame-login"),
path(
'email/generate/',
csrf_exempt(GetSimpleLoginView().as_view()),
name='sesame-generate',
),
path('email/login/', LoginView.as_view(), name='sesame-login'),
# Unknown endpoint
re_path(r'^.*$', NotFoundView.as_view(), name='api-404'),
]
settings_urls = [
re_path(r'^i18n/?', include('django.conf.urls.i18n')),
re_path(r'^appearance/?', AppearanceSelectView.as_view(), name='settings-appearance'),
re_path(
r'^appearance/?', AppearanceSelectView.as_view(), name='settings-appearance'
),
# Catch any other urls
re_path(r'^.*$', SettingsView.as_view(template_name='InvenTree/settings/settings.html'), name='settings'),
re_path(
r'^.*$',
SettingsView.as_view(template_name='InvenTree/settings/settings.html'),
name='settings',
),
]
notifications_urls = [
# Catch any other urls
re_path(r'^.*$', NotificationsView.as_view(), name='notifications'),
re_path(r'^.*$', NotificationsView.as_view(), name='notifications')
]
# These javascript files are served "dynamically" - i.e. rendered on demand
dynamic_javascript_urls = [
re_path(r'^calendar.js', DynamicJsView.as_view(template_name='js/dynamic/calendar.js'), name='calendar.js'),
re_path(r'^nav.js', DynamicJsView.as_view(template_name='js/dynamic/nav.js'), name='nav.js'),
re_path(r'^permissions.js', DynamicJsView.as_view(template_name='js/dynamic/permissions.js'), name='permissions.js'),
re_path(r'^settings.js', DynamicJsView.as_view(template_name='js/dynamic/settings.js'), name='settings.js'),
re_path(
r'^calendar.js',
DynamicJsView.as_view(template_name='js/dynamic/calendar.js'),
name='calendar.js',
),
re_path(
r'^nav.js',
DynamicJsView.as_view(template_name='js/dynamic/nav.js'),
name='nav.js',
),
re_path(
r'^permissions.js',
DynamicJsView.as_view(template_name='js/dynamic/permissions.js'),
name='permissions.js',
),
re_path(
r'^settings.js',
DynamicJsView.as_view(template_name='js/dynamic/settings.js'),
name='settings.js',
),
]
# These javascript files are passed through the Django translation layer
translated_javascript_urls = [
re_path(r'^api.js', DynamicJsView.as_view(template_name='js/translated/api.js'), name='api.js'),
re_path(r'^attachment.js', DynamicJsView.as_view(template_name='js/translated/attachment.js'), name='attachment.js'),
re_path(r'^barcode.js', DynamicJsView.as_view(template_name='js/translated/barcode.js'), name='barcode.js'),
re_path(r'^bom.js', DynamicJsView.as_view(template_name='js/translated/bom.js'), name='bom.js'),
re_path(r'^build.js', DynamicJsView.as_view(template_name='js/translated/build.js'), name='build.js'),
re_path(r'^charts.js', DynamicJsView.as_view(template_name='js/translated/charts.js'), name='charts.js'),
re_path(r'^company.js', DynamicJsView.as_view(template_name='js/translated/company.js'), name='company.js'),
re_path(r'^filters.js', DynamicJsView.as_view(template_name='js/translated/filters.js'), name='filters.js'),
re_path(r'^forms.js', DynamicJsView.as_view(template_name='js/translated/forms.js'), name='forms.js'),
re_path(r'^helpers.js', DynamicJsView.as_view(template_name='js/translated/helpers.js'), name='helpers.js'),
re_path(r'^index.js', DynamicJsView.as_view(template_name='js/translated/index.js'), name='index.js'),
re_path(r'^label.js', DynamicJsView.as_view(template_name='js/translated/label.js'), name='label.js'),
re_path(r'^model_renderers.js', DynamicJsView.as_view(template_name='js/translated/model_renderers.js'), name='model_renderers.js'),
re_path(r'^modals.js', DynamicJsView.as_view(template_name='js/translated/modals.js'), name='modals.js'),
re_path(r'^order.js', DynamicJsView.as_view(template_name='js/translated/order.js'), name='order.js'),
re_path(r'^part.js', DynamicJsView.as_view(template_name='js/translated/part.js'), name='part.js'),
re_path(r'^purchase_order.js', DynamicJsView.as_view(template_name='js/translated/purchase_order.js'), name='purchase_order.js'),
re_path(r'^return_order.js', DynamicJsView.as_view(template_name='js/translated/return_order.js'), name='return_order.js'),
re_path(r'^report.js', DynamicJsView.as_view(template_name='js/translated/report.js'), name='report.js'),
re_path(r'^sales_order.js', DynamicJsView.as_view(template_name='js/translated/sales_order.js'), name='sales_order.js'),
re_path(r'^search.js', DynamicJsView.as_view(template_name='js/translated/search.js'), name='search.js'),
re_path(r'^stock.js', DynamicJsView.as_view(template_name='js/translated/stock.js'), name='stock.js'),
re_path(r'^status_codes.js', DynamicJsView.as_view(template_name='js/translated/status_codes.js'), name='status_codes.js'),
re_path(r'^plugin.js', DynamicJsView.as_view(template_name='js/translated/plugin.js'), name='plugin.js'),
re_path(r'^pricing.js', DynamicJsView.as_view(template_name='js/translated/pricing.js'), name='pricing.js'),
re_path(r'^news.js', DynamicJsView.as_view(template_name='js/translated/news.js'), name='news.js'),
re_path(r'^tables.js', DynamicJsView.as_view(template_name='js/translated/tables.js'), name='tables.js'),
re_path(r'^table_filters.js', DynamicJsView.as_view(template_name='js/translated/table_filters.js'), name='table_filters.js'),
re_path(r'^notification.js', DynamicJsView.as_view(template_name='js/translated/notification.js'), name='notification.js'),
re_path(
r'^api.js',
DynamicJsView.as_view(template_name='js/translated/api.js'),
name='api.js',
),
re_path(
r'^attachment.js',
DynamicJsView.as_view(template_name='js/translated/attachment.js'),
name='attachment.js',
),
re_path(
r'^barcode.js',
DynamicJsView.as_view(template_name='js/translated/barcode.js'),
name='barcode.js',
),
re_path(
r'^bom.js',
DynamicJsView.as_view(template_name='js/translated/bom.js'),
name='bom.js',
),
re_path(
r'^build.js',
DynamicJsView.as_view(template_name='js/translated/build.js'),
name='build.js',
),
re_path(
r'^charts.js',
DynamicJsView.as_view(template_name='js/translated/charts.js'),
name='charts.js',
),
re_path(
r'^company.js',
DynamicJsView.as_view(template_name='js/translated/company.js'),
name='company.js',
),
re_path(
r'^filters.js',
DynamicJsView.as_view(template_name='js/translated/filters.js'),
name='filters.js',
),
re_path(
r'^forms.js',
DynamicJsView.as_view(template_name='js/translated/forms.js'),
name='forms.js',
),
re_path(
r'^helpers.js',
DynamicJsView.as_view(template_name='js/translated/helpers.js'),
name='helpers.js',
),
re_path(
r'^index.js',
DynamicJsView.as_view(template_name='js/translated/index.js'),
name='index.js',
),
re_path(
r'^label.js',
DynamicJsView.as_view(template_name='js/translated/label.js'),
name='label.js',
),
re_path(
r'^model_renderers.js',
DynamicJsView.as_view(template_name='js/translated/model_renderers.js'),
name='model_renderers.js',
),
re_path(
r'^modals.js',
DynamicJsView.as_view(template_name='js/translated/modals.js'),
name='modals.js',
),
re_path(
r'^order.js',
DynamicJsView.as_view(template_name='js/translated/order.js'),
name='order.js',
),
re_path(
r'^part.js',
DynamicJsView.as_view(template_name='js/translated/part.js'),
name='part.js',
),
re_path(
r'^purchase_order.js',
DynamicJsView.as_view(template_name='js/translated/purchase_order.js'),
name='purchase_order.js',
),
re_path(
r'^return_order.js',
DynamicJsView.as_view(template_name='js/translated/return_order.js'),
name='return_order.js',
),
re_path(
r'^report.js',
DynamicJsView.as_view(template_name='js/translated/report.js'),
name='report.js',
),
re_path(
r'^sales_order.js',
DynamicJsView.as_view(template_name='js/translated/sales_order.js'),
name='sales_order.js',
),
re_path(
r'^search.js',
DynamicJsView.as_view(template_name='js/translated/search.js'),
name='search.js',
),
re_path(
r'^stock.js',
DynamicJsView.as_view(template_name='js/translated/stock.js'),
name='stock.js',
),
re_path(
r'^status_codes.js',
DynamicJsView.as_view(template_name='js/translated/status_codes.js'),
name='status_codes.js',
),
re_path(
r'^plugin.js',
DynamicJsView.as_view(template_name='js/translated/plugin.js'),
name='plugin.js',
),
re_path(
r'^pricing.js',
DynamicJsView.as_view(template_name='js/translated/pricing.js'),
name='pricing.js',
),
re_path(
r'^news.js',
DynamicJsView.as_view(template_name='js/translated/news.js'),
name='news.js',
),
re_path(
r'^tables.js',
DynamicJsView.as_view(template_name='js/translated/tables.js'),
name='tables.js',
),
re_path(
r'^table_filters.js',
DynamicJsView.as_view(template_name='js/translated/table_filters.js'),
name='table_filters.js',
),
re_path(
r'^notification.js',
DynamicJsView.as_view(template_name='js/translated/notification.js'),
name='notification.js',
),
]
backendpatterns = [
# "Dynamic" javascript files which are rendered using InvenTree templating.
re_path(r'^js/dynamic/', include(dynamic_javascript_urls)),
re_path(r'^js/i18n/', include(translated_javascript_urls)),
re_path(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
re_path(r'^auth/?', auth_request),
re_path(r'^api/', include(apipatterns)),
re_path(r'^api-doc/', SpectacularRedocView.as_view(url_name='schema'), name='api-doc'),
re_path(
r'^api-doc/', SpectacularRedocView.as_view(url_name='schema'), name='api-doc'
),
]
classic_frontendpatterns = [
# Apps
re_path(r'^build/', include(build_urls)),
re_path(r'^common/', include(common_urls)),
@ -189,38 +375,48 @@ classic_frontendpatterns = [
re_path(r'^part/', include(part_urls)),
re_path(r'^stock/', include(stock_urls)),
re_path(r'^supplier-part/', include(supplier_part_urls)),
re_path(r'^edit-user/', EditUserView.as_view(), name='edit-user'),
re_path(r'^set-password/', SetPasswordView.as_view(), name='set-password'),
re_path(r'^index/', IndexView.as_view(), name='index'),
re_path(r'^notifications/', include(notifications_urls)),
re_path(r'^search/', SearchView.as_view(), name='search'),
re_path(r'^settings/', include(settings_urls)),
re_path(r'^about/', AboutView.as_view(), name='about'),
re_path(r'^stats/', DatabaseStatsView.as_view(), name='stats'),
# DB user sessions
path('accounts/sessions/other/delete/', view=CustomSessionDeleteOtherView.as_view(), name='session_delete_other', ),
re_path(r'^accounts/sessions/(?P<pk>\w+)/delete/$', view=CustomSessionDeleteView.as_view(), name='session_delete', ),
path(
'accounts/sessions/other/delete/',
view=CustomSessionDeleteOtherView.as_view(),
name='session_delete_other',
),
re_path(
r'^accounts/sessions/(?P<pk>\w+)/delete/$',
view=CustomSessionDeleteView.as_view(),
name='session_delete',
),
# Single Sign On / allauth
# overrides of urlpatterns
re_path(r'^accounts/email/', CustomEmailView.as_view(), name='account_email'),
re_path(r'^accounts/social/connections/', CustomConnectionsView.as_view(), name='socialaccount_connections'),
re_path(r"^accounts/password/reset/key/(?P<uidb36>[0-9A-Za-z]+)-(?P<key>.+)/$", CustomPasswordResetFromKeyView.as_view(), name="account_reset_password_from_key"),
re_path(
r'^accounts/social/connections/',
CustomConnectionsView.as_view(),
name='socialaccount_connections',
),
re_path(
r'^accounts/password/reset/key/(?P<uidb36>[0-9A-Za-z]+)-(?P<key>.+)/$',
CustomPasswordResetFromKeyView.as_view(),
name='account_reset_password_from_key',
),
# Override login page
re_path("accounts/login/", CustomLoginView.as_view(), name="account_login"),
re_path(r'^accounts/', include('allauth_2fa.urls')), # MFA support
re_path(r'^accounts/', include('allauth.urls')), # included urlpatterns
re_path('accounts/login/', CustomLoginView.as_view(), name='account_login'),
re_path(r'^accounts/', include('allauth_2fa.urls')), # MFA support
re_path(r'^accounts/', include('allauth.urls')), # included urlpatterns
]
urlpatterns = []
if settings.INVENTREE_ADMIN_ENABLED:
admin_url = settings.INVENTREE_ADMIN_URL,
admin_url = (settings.INVENTREE_ADMIN_URL,)
urlpatterns += [
path(f'{admin_url}/error_log/', include('error_report.urls')),
path(f'{admin_url}/', admin.site.urls, name='inventree-admin'),
@ -251,8 +447,13 @@ if settings.DEBUG:
# Redirect for favicon.ico
urlpatterns.append(
path('favicon.ico', RedirectView.as_view(url=f'{settings.STATIC_URL}img/favicon/favicon.ico'))
path(
'favicon.ico',
RedirectView.as_view(url=f'{settings.STATIC_URL}img/favicon/favicon.ico'),
)
)
# Send any unknown URLs to the parts page
urlpatterns += [re_path(r'^.*$', RedirectView.as_view(url='/index/', permanent=False), name='index')]
urlpatterns += [
re_path(r'^.*$', RedirectView.as_view(url='/index/', permanent=False), name='index')
]

View File

@ -69,7 +69,9 @@ class AllowedURLValidator(validators.URLValidator):
self.schemes = allowable_url_schemes()
# Determine if 'strict' URL validation is required (i.e. if the URL must have a schema prefix)
strict_urls = common.models.InvenTreeSetting.get_setting('INVENTREE_STRICT_URLS', True, cache=False)
strict_urls = common.models.InvenTreeSetting.get_setting(
'INVENTREE_STRICT_URLS', True, cache=False
)
if not strict_urls:
# Allow URLs which do not have a provided schema
@ -117,7 +119,7 @@ def validate_overage(value):
i = Decimal(value)
if i < 0:
raise ValidationError(_("Overage value must not be negative"))
raise ValidationError(_('Overage value must not be negative'))
# Looks like a number
return True
@ -133,17 +135,15 @@ def validate_overage(value):
f = float(v)
if f < 0:
raise ValidationError(_("Overage value must not be negative"))
raise ValidationError(_('Overage value must not be negative'))
elif f > 100:
raise ValidationError(_("Overage must not exceed 100%"))
raise ValidationError(_('Overage must not exceed 100%'))
return True
except ValueError:
pass
raise ValidationError(
_("Invalid value for overage")
)
raise ValidationError(_('Invalid value for overage'))
def validate_part_name_format(value):
@ -179,8 +179,6 @@ def validate_part_name_format(value):
try:
Template(value).render({'part': p})
except Exception as exc:
raise ValidationError({
'value': str(exc)
})
raise ValidationError({'value': str(exc)})
return True

View File

@ -19,7 +19,7 @@ from dulwich.repo import NotGitRepository, Repo
from .api_version import INVENTREE_API_TEXT, INVENTREE_API_VERSION
# InvenTree software version
INVENTREE_SW_VERSION = "0.14.0 dev"
INVENTREE_SW_VERSION = '0.14.0 dev'
# Discover git
try:
@ -32,8 +32,8 @@ except (NotGitRepository, FileNotFoundError):
def checkMinPythonVersion():
"""Check that the Python version is at least 3.9"""
version = sys.version.split(" ")[0]
docs = "https://docs.inventree.org/en/stable/start/intro/#python-requirements"
version = sys.version.split(' ')[0]
docs = 'https://docs.inventree.org/en/stable/start/intro/#python-requirements'
msg = f"""
InvenTree requires Python 3.9 or above - you are running version {version}.
@ -47,22 +47,22 @@ def checkMinPythonVersion():
if sys.version_info.major == 3 and sys.version_info.minor < 9:
raise RuntimeError(msg)
print(f"Python version {version} - {sys.executable}")
print(f'Python version {version} - {sys.executable}')
def inventreeInstanceName():
"""Returns the InstanceName settings for the current database."""
import common.models
return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "")
return common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE', '')
def inventreeInstanceTitle():
"""Returns the InstanceTitle for the current database."""
import common.models
if common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE_TITLE", False):
return common.models.InvenTreeSetting.get_setting("INVENTREE_INSTANCE", "")
if common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE_TITLE', False):
return common.models.InvenTreeSetting.get_setting('INVENTREE_INSTANCE', '')
return 'InvenTree'
@ -76,7 +76,7 @@ def inventreeVersionTuple(version=None):
if version is None:
version = INVENTREE_SW_VERSION
match = re.match(r"^.*(\d+)\.(\d+)\.(\d+).*$", str(version))
match = re.match(r'^.*(\d+)\.(\d+)\.(\d+).*$', str(version))
return [int(g) for g in match.groups()]
@ -93,14 +93,14 @@ def inventreeDocsVersion():
Release -> "major.minor.sub" e.g. "0.5.2"
"""
if isInvenTreeDevelopmentVersion():
return "latest"
return 'latest'
return INVENTREE_SW_VERSION # pragma: no cover
def inventreeDocUrl():
"""Return URL for InvenTree documentation site."""
tag = inventreeDocsVersion()
return f"https://docs.inventree.org/en/{tag}"
return f'https://docs.inventree.org/en/{tag}'
def inventreeAppUrl():
@ -110,12 +110,12 @@ def inventreeAppUrl():
def inventreeCreditsUrl():
"""Return URL for InvenTree credits site."""
return "https://docs.inventree.org/en/latest/credits/"
return 'https://docs.inventree.org/en/latest/credits/'
def inventreeGithubUrl():
"""Return URL for InvenTree github site."""
return "https://github.com/InvenTree/InvenTree/"
return 'https://github.com/InvenTree/InvenTree/'
def isInvenTreeUpToDate():
@ -124,7 +124,10 @@ def isInvenTreeUpToDate():
A background task periodically queries GitHub for latest version, and stores it to the database as "_INVENTREE_LATEST_VERSION"
"""
import common.models
latest = common.models.InvenTreeSetting.get_setting('_INVENTREE_LATEST_VERSION', backup_value=None, create=False)
latest = common.models.InvenTreeSetting.get_setting(
'_INVENTREE_LATEST_VERSION', backup_value=None, create=False
)
# No record for "latest" version - we must assume we are up to date!
if not latest:
@ -144,24 +147,26 @@ def inventreeApiVersion():
def parse_version_text():
"""Parse the version text to structured data."""
patched_data = INVENTREE_API_TEXT.split("\n\n")
patched_data = INVENTREE_API_TEXT.split('\n\n')
# Remove first newline on latest version
patched_data[0] = patched_data[0].replace("\n", "", 1)
patched_data[0] = patched_data[0].replace('\n', '', 1)
version_data = {}
for version in patched_data:
data = version.split("\n")
data = version.split('\n')
version_split = data[0].split(' -> ')
version_detail = version_split[1].split(':', 1) if len(version_split) > 1 else ['', ]
version_detail = (
version_split[1].split(':', 1) if len(version_split) > 1 else ['']
)
new_data = {
"version": version_split[0].strip(),
"date": version_detail[0].strip(),
"gh": version_detail[1].strip() if len(version_detail) > 1 else None,
"text": data[1:],
"latest": False,
'version': version_split[0].strip(),
'date': version_detail[0].strip(),
'gh': version_detail[1].strip() if len(version_detail) > 1 else None,
'text': data[1:],
'latest': False,
}
version_data[new_data["version"]] = new_data
version_data[new_data['version']] = new_data
return version_data
@ -183,7 +188,7 @@ def inventreeApiText(versions: int = 10, start_version: int = 0):
start_version = INVENTREE_API_VERSION - versions
return {
f"v{a}": version_data.get(f"v{a}", None)
f'v{a}': version_data.get(f'v{a}', None)
for a in range(start_version, start_version + versions)
}
@ -222,7 +227,9 @@ def inventreeCommitDate():
if main_commit is None:
return None
commit_dt = dt.fromtimestamp(main_commit.commit_time) + td(seconds=main_commit.commit_timezone)
commit_dt = dt.fromtimestamp(main_commit.commit_time) + td(
seconds=main_commit.commit_timezone
)
return str(commit_dt.date())

View File

@ -5,8 +5,7 @@ as JSON objects and passing them to modal forms (using jQuery / bootstrap).
"""
from django.contrib.auth import password_validation
from django.contrib.auth.mixins import (LoginRequiredMixin,
PermissionRequiredMixin)
from django.contrib.auth.mixins import LoginRequiredMixin, PermissionRequiredMixin
from django.core.exceptions import ValidationError
from django.http import HttpResponse, HttpResponseRedirect, JsonResponse
from django.shortcuts import redirect
@ -20,8 +19,7 @@ from django.views.generic.base import RedirectView, TemplateView
from allauth.account.forms import AddEmailForm
from allauth.account.models import EmailAddress
from allauth.account.views import (EmailView, LoginView,
PasswordResetFromKeyView)
from allauth.account.views import EmailView, LoginView, PasswordResetFromKeyView
from allauth.socialaccount.forms import DisconnectForm
from allauth.socialaccount.views import ConnectionsView
from djmoney.contrib.exchange.models import ExchangeBackend, Rate
@ -104,7 +102,6 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
return True
for required in roles_required:
(role, permission) = required.split('.')
if role not in RuleSet.RULESET_NAMES:
@ -138,12 +135,14 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
app_label = model._meta.app_label
model_name = model._meta.model_name
table = f"{app_label}_{model_name}"
table = f'{app_label}_{model_name}'
permission = self.get_permission_class()
if not permission:
raise AttributeError(f"permission_class not defined for {type(self).__name__}")
raise AttributeError(
f'permission_class not defined for {type(self).__name__}'
)
# Check if the user has the required permission
return RuleSet.check_table_permission(user, table, permission)
@ -182,7 +181,6 @@ class InvenTreeRoleMixin(PermissionRequiredMixin):
}
for view_class in permission_map.keys():
if issubclass(type(self), view_class):
return permission_map[view_class]
@ -282,9 +280,7 @@ class AjaxMixin(InvenTreeRoleMixin):
data['title'] = self.get_form_title()
data['html_form'] = render_to_string(
self.ajax_template_name,
context,
request=request
self.ajax_template_name, context, request=request
)
# Custom feedback`data
@ -329,7 +325,9 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
"""
super(UpdateView, self).get(request, *args, **kwargs)
return self.renderJsonResponse(request, self.get_form(), context=self.get_context_data())
return self.renderJsonResponse(
request, self.get_form(), context=self.get_context_data()
)
def save(self, object, form, **kwargs):
"""Method for updating the object in the database. Default implementation is very simple, but can be overridden if required.
@ -379,7 +377,6 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
data[key] = value
if valid:
# Save the updated object to the database
self.save(self.object, form)
@ -399,8 +396,8 @@ class AjaxUpdateView(AjaxMixin, UpdateView):
class EditUserView(AjaxUpdateView):
"""View for editing user information."""
ajax_template_name = "modal_form.html"
ajax_form_title = _("Edit User Information")
ajax_template_name = 'modal_form.html'
ajax_form_title = _('Edit User Information')
form_class = EditUserForm
def get_object(self):
@ -411,8 +408,8 @@ class EditUserView(AjaxUpdateView):
class SetPasswordView(AjaxUpdateView):
"""View for setting user password."""
ajax_template_name = "InvenTree/password.html"
ajax_form_title = _("Set Password")
ajax_template_name = 'InvenTree/password.html'
ajax_form_title = _('Set Password')
form_class = SetPasswordForm
def get_object(self):
@ -494,14 +491,14 @@ class SearchView(TemplateView):
class DynamicJsView(TemplateView):
"""View for returning javacsript files, which instead of being served dynamically, are passed through the django translation engine!"""
template_name = ""
template_name = ''
content_type = 'text/javascript'
class SettingsView(TemplateView):
"""View for configuring User settings."""
template_name = "InvenTree/settings/settings.html"
template_name = 'InvenTree/settings/settings.html'
def get_context_data(self, **kwargs):
"""Add data for template."""
@ -509,33 +506,37 @@ class SettingsView(TemplateView):
ctx['settings'] = common_models.InvenTreeSetting.objects.all().order_by('key')
ctx["base_currency"] = common_settings.currency_code_default()
ctx["currencies"] = common_settings.currency_codes
ctx['base_currency'] = common_settings.currency_code_default()
ctx['currencies'] = common_settings.currency_codes
ctx["rates"] = Rate.objects.filter(backend="InvenTreeExchange")
ctx['rates'] = Rate.objects.filter(backend='InvenTreeExchange')
ctx["categories"] = PartCategory.objects.all().order_by('tree_id', 'lft', 'name')
ctx['categories'] = PartCategory.objects.all().order_by(
'tree_id', 'lft', 'name'
)
# When were the rates last updated?
try:
backend = ExchangeBackend.objects.filter(name='InvenTreeExchange')
if backend.exists():
backend = backend.first()
ctx["rates_updated"] = backend.last_update
ctx['rates_updated'] = backend.last_update
except Exception:
ctx["rates_updated"] = None
ctx['rates_updated'] = None
# Forms and context for allauth
ctx['add_email_form'] = AddEmailForm
ctx["can_add_email"] = EmailAddress.objects.can_add_email(self.request.user)
ctx['can_add_email'] = EmailAddress.objects.can_add_email(self.request.user)
# Form and context for allauth social-accounts
ctx["request"] = self.request
ctx['request'] = self.request
ctx['social_form'] = DisconnectForm(request=self.request)
# user db sessions
ctx['session_key'] = self.request.session.session_key
ctx['session_list'] = self.request.user.session_set.filter(expire_date__gt=now()).order_by('-last_activity')
ctx['session_list'] = self.request.user.session_set.filter(
expire_date__gt=now()
).order_by('-last_activity')
return ctx
@ -550,20 +551,23 @@ class AllauthOverrides(LoginRequiredMixin):
class CustomEmailView(AllauthOverrides, EmailView):
"""Override of allauths EmailView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("settings")
success_url = reverse_lazy('settings')
class CustomConnectionsView(AllauthOverrides, ConnectionsView):
"""Override of allauths ConnectionsView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("settings")
success_url = reverse_lazy('settings')
class CustomPasswordResetFromKeyView(PasswordResetFromKeyView):
"""Override of allauths PasswordResetFromKeyView to always show the settings but leave the functions allow."""
success_url = reverse_lazy("account_login")
success_url = reverse_lazy('account_login')
class UserSessionOverride():
class UserSessionOverride:
"""Overrides sucessurl to lead to settings."""
def get_success_url(self):
@ -573,11 +577,13 @@ class UserSessionOverride():
class CustomSessionDeleteView(UserSessionOverride, SessionDeleteView):
"""Revert to settings after session delete."""
pass
class CustomSessionDeleteOtherView(UserSessionOverride, SessionDeleteOtherView):
"""Revert to settings after session delete."""
pass
@ -607,7 +613,9 @@ class AppearanceSelectView(RedirectView):
def get_user_theme(self):
"""Get current user color theme."""
try:
user_theme = common_models.ColorTheme.objects.filter(user=self.request.user).get()
user_theme = common_models.ColorTheme.objects.filter(
user=self.request.user
).get()
except common_models.ColorTheme.DoesNotExist:
user_theme = None
@ -638,18 +646,18 @@ class AppearanceSelectView(RedirectView):
class DatabaseStatsView(AjaxView):
"""View for displaying database statistics."""
ajax_template_name = "stats.html"
ajax_form_title = _("System Information")
ajax_template_name = 'stats.html'
ajax_form_title = _('System Information')
class AboutView(AjaxView):
"""A view for displaying InvenTree version information"""
ajax_template_name = "about.html"
ajax_form_title = _("About InvenTree")
ajax_template_name = 'about.html'
ajax_form_title = _('About InvenTree')
class NotificationsView(TemplateView):
"""View for showing notifications."""
template_name = "InvenTree/notifications/notifications.html"
template_name = 'InvenTree/notifications/notifications.html'

View File

@ -10,6 +10,8 @@ import os # pragma: no cover
from django.core.wsgi import get_wsgi_application # pragma: no cover
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "InvenTree.settings") # pragma: no cover
os.environ.setdefault(
'DJANGO_SETTINGS_MODULE', 'InvenTree.settings'
) # pragma: no cover
application = get_wsgi_application() # pragma: no cover

View File

@ -22,7 +22,7 @@ class SettingsAdmin(ImportExportModelAdmin):
class UserSettingsAdmin(ImportExportModelAdmin):
"""Admin settings for InvenTreeUserSetting."""
list_display = ('key', 'value', 'user', )
list_display = ('key', 'value', 'user')
def get_readonly_fields(self, request, obj=None): # pragma: no cover
"""Prevent the 'key' field being edited once the setting is created."""
@ -40,23 +40,31 @@ class WebhookAdmin(ImportExportModelAdmin):
class NotificationEntryAdmin(admin.ModelAdmin):
"""Admin settings for NotificationEntry."""
list_display = ('key', 'uid', 'updated', )
list_display = ('key', 'uid', 'updated')
class NotificationMessageAdmin(admin.ModelAdmin):
"""Admin settings for NotificationMessage."""
list_display = ('age_human', 'user', 'category', 'name', 'read', 'target_object', 'source_object', )
list_display = (
'age_human',
'user',
'category',
'name',
'read',
'target_object',
'source_object',
)
list_filter = ('category', 'read', 'user', )
list_filter = ('category', 'read', 'user')
search_fields = ('name', 'category', 'message', )
search_fields = ('name', 'category', 'message')
class NewsFeedEntryAdmin(admin.ModelAdmin):
"""Admin settings for NewsFeedEntry."""
list_display = ('title', 'author', 'published', 'summary', )
list_display = ('title', 'author', 'published', 'summary')
admin.site.register(common.models.InvenTreeSetting, SettingsAdmin)

View File

@ -23,8 +23,13 @@ from InvenTree.api import BulkDeleteMixin, MetadataView
from InvenTree.config import CONFIG_LOOKUPS
from InvenTree.filters import ORDER_FILTER, SEARCH_ORDER_FILTER
from InvenTree.helpers import inheritors
from InvenTree.mixins import (ListAPI, ListCreateAPI, RetrieveAPI,
RetrieveUpdateAPI, RetrieveUpdateDestroyAPI)
from InvenTree.mixins import (
ListAPI,
ListCreateAPI,
RetrieveAPI,
RetrieveUpdateAPI,
RetrieveUpdateDestroyAPI,
)
from InvenTree.permissions import IsStaffOrReadOnly, IsSuperuser
from plugin.models import NotificationUserSetting
from plugin.serializers import NotificationUserSettingSerializer
@ -41,6 +46,7 @@ class CsrfExemptMixin(object):
class WebhookView(CsrfExemptMixin, APIView):
"""Endpoint for receiving webhooks."""
authentication_classes = []
permission_classes = []
model_class = common.models.WebhookEndpoint
@ -66,8 +72,7 @@ class WebhookView(CsrfExemptMixin, APIView):
async_task(self._process_payload, message.id)
else:
self._process_result(
self.webhook.process_payload(message, payload, headers),
message,
self.webhook.process_payload(message, payload, headers), message
)
data = self.webhook.get_return(payload, headers, request)
@ -76,8 +81,7 @@ class WebhookView(CsrfExemptMixin, APIView):
def _process_payload(self, message_id):
message = common.models.WebhookMessage.objects.get(message_id=message_id)
self._process_result(
self.webhook.process_payload(message, message.body, message.header),
message,
self.webhook.process_payload(message, message.body, message.header), message
)
def _process_result(self, result, message):
@ -108,9 +112,7 @@ class WebhookView(CsrfExemptMixin, APIView):
class CurrencyExchangeView(APIView):
"""API endpoint for displaying currency information"""
permission_classes = [
permissions.IsAuthenticated,
]
permission_classes = [permissions.IsAuthenticated]
def get(self, request, format=None):
"""Return information on available currency conversions"""
@ -133,7 +135,9 @@ class CurrencyExchangeView(APIView):
updated = None
response = {
'base_currency': common.models.InvenTreeSetting.get_setting('INVENTREE_DEFAULT_CURRENCY', 'USD'),
'base_currency': common.models.InvenTreeSetting.get_setting(
'INVENTREE_DEFAULT_CURRENCY', 'USD'
),
'exchange_rates': {},
'updated': updated,
}
@ -150,10 +154,7 @@ class CurrencyRefreshView(APIView):
User must be a 'staff' user to access this endpoint
"""
permission_classes = [
permissions.IsAuthenticated,
permissions.IsAdminUser,
]
permission_classes = [permissions.IsAuthenticated, permissions.IsAdminUser]
def post(self, request, *args, **kwargs):
"""Performing a POST request will update currency exchange rates"""
@ -161,9 +162,7 @@ class CurrencyRefreshView(APIView):
update_exchange_rates(force=True)
return Response({
'success': 'Exchange rates updated',
})
return Response({'success': 'Exchange rates updated'})
class SettingsList(ListAPI):
@ -174,21 +173,15 @@ class SettingsList(ListAPI):
filter_backends = SEARCH_ORDER_FILTER
ordering_fields = [
'pk',
'key',
'name',
]
ordering_fields = ['pk', 'key', 'name']
search_fields = [
'key',
]
search_fields = ['key']
class GlobalSettingsList(SettingsList):
"""API endpoint for accessing a list of global settings objects."""
queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith="_")
queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith='_')
serializer_class = common.serializers.GlobalSettingsSerializer
def list(self, request, *args, **kwargs):
@ -221,25 +214,24 @@ class GlobalSettingsDetail(RetrieveUpdateAPI):
"""
lookup_field = 'key'
queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith="_")
queryset = common.models.InvenTreeSetting.objects.exclude(key__startswith='_')
serializer_class = common.serializers.GlobalSettingsSerializer
def get_object(self):
"""Attempt to find a global setting object with the provided key."""
key = str(self.kwargs['key']).upper()
if key.startswith('_') or key not in common.models.InvenTreeSetting.SETTINGS.keys():
if (
key.startswith('_')
or key not in common.models.InvenTreeSetting.SETTINGS.keys()
):
raise NotFound()
return common.models.InvenTreeSetting.get_setting_object(
key,
cache=False, create=True
key, cache=False, create=True
)
permission_classes = [
permissions.IsAuthenticated,
GlobalSettingsPermissions,
]
permission_classes = [permissions.IsAuthenticated, GlobalSettingsPermissions]
class UserSettingsList(SettingsList):
@ -294,18 +286,17 @@ class UserSettingsDetail(RetrieveUpdateAPI):
"""Attempt to find a user setting object with the provided key."""
key = str(self.kwargs['key']).upper()
if key.startswith('_') or key not in common.models.InvenTreeUserSetting.SETTINGS.keys():
if (
key.startswith('_')
or key not in common.models.InvenTreeUserSetting.SETTINGS.keys()
):
raise NotFound()
return common.models.InvenTreeUserSetting.get_setting_object(
key,
user=self.request.user,
cache=False, create=True
key, user=self.request.user, cache=False, create=True
)
permission_classes = [
UserSettingsPermissions,
]
permission_classes = [UserSettingsPermissions]
class NotificationUserSettingsList(SettingsList):
@ -334,39 +325,29 @@ class NotificationUserSettingsDetail(RetrieveUpdateAPI):
queryset = NotificationUserSetting.objects.all()
serializer_class = NotificationUserSettingSerializer
permission_classes = [UserSettingsPermissions, ]
permission_classes = [UserSettingsPermissions]
class NotificationMessageMixin:
"""Generic mixin for NotificationMessage."""
queryset = common.models.NotificationMessage.objects.all()
serializer_class = common.serializers.NotificationMessageSerializer
permission_classes = [UserSettingsPermissions, ]
permission_classes = [UserSettingsPermissions]
class NotificationList(NotificationMessageMixin, BulkDeleteMixin, ListAPI):
"""List view for all notifications of the current user."""
permission_classes = [permissions.IsAuthenticated, ]
permission_classes = [permissions.IsAuthenticated]
filter_backends = SEARCH_ORDER_FILTER
ordering_fields = [
'category',
'name',
'read',
'creation',
]
ordering_fields = ['category', 'name', 'read', 'creation']
search_fields = [
'name',
'message',
]
search_fields = ['name', 'message']
filterset_fields = [
'category',
'read',
]
filterset_fields = ['category', 'read']
def filter_queryset(self, queryset):
"""Only list notifications which apply to the current user."""
@ -401,29 +382,27 @@ class NotificationReadAll(NotificationMessageMixin, RetrieveAPI):
self.queryset.filter(user=request.user, read=False).update(read=True)
return Response({'status': 'ok'})
except Exception as exc:
raise serializers.ValidationError(detail=serializers.as_serializer_error(exc))
raise serializers.ValidationError(
detail=serializers.as_serializer_error(exc)
)
class NewsFeedMixin:
"""Generic mixin for NewsFeedEntry."""
queryset = common.models.NewsFeedEntry.objects.all()
serializer_class = common.serializers.NewsFeedEntrySerializer
permission_classes = [IsAdminUser, ]
permission_classes = [IsAdminUser]
class NewsFeedEntryList(NewsFeedMixin, BulkDeleteMixin, ListAPI):
"""List view for all news items."""
filter_backends = ORDER_FILTER
ordering_fields = [
'published',
'author',
'read',
]
ordering_fields = ['published', 'author', 'read']
filterset_fields = [
'read',
]
filterset_fields = ['read']
class NewsFeedEntryDetail(NewsFeedMixin, RetrieveUpdateDestroyAPI):
@ -435,14 +414,14 @@ class ConfigList(ListAPI):
queryset = CONFIG_LOOKUPS
serializer_class = common.serializers.ConfigSerializer
permission_classes = [IsSuperuser, ]
permission_classes = [IsSuperuser]
class ConfigDetail(RetrieveAPI):
"""Detail view for an individual configuration."""
serializer_class = common.serializers.ConfigSerializer
permission_classes = [IsSuperuser, ]
permission_classes = [IsSuperuser]
def get_object(self):
"""Attempt to find a config object with the provided key."""
@ -458,7 +437,7 @@ class NotesImageList(ListCreateAPI):
queryset = common.models.NotesImage.objects.all()
serializer_class = common.serializers.NotesImageSerializer
permission_classes = [permissions.IsAuthenticated, ]
permission_classes = [permissions.IsAuthenticated]
def perform_create(self, serializer):
"""Create (upload) a new notes image"""
@ -475,14 +454,9 @@ class ProjectCodeList(ListCreateAPI):
permission_classes = [permissions.IsAuthenticated, IsStaffOrReadOnly]
filter_backends = SEARCH_ORDER_FILTER
ordering_fields = [
'code',
]
ordering_fields = ['code']
search_fields = [
'code',
'description',
]
search_fields = ['code', 'description']
class ProjectCodeDetail(RetrieveUpdateDestroyAPI):
@ -515,14 +489,14 @@ class FlagList(ListAPI):
queryset = settings.FLAGS
serializer_class = common.serializers.FlagSerializer
permission_classes = [permissions.AllowAny, ]
permission_classes = [permissions.AllowAny]
class FlagDetail(RetrieveAPI):
"""Detail view for an individual feature flag."""
serializer_class = common.serializers.FlagSerializer
permission_classes = [permissions.AllowAny, ]
permission_classes = [permissions.AllowAny]
def get_object(self):
"""Attempt to find a config object with the provided key."""
@ -535,97 +509,175 @@ class FlagDetail(RetrieveAPI):
settings_api_urls = [
# User settings
re_path(r'^user/', include([
# User Settings Detail
re_path(r'^(?P<key>\w+)/', UserSettingsDetail.as_view(), name='api-user-setting-detail'),
# User Settings List
re_path(r'^.*$', UserSettingsList.as_view(), name='api-user-setting-list'),
])),
re_path(
r'^user/',
include([
# User Settings Detail
re_path(
r'^(?P<key>\w+)/',
UserSettingsDetail.as_view(),
name='api-user-setting-detail',
),
# User Settings List
re_path(r'^.*$', UserSettingsList.as_view(), name='api-user-setting-list'),
]),
),
# Notification settings
re_path(r'^notification/', include([
# Notification Settings Detail
path(r'<int:pk>/', NotificationUserSettingsDetail.as_view(), name='api-notification-setting-detail'),
# Notification Settings List
re_path(r'^.*$', NotificationUserSettingsList.as_view(), name='api-notification-setting-list'),
])),
re_path(
r'^notification/',
include([
# Notification Settings Detail
path(
r'<int:pk>/',
NotificationUserSettingsDetail.as_view(),
name='api-notification-setting-detail',
),
# Notification Settings List
re_path(
r'^.*$',
NotificationUserSettingsList.as_view(),
name='api-notification-setting-list',
),
]),
),
# Global settings
re_path(r'^global/', include([
# Global Settings Detail
re_path(r'^(?P<key>\w+)/', GlobalSettingsDetail.as_view(), name='api-global-setting-detail'),
# Global Settings List
re_path(r'^.*$', GlobalSettingsList.as_view(), name='api-global-setting-list'),
])),
re_path(
r'^global/',
include([
# Global Settings Detail
re_path(
r'^(?P<key>\w+)/',
GlobalSettingsDetail.as_view(),
name='api-global-setting-detail',
),
# Global Settings List
re_path(
r'^.*$', GlobalSettingsList.as_view(), name='api-global-setting-list'
),
]),
),
]
common_api_urls = [
# Webhooks
path('webhook/<slug:endpoint>/', WebhookView.as_view(), name='api-webhook'),
# Uploaded images for notes
re_path(r'^notes-image-upload/', NotesImageList.as_view(), name='api-notes-image-list'),
re_path(
r'^notes-image-upload/', NotesImageList.as_view(), name='api-notes-image-list'
),
# Project codes
re_path(r'^project-code/', include([
path(r'<int:pk>/', include([
re_path(r'^metadata/', MetadataView.as_view(), {'model': common.models.ProjectCode}, name='api-project-code-metadata'),
re_path(r'^.*$', ProjectCodeDetail.as_view(), name='api-project-code-detail'),
])),
re_path(r'^.*$', ProjectCodeList.as_view(), name='api-project-code-list'),
])),
re_path(
r'^project-code/',
include([
path(
r'<int:pk>/',
include([
re_path(
r'^metadata/',
MetadataView.as_view(),
{'model': common.models.ProjectCode},
name='api-project-code-metadata',
),
re_path(
r'^.*$',
ProjectCodeDetail.as_view(),
name='api-project-code-detail',
),
]),
),
re_path(r'^.*$', ProjectCodeList.as_view(), name='api-project-code-list'),
]),
),
# Custom physical units
re_path(r'^units/', include([
path(r'<int:pk>/', include([
re_path(r'^.*$', CustomUnitDetail.as_view(), name='api-custom-unit-detail'),
])),
re_path(r'^.*$', CustomUnitList.as_view(), name='api-custom-unit-list'),
])),
re_path(
r'^units/',
include([
path(
r'<int:pk>/',
include([
re_path(
r'^.*$',
CustomUnitDetail.as_view(),
name='api-custom-unit-detail',
)
]),
),
re_path(r'^.*$', CustomUnitList.as_view(), name='api-custom-unit-list'),
]),
),
# Currencies
re_path(r'^currency/', include([
re_path(r'^exchange/', CurrencyExchangeView.as_view(), name='api-currency-exchange'),
re_path(r'^refresh/', CurrencyRefreshView.as_view(), name='api-currency-refresh'),
])),
re_path(
r'^currency/',
include([
re_path(
r'^exchange/',
CurrencyExchangeView.as_view(),
name='api-currency-exchange',
),
re_path(
r'^refresh/', CurrencyRefreshView.as_view(), name='api-currency-refresh'
),
]),
),
# Notifications
re_path(r'^notifications/', include([
# Individual purchase order detail URLs
path(r'<int:pk>/', include([
re_path(r'.*$', NotificationDetail.as_view(), name='api-notifications-detail'),
])),
# Read all
re_path(r'^readall/', NotificationReadAll.as_view(), name='api-notifications-readall'),
# Notification messages list
re_path(r'^.*$', NotificationList.as_view(), name='api-notifications-list'),
])),
re_path(
r'^notifications/',
include([
# Individual purchase order detail URLs
path(
r'<int:pk>/',
include([
re_path(
r'.*$',
NotificationDetail.as_view(),
name='api-notifications-detail',
)
]),
),
# Read all
re_path(
r'^readall/',
NotificationReadAll.as_view(),
name='api-notifications-readall',
),
# Notification messages list
re_path(r'^.*$', NotificationList.as_view(), name='api-notifications-list'),
]),
),
# News
re_path(r'^news/', include([
path(r'<int:pk>/', include([
re_path(r'.*$', NewsFeedEntryDetail.as_view(), name='api-news-detail'),
])),
re_path(r'^.*$', NewsFeedEntryList.as_view(), name='api-news-list'),
])),
re_path(
r'^news/',
include([
path(
r'<int:pk>/',
include([
re_path(
r'.*$', NewsFeedEntryDetail.as_view(), name='api-news-detail'
)
]),
),
re_path(r'^.*$', NewsFeedEntryList.as_view(), name='api-news-list'),
]),
),
# Flags
path('flags/', include([
path('<str:key>/', FlagDetail.as_view(), name='api-flag-detail'),
re_path(r'^.*$', FlagList.as_view(), name='api-flag-list'),
])),
path(
'flags/',
include([
path('<str:key>/', FlagDetail.as_view(), name='api-flag-detail'),
re_path(r'^.*$', FlagList.as_view(), name='api-flag-list'),
]),
),
# Status
path('generic/status/', include([
path(f'<str:{StatusView.MODEL_REF}>/', include([
path('', StatusView.as_view(), name='api-status'),
])),
path('', AllStatusViews.as_view(), name='api-status-all'),
])),
path(
'generic/status/',
include([
path(
f'<str:{StatusView.MODEL_REF}>/',
include([path('', StatusView.as_view(), name='api-status')]),
),
path('', AllStatusViews.as_view(), name='api-status-all'),
]),
),
]
admin_api_urls = [

View File

@ -30,10 +30,14 @@ class CommonConfig(AppConfig):
try:
import common.models
if common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED', backup_value=False, create=False, cache=False):
logger.info("Clearing SERVER_RESTART_REQUIRED flag")
if common.models.InvenTreeSetting.get_setting(
'SERVER_RESTART_REQUIRED', backup_value=False, create=False, cache=False
):
logger.info('Clearing SERVER_RESTART_REQUIRED flag')
if not InvenTree.ready.isImportingData():
common.models.InvenTreeSetting.set_setting('SERVER_RESTART_REQUIRED', False, None)
common.models.InvenTreeSetting.set_setting(
'SERVER_RESTART_REQUIRED', False, None
)
except Exception:
pass

View File

@ -49,12 +49,12 @@ class FileManager:
ext = os.path.splitext(file.name)[-1].lower().replace('.', '')
try:
if ext in ['csv', 'tsv', ]:
if ext in ['csv', 'tsv']:
# These file formats need string decoding
raw_data = file.read().decode('utf-8')
# Reset stream position to beginning of file
file.seek(0)
elif ext in ['xls', 'xlsx', 'json', 'yaml', ]:
elif ext in ['xls', 'xlsx', 'json', 'yaml']:
raw_data = file.read()
# Reset stream position to beginning of file
file.seek(0)
@ -81,7 +81,12 @@ class FileManager:
def update_headers(self):
"""Update headers."""
self.HEADERS = self.REQUIRED_HEADERS + self.ITEM_MATCH_HEADERS + self.OPTIONAL_MATCH_HEADERS + self.OPTIONAL_HEADERS
self.HEADERS = (
self.REQUIRED_HEADERS
+ self.ITEM_MATCH_HEADERS
+ self.OPTIONAL_MATCH_HEADERS
+ self.OPTIONAL_HEADERS
)
def setup(self):
"""Setup headers should be overridden in usage to set the Different Headers."""
@ -149,15 +154,9 @@ class FileManager:
break
if not guess_exists:
headers.append({
'name': header,
'guess': guess
})
headers.append({'name': header, 'guess': guess})
else:
headers.append({
'name': header,
'guess': None
})
headers.append({'name': header, 'guess': None})
return headers
@ -180,7 +179,6 @@ class FileManager:
rows = []
for i in range(self.row_count()):
data = list(self.get_row_data(i))
# Is the row completely empty? Skip!
@ -203,10 +201,7 @@ class FileManager:
if empty:
continue
row = {
'data': data,
'index': i
}
row = {'data': data, 'index': i}
rows.append(row)

View File

@ -9,10 +9,7 @@ from .files import FileManager
class UploadFileForm(forms.Form):
"""Step 1 of FileManagementFormView."""
file = forms.FileField(
label=_('File'),
help_text=_('Select file to upload'),
)
file = forms.FileField(label=_('File'), help_text=_('Select file to upload'))
def __init__(self, *args, **kwargs):
"""Update label and help_text."""
@ -67,9 +64,7 @@ class MatchFieldForm(forms.Form):
self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + headers_choices,
required=False,
widget=forms.Select(attrs={
'class': 'select fieldselect',
})
widget=forms.Select(attrs={'class': 'select fieldselect'}),
)
if col['guess']:
self.fields[field_name].initial = col['guess']
@ -107,7 +102,9 @@ class MatchItemForm(forms.Form):
field_name = col_guess.lower() + '-' + str(row['index'])
# check if field def was overridden
overriden_field = self.get_special_field(col_guess, row, file_manager)
overriden_field = self.get_special_field(
col_guess, row, file_manager
)
if overriden_field:
self.fields[field_name] = overriden_field
@ -117,23 +114,23 @@ class MatchItemForm(forms.Form):
value = row.get(col_guess.lower(), '')
# Set field input box
self.fields[field_name] = forms.CharField(
required=True,
initial=value,
required=True, initial=value
)
# Create item selection box
elif col_guess in file_manager.OPTIONAL_MATCH_HEADERS:
# Get item options
item_options = [(option.id, option) for option in row['match_options_' + col_guess]]
item_options = [
(option.id, option)
for option in row['match_options_' + col_guess]
]
# Get item match
item_match = row['match_' + col_guess]
# Set field select box
self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + item_options,
required=False,
widget=forms.Select(attrs={
'class': 'select bomselect',
})
widget=forms.Select(attrs={'class': 'select bomselect'}),
)
# Update select box when match was found
if item_match:
@ -142,7 +139,9 @@ class MatchItemForm(forms.Form):
# Create item selection box
elif col_guess in file_manager.ITEM_MATCH_HEADERS:
# Get item options
item_options = [(option.id, option) for option in row['item_options']]
item_options = [
(option.id, option) for option in row['item_options']
]
# Get item match
item_match = row['item_match']
# Set field name
@ -151,9 +150,7 @@ class MatchItemForm(forms.Form):
self.fields[field_name] = forms.ChoiceField(
choices=[('', '-' * 10)] + item_options,
required=False,
widget=forms.Select(attrs={
'class': 'select bomselect',
})
widget=forms.Select(attrs={'class': 'select bomselect'}),
)
# Update select box when match was found
if item_match:
@ -169,8 +166,7 @@ class MatchItemForm(forms.Form):
value = row.get(col_guess.lower(), '')
# Set field input box
self.fields[field_name] = forms.CharField(
required=False,
initial=value,
required=False, initial=value
)
def get_special_field(self, col_guess, row, file_manager):

File diff suppressed because it is too large Load Diff

View File

@ -24,7 +24,7 @@ class NotificationMethod:
METHOD_NAME = ''
METHOD_ICON = None
CONTEXT_BUILTIN = ['name', 'message', ]
CONTEXT_BUILTIN = ['name', 'message']
CONTEXT_EXTRA = []
GLOBAL_SETTING = None
USER_SETTING = None
@ -39,11 +39,15 @@ class NotificationMethod:
"""
# Check if a sending fnc is defined
if (not hasattr(self, 'send')) and (not hasattr(self, 'send_bulk')):
raise NotImplementedError('A NotificationMethod must either define a `send` or a `send_bulk` method')
raise NotImplementedError(
'A NotificationMethod must either define a `send` or a `send_bulk` method'
)
# No method name is no good
if self.METHOD_NAME in ('', None):
raise NotImplementedError(f'The NotificationMethod {self.__class__} did not provide a METHOD_NAME')
raise NotImplementedError(
f'The NotificationMethod {self.__class__} did not provide a METHOD_NAME'
)
# Check if plugin is disabled - if so do not gather targets etc.
if self.global_setting_disable():
@ -61,9 +65,10 @@ class NotificationMethod:
def check_context(self, context):
"""Check that all values defined in the methods CONTEXT were provided in the current context."""
def check(ref, obj):
# the obj is not accessible so we are on the end
if not isinstance(obj, (list, dict, tuple, )):
if not isinstance(obj, (list, dict, tuple)):
return ref
# check if the ref exists
@ -82,7 +87,9 @@ class NotificationMethod:
return check(ref[1:], obj[ref[0]])
# other cases -> raise
raise NotImplementedError('This type can not be used as a context reference')
raise NotImplementedError(
'This type can not be used as a context reference'
)
missing = []
for item in (*self.CONTEXT_BUILTIN, *self.CONTEXT_EXTRA):
@ -91,7 +98,9 @@ class NotificationMethod:
missing.append(ret)
if missing:
raise NotImplementedError(f'The `context` is missing the following items:\n{missing}')
raise NotImplementedError(
f'The `context` is missing the following items:\n{missing}'
)
return context
@ -142,7 +151,12 @@ class NotificationMethod:
def usersetting(self, target):
"""Returns setting for this method for a given user."""
return NotificationUserSetting.get_setting(f'NOTIFICATION_METHOD_{self.METHOD_NAME.upper()}', user=target, method=self.METHOD_NAME)
return NotificationUserSetting.get_setting(
f'NOTIFICATION_METHOD_{self.METHOD_NAME.upper()}',
user=target,
method=self.METHOD_NAME,
)
# endregion
@ -160,6 +174,8 @@ class BulkNotificationMethod(NotificationMethod):
def send_bulk(self):
"""This function must be overridden."""
raise NotImplementedError('The `send` method must be overridden!')
# endregion
@ -181,17 +197,25 @@ class MethodStorageClass:
selected_classes (class, optional): References to the classes that should be registered. Defaults to None.
"""
logger.debug('Collecting notification methods')
current_method = InvenTree.helpers.inheritors(NotificationMethod) - IGNORED_NOTIFICATION_CLS
current_method = (
InvenTree.helpers.inheritors(NotificationMethod) - IGNORED_NOTIFICATION_CLS
)
# for testing selective loading is made available
if selected_classes:
current_method = [item for item in current_method if item is selected_classes]
current_method = [
item for item in current_method if item is selected_classes
]
# make sure only one of each method is added
filtered_list = {}
for item in current_method:
plugin = item.get_plugin(item)
ref = f'{plugin.package_path}_{item.METHOD_NAME}' if plugin else item.METHOD_NAME
ref = (
f'{plugin.package_path}_{item.METHOD_NAME}'
if plugin
else item.METHOD_NAME
)
item.plugin = plugin() if plugin else None
filtered_list[ref] = item
@ -217,9 +241,7 @@ class MethodStorageClass:
# make sure the setting exists
self.user_settings[new_key] = item.USER_SETTING
NotificationUserSetting.get_setting(
key=new_key,
user=user,
method=item.METHOD_NAME,
key=new_key, user=user, method=item.METHOD_NAME
)
# save definition
@ -231,7 +253,7 @@ class MethodStorageClass:
return methods
IGNORED_NOTIFICATION_CLS = {SingleNotificationMethod, BulkNotificationMethod, }
IGNORED_NOTIFICATION_CLS = {SingleNotificationMethod, BulkNotificationMethod}
storage = MethodStorageClass()
@ -275,6 +297,7 @@ class NotificationBody:
app_label: App label (slugified) of the model
model_name': Name (slugified) of the model
"""
name: str
slug: str
message: str
@ -286,24 +309,25 @@ class InvenTreeNotificationBodies:
Contains regularly used notification bodies.
"""
NewOrder = NotificationBody(
name=_("New {verbose_name}"),
name=_('New {verbose_name}'),
slug='{app_label}.new_{model_name}',
message=_("A new order has been created and assigned to you"),
message=_('A new order has been created and assigned to you'),
template='email/new_order_assigned.html',
)
"""Send when a new order (build, sale or purchase) was created."""
OrderCanceled = NotificationBody(
name=_("{verbose_name} canceled"),
name=_('{verbose_name} canceled'),
slug='{app_label}.canceled_{model_name}',
message=_("A order that is assigned to you was canceled"),
message=_('A order that is assigned to you was canceled'),
template='email/canceled_order_assigned.html',
)
"""Send when a order (sale, return or purchase) was canceled."""
ItemsReceived = NotificationBody(
name=_("Items Received"),
name=_('Items Received'),
slug='purchase_order.items_received',
message=_('Items have been received against a purchase order'),
template='email/purchase_order_received.html',
@ -340,13 +364,19 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
if not obj_ref_value:
obj_ref_value = getattr(obj, 'id', None)
if not obj_ref_value:
raise KeyError(f"Could not resolve an object reference for '{str(obj)}' with {obj_ref}, pk, id")
raise KeyError(
f"Could not resolve an object reference for '{str(obj)}' with {obj_ref}, pk, id"
)
# Check if we have notified recently...
delta = timedelta(days=1)
if common.models.NotificationEntry.check_recent(category, obj_ref_value, delta):
logger.info("Notification '%s' has recently been sent for '%s' - SKIPPING", category, str(obj))
logger.info(
"Notification '%s' has recently been sent for '%s' - SKIPPING",
category,
str(obj),
)
return
logger.info("Gathering users for notification '%s'", category)
@ -383,7 +413,9 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
target_users.add(user)
# Unhandled type
else:
logger.error("Unknown target passed to trigger_notification method: %s", target)
logger.error(
'Unknown target passed to trigger_notification method: %s', target
)
if target_users:
logger.info("Sending notification '%s' for '%s'", category, str(obj))
@ -392,7 +424,7 @@ def trigger_notification(obj, category=None, obj_ref='pk', **kwargs):
if delivery_methods is None:
delivery_methods = storage.liste
else:
delivery_methods = (delivery_methods - IGNORED_NOTIFICATION_CLS)
delivery_methods = delivery_methods - IGNORED_NOTIFICATION_CLS
for method in delivery_methods:
logger.info("Triggering notification method '%s'", method.METHOD_NAME)
@ -422,17 +454,15 @@ def trigger_superuser_notification(plugin: PluginConfig, msg: str):
trigger_notification(
plugin,
'inventree.plugin',
context={
'error': plugin,
'name': _('Error raised by plugin'),
'message': msg,
},
context={'error': plugin, 'name': _('Error raised by plugin'), 'message': msg},
targets=users,
delivery_methods={UIMessageNotification, },
delivery_methods={UIMessageNotification},
)
def deliver_notification(cls: NotificationMethod, obj, category: str, targets, context: dict):
def deliver_notification(
cls: NotificationMethod, obj, category: str, targets, context: dict
):
"""Send notification with the provided class.
This:
@ -447,7 +477,12 @@ def deliver_notification(cls: NotificationMethod, obj, category: str, targets, c
if method.targets and len(method.targets) > 0:
# Log start
logger.info("Notify users via '%s' for notification '%s' for '%s'", method.METHOD_NAME, category, str(obj))
logger.info(
"Notify users via '%s' for notification '%s' for '%s'",
method.METHOD_NAME,
category,
str(obj),
)
# Run setup for delivery method
method.setup()
@ -472,6 +507,12 @@ def deliver_notification(cls: NotificationMethod, obj, category: str, targets, c
method.cleanup()
# Log results
logger.info("Notified %s users via '%s' for notification '%s' for '%s' successfully", success_count, method.METHOD_NAME, category, str(obj))
logger.info(
"Notified %s users via '%s' for notification '%s' for '%s' successfully",
success_count,
method.METHOD_NAME,
category,
str(obj),
)
if not success:
logger.info("There were some problems")
logger.info('There were some problems')

View File

@ -1,6 +1,5 @@
"""JSON serializers for common components."""
from django.urls import reverse
from flags.state import flag_state
@ -9,8 +8,10 @@ from rest_framework import serializers
import common.models as common_models
from InvenTree.helpers import get_objectreference
from InvenTree.helpers_model import construct_absolute_url
from InvenTree.serializers import (InvenTreeImageSerializerField,
InvenTreeModelSerializer)
from InvenTree.serializers import (
InvenTreeImageSerializerField,
InvenTreeModelSerializer,
)
from users.serializers import OwnerSerializer
@ -62,10 +63,7 @@ class SettingsSerializer(InvenTreeModelSerializer):
if choices:
for choice in choices:
results.append({
'value': choice[0],
'display_name': choice[1],
})
results.append({'value': choice[0], 'display_name': choice[1]})
return results
@ -131,8 +129,10 @@ class GenericReferencedSettingSerializer(SettingsSerializer):
def __init__(self, *args, **kwargs):
"""Init overrides the Meta class to make it dynamic."""
class CustomMeta:
"""Scaffold for custom Meta class."""
fields = [
'pk',
'key',
@ -204,10 +204,12 @@ class NotificationMessageSerializer(InvenTreeModelSerializer):
request = self.context['request']
if request.user and request.user.is_staff:
meta = obj.target_object._meta
target['link'] = construct_absolute_url(reverse(
f'admin:{meta.db_table}_change',
kwargs={'object_id': obj.target_object_id}
))
target['link'] = construct_absolute_url(
reverse(
f'admin:{meta.db_table}_change',
kwargs={'object_id': obj.target_object_id},
)
)
return target
@ -257,17 +259,9 @@ class NotesImageSerializer(InvenTreeModelSerializer):
"""Meta options for NotesImageSerializer."""
model = common_models.NotesImage
fields = [
'pk',
'image',
'user',
'date',
]
fields = ['pk', 'image', 'user', 'date']
read_only_fields = [
'date',
'user',
]
read_only_fields = ['date', 'user']
image = InvenTreeImageSerializerField(required=True)
@ -279,13 +273,7 @@ class ProjectCodeSerializer(InvenTreeModelSerializer):
"""Meta options for ProjectCodeSerializer."""
model = common_models.ProjectCode
fields = [
'pk',
'code',
'description',
'responsible',
'responsible_detail',
]
fields = ['pk', 'code', 'description', 'responsible', 'responsible_detail']
responsible_detail = OwnerSerializer(source='responsible', read_only=True)
@ -313,9 +301,4 @@ class CustomUnitSerializer(InvenTreeModelSerializer):
"""Meta options for CustomUnitSerializer."""
model = common_models.CustomUnit
fields = [
'pk',
'name',
'symbol',
'definition',
]
fields = ['pk', 'name', 'symbol', 'definition']

View File

@ -20,7 +20,9 @@ def currency_code_default():
return cached_value
try:
code = InvenTreeSetting.get_setting('INVENTREE_DEFAULT_CURRENCY', backup_value='', create=True, cache=True)
code = InvenTreeSetting.get_setting(
'INVENTREE_DEFAULT_CURRENCY', backup_value='', create=True, cache=True
)
except Exception: # pragma: no cover
# Database may not yet be ready, no need to throw an error here
code = ''

View File

@ -27,7 +27,9 @@ def delete_old_notifications():
try:
from common.models import NotificationEntry
except AppRegistryNotReady: # pragma: no cover
logger.info("Could not perform 'delete_old_notifications' - App registry not ready")
logger.info(
"Could not perform 'delete_old_notifications' - App registry not ready"
)
return
before = timezone.now() - timedelta(days=90)
@ -49,7 +51,7 @@ def update_news_feed():
try:
d = feedparser.parse(settings.INVENTREE_NEWS_URL)
except Exception as entry: # pragma: no cover
logger.warning("update_news_feed: Error parsing the newsfeed", entry)
logger.warning('update_news_feed: Error parsing the newsfeed', entry)
return
# Get a reference list
@ -87,13 +89,15 @@ def delete_old_notes_images():
try:
from common.models import NotesImage
except AppRegistryNotReady:
logger.info("Could not perform 'delete_old_notes_images' - App registry not ready")
logger.info(
"Could not perform 'delete_old_notes_images' - App registry not ready"
)
return
# Remove any notes which point to non-existent image files
for note in NotesImage.objects.all():
if not os.path.exists(note.image.path):
logger.info("Deleting note %s - image file does not exist", note.image.path)
logger.info('Deleting note %s - image file does not exist', note.image.path)
note.delete()
note_classes = getModelsWithMixin(InvenTreeNotesMixin)
@ -112,7 +116,7 @@ def delete_old_notes_images():
break
if not found:
logger.info("Deleting note %s - image file not linked to a note", img)
logger.info('Deleting note %s - image file not linked to a note', img)
note.delete()
# Finally, remove any images in the notes dir which are not linked to a note
@ -127,7 +131,6 @@ def delete_old_notes_images():
all_notes = NotesImage.objects.all()
for image in images:
found = False
for note in all_notes:
img_path = os.path.basename(note.image.path)
@ -136,5 +139,5 @@ def delete_old_notes_images():
break
if not found:
logger.info("Deleting note %s - image file not linked to a note", image)
logger.info('Deleting note %s - image file not linked to a note', image)
os.remove(os.path.join(notes_dir, image))

View File

@ -1,8 +1,12 @@
"""Tests for basic notification methods and functions in InvenTree."""
import plugin.templatetags.plugin_extras as plugin_tags
from common.notifications import (BulkNotificationMethod, NotificationMethod,
SingleNotificationMethod, storage)
from common.notifications import (
BulkNotificationMethod,
NotificationMethod,
SingleNotificationMethod,
storage,
)
from part.test_part import BaseNotificationIntegrationTest
from plugin.models import NotificationUserSetting
@ -23,37 +27,31 @@ class BaseNotificationTests(BaseNotificationIntegrationTest):
"""A comment so we do not need a pass."""
class NoNameNotificationMethod(NotificationMethod):
def send(self):
"""A comment so we do not need a pass."""
class WrongContextNotificationMethod(NotificationMethod):
METHOD_NAME = 'WrongContextNotification'
CONTEXT_EXTRA = [
'aa',
('aa', 'bb', ),
('templates', 'ccc', ),
(123, )
]
CONTEXT_EXTRA = ['aa', ('aa', 'bb'), ('templates', 'ccc'), (123,)]
def send(self):
"""A comment so we do not need a pass."""
# no send / send bulk
with self.assertRaises(NotImplementedError):
FalseNotificationMethod('', '', '', '', )
FalseNotificationMethod('', '', '', '')
# no METHOD_NAME
with self.assertRaises(NotImplementedError):
NoNameNotificationMethod('', '', '', '', )
NoNameNotificationMethod('', '', '', '')
# a not existent context check
with self.assertRaises(NotImplementedError):
WrongContextNotificationMethod('', '', '', '', )
WrongContextNotificationMethod('', '', '', '')
# no get_targets
with self.assertRaises(NotImplementedError):
AnotherFalseNotificationMethod('', '', '', {'name': 1, 'message': 2, }, )
AnotherFalseNotificationMethod('', '', '', {'name': 1, 'message': 2})
def test_failing_passing(self):
"""Ensure that an error in one deliverymethod is not blocking all mehthods."""
@ -67,7 +65,7 @@ class BaseNotificationTests(BaseNotificationIntegrationTest):
METHOD_NAME = 'ErrorImplementation'
def get_targets(self):
return [1, ]
return [1]
def send(self, target):
raise KeyError('This could be any error')
@ -91,7 +89,7 @@ class BulkNotificationMethodTests(BaseNotificationIntegrationTest):
METHOD_NAME = 'WrongImplementationBulk'
def get_targets(self):
return [1, ]
return [1]
with self.assertLogs(logger='inventree', level='ERROR'):
self._notification_run(WrongImplementation)
@ -113,11 +111,12 @@ class SingleNotificationMethodTests(BaseNotificationIntegrationTest):
METHOD_NAME = 'WrongImplementationSingle'
def get_targets(self):
return [1, ]
return [1]
with self.assertLogs(logger='inventree', level='ERROR'):
self._notification_run(WrongImplementation)
# A integration test for notifications is provided in test_part.PartNotificationTest
@ -144,7 +143,7 @@ class NotificationUserSettingTests(BaseNotificationIntegrationTest):
}
def get_targets(self):
return [1, ]
return [1]
def send_bulk(self):
return True
@ -158,10 +157,14 @@ class NotificationUserSettingTests(BaseNotificationIntegrationTest):
# assertions for settings
self.assertEqual(setting.name, 'Enable test notifications')
self.assertEqual(setting.default_value, True)
self.assertEqual(setting.description, 'Allow sending of test for event notifications')
self.assertEqual(
setting.description, 'Allow sending of test for event notifications'
)
self.assertEqual(setting.units, 'alpha')
# test tag and array
self.assertEqual(plugin_tags.notification_settings_list({'user': self.user}), array)
self.assertEqual(
plugin_tags.notification_settings_list({'user': self.user}), array
)
self.assertEqual(array[0]['key'], 'NOTIFICATION_METHOD_TEST')
self.assertEqual(array[0]['method'], 'test')

View File

@ -15,4 +15,4 @@ class TaskTest(TestCase):
"""Test that the task `delete_old_notifications` runs through without errors."""
# check empty run
self.assertEqual(NotificationEntry.objects.all().count(), 0)
offload_task(common_tasks.delete_old_notifications,)
offload_task(common_tasks.delete_old_notifications)

View File

@ -17,16 +17,23 @@ from django.urls import reverse
import PIL
from InvenTree.helpers import str2bool
from InvenTree.unit_test import (InvenTreeAPITestCase, InvenTreeTestCase,
PluginMixin)
from InvenTree.unit_test import InvenTreeAPITestCase, InvenTreeTestCase, PluginMixin
from plugin import registry
from plugin.models import NotificationUserSetting
from .api import WebhookView
from .models import (ColorTheme, CustomUnit, InvenTreeSetting,
InvenTreeUserSetting, NotesImage, NotificationEntry,
NotificationMessage, ProjectCode, WebhookEndpoint,
WebhookMessage)
from .models import (
ColorTheme,
CustomUnit,
InvenTreeSetting,
InvenTreeUserSetting,
NotesImage,
NotificationEntry,
NotificationMessage,
ProjectCode,
WebhookEndpoint,
WebhookMessage,
)
CONTENT_TYPE_JSON = 'application/json'
@ -34,9 +41,7 @@ CONTENT_TYPE_JSON = 'application/json'
class SettingsTest(InvenTreeTestCase):
"""Tests for the 'settings' model."""
fixtures = [
'settings',
]
fixtures = ['settings']
def test_settings_objects(self):
"""Test fixture loading and lookup for settings."""
@ -50,7 +55,9 @@ class SettingsTest(InvenTreeTestCase):
self.assertEqual(instance_name.value, 'My very first InvenTree Instance')
# Check object lookup (case insensitive)
self.assertEqual(InvenTreeSetting.get_setting_object('iNvEnTrEE_inSTanCE').pk, 1)
self.assertEqual(
InvenTreeSetting.get_setting_object('iNvEnTrEE_inSTanCE').pk, 1
)
def test_settings_functions(self):
"""Test settings functions and properties."""
@ -61,14 +68,25 @@ class SettingsTest(InvenTreeTestCase):
stale_ref = 'STOCK_STALE_DAYS'
stale_days = InvenTreeSetting.get_setting_object(stale_ref, cache=False)
report_size_obj = InvenTreeSetting.get_setting_object('REPORT_DEFAULT_PAGE_SIZE')
report_test_obj = InvenTreeSetting.get_setting_object('REPORT_ENABLE_TEST_REPORT')
report_size_obj = InvenTreeSetting.get_setting_object(
'REPORT_DEFAULT_PAGE_SIZE'
)
report_test_obj = InvenTreeSetting.get_setting_object(
'REPORT_ENABLE_TEST_REPORT'
)
# check settings base fields
self.assertEqual(instance_obj.name, 'Server Instance Name')
self.assertEqual(instance_obj.get_setting_name(instance_ref), 'Server Instance Name')
self.assertEqual(instance_obj.description, 'String descriptor for the server instance')
self.assertEqual(instance_obj.get_setting_description(instance_ref), 'String descriptor for the server instance')
self.assertEqual(
instance_obj.get_setting_name(instance_ref), 'Server Instance Name'
)
self.assertEqual(
instance_obj.description, 'String descriptor for the server instance'
)
self.assertEqual(
instance_obj.get_setting_description(instance_ref),
'String descriptor for the server instance',
)
# check units
self.assertEqual(instance_obj.units, '')
@ -90,7 +108,9 @@ class SettingsTest(InvenTreeTestCase):
# check as_int
self.assertEqual(stale_days.as_int(), 0)
self.assertEqual(instance_obj.as_int(), 'InvenTree') # not an int -> return default
self.assertEqual(
instance_obj.as_int(), 'InvenTree'
) # not an int -> return default
# check as_bool
self.assertEqual(report_test_obj.as_bool(), True)
@ -116,62 +136,66 @@ class SettingsTest(InvenTreeTestCase):
def test_all_settings(self):
"""Make sure that the all_settings function returns correctly"""
result = InvenTreeSetting.all_settings()
self.assertIn("INVENTREE_INSTANCE", result)
self.assertIn('INVENTREE_INSTANCE', result)
self.assertIsInstance(result['INVENTREE_INSTANCE'], InvenTreeSetting)
@mock.patch("common.models.InvenTreeSetting.get_setting_definition")
@mock.patch('common.models.InvenTreeSetting.get_setting_definition')
def test_check_all_settings(self, get_setting_definition):
"""Make sure that the check_all_settings function returns correctly"""
# define partial schema
settings_definition = {
"AB": { # key that's has not already been accessed
"required": True,
'AB': { # key that's has not already been accessed
'required': True
},
"CD": {
"required": True,
"protected": True,
},
"EF": {}
'CD': {'required': True, 'protected': True},
'EF': {},
}
def mocked(key, **kwargs):
return settings_definition.get(key, {})
get_setting_definition.side_effect = mocked
self.assertEqual(InvenTreeSetting.check_all_settings(settings_definition=settings_definition), (False, ["AB", "CD"]))
InvenTreeSetting.set_setting('AB', "hello", self.user)
InvenTreeSetting.set_setting('CD', "world", self.user)
self.assertEqual(
InvenTreeSetting.check_all_settings(
settings_definition=settings_definition
),
(False, ['AB', 'CD']),
)
InvenTreeSetting.set_setting('AB', 'hello', self.user)
InvenTreeSetting.set_setting('CD', 'world', self.user)
self.assertEqual(InvenTreeSetting.check_all_settings(), (True, []))
@mock.patch("common.models.InvenTreeSetting.get_setting_definition")
@mock.patch('common.models.InvenTreeSetting.get_setting_definition')
def test_settings_validator(self, get_setting_definition):
"""Make sure that the validator function gets called on set setting."""
def validator(x):
if x == "hello":
if x == 'hello':
return x
raise ValidationError(f"{x} is not valid")
raise ValidationError(f'{x} is not valid')
mock_validator = mock.Mock(side_effect=validator)
# define partial schema
settings_definition = {
"AB": { # key that's has not already been accessed
"validator": mock_validator,
},
'AB': { # key that's has not already been accessed
'validator': mock_validator
}
}
def mocked(key, **kwargs):
return settings_definition.get(key, {})
get_setting_definition.side_effect = mocked
InvenTreeSetting.set_setting("AB", "hello", self.user)
mock_validator.assert_called_with("hello")
InvenTreeSetting.set_setting('AB', 'hello', self.user)
mock_validator.assert_called_with('hello')
with self.assertRaises(ValidationError):
InvenTreeSetting.set_setting("AB", "world", self.user)
mock_validator.assert_called_with("world")
InvenTreeSetting.set_setting('AB', 'world', self.user)
mock_validator.assert_called_with('world')
def run_settings_check(self, key, setting):
"""Test that all settings are valid.
@ -194,7 +218,9 @@ class SettingsTest(InvenTreeTestCase):
self.assertIn('django.utils.functional.lazy', str(type(description)))
if key != key.upper():
raise ValueError(f"Setting key '{key}' is not uppercase") # pragma: no cover
raise ValueError(
f"Setting key '{key}' is not uppercase"
) # pragma: no cover
# Check that only allowed keys are provided
allowed_keys = [
@ -232,7 +258,6 @@ class SettingsTest(InvenTreeTestCase):
- Ensure that every setting has a description, which is translated
"""
for key, setting in InvenTreeSetting.SETTINGS.items():
try:
self.run_settings_check(key, setting)
except Exception as exc: # pragma: no cover
@ -249,7 +274,6 @@ class SettingsTest(InvenTreeTestCase):
def test_defaults(self):
"""Populate the settings with default values."""
for key in InvenTreeSetting.SETTINGS.keys():
value = InvenTreeSetting.get_setting_default(key)
InvenTreeSetting.set_setting(key, value, self.user)
@ -261,10 +285,14 @@ class SettingsTest(InvenTreeTestCase):
if setting.is_bool():
if setting.default_value in ['', None]:
raise ValueError(f'Default value for boolean setting {key} not provided') # pragma: no cover
raise ValueError(
f'Default value for boolean setting {key} not provided'
) # pragma: no cover
if setting.default_value not in [True, False]:
raise ValueError(f'Non-boolean default value specified for {key}') # pragma: no cover
raise ValueError(
f'Non-boolean default value specified for {key}'
) # pragma: no cover
def test_global_setting_caching(self):
"""Test caching operations for the global settings class"""
@ -294,9 +322,7 @@ class SettingsTest(InvenTreeTestCase):
# Generate a number of new users
for idx in range(5):
get_user_model().objects.create(
username=f"User_{idx}",
password="hunter42",
email="email@dot.com",
username=f'User_{idx}', password='hunter42', email='email@dot.com'
)
key = 'SEARCH_PREVIEW_RESULTS'
@ -305,7 +331,10 @@ class SettingsTest(InvenTreeTestCase):
for user in get_user_model().objects.all():
setting = InvenTreeUserSetting.get_setting_object(key, user=user)
cache_key = setting.cache_key
self.assertEqual(cache_key, f"InvenTreeUserSetting:SEARCH_PREVIEW_RESULTS_user:{user.username}")
self.assertEqual(
cache_key,
f'InvenTreeUserSetting:SEARCH_PREVIEW_RESULTS_user:{user.username}',
)
InvenTreeUserSetting.set_setting(key, user.pk, None, user=user)
self.assertIsNotNone(cache.get(cache_key))
@ -333,7 +362,9 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
response = self.get(url, expected_code=200)
n_public_settings = len([k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_')])
n_public_settings = len([
k for k in InvenTreeSetting.SETTINGS.keys() if not k.startswith('_')
])
# Number of results should match the number of settings
self.assertEqual(len(response.data), n_public_settings)
@ -358,13 +389,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
# Test setting via the API
for val in ['cat', 'hat', 'bat', 'mat']:
response = self.patch(
url,
{
'value': val,
},
expected_code=200
)
response = self.patch(url, {'value': val}, expected_code=200)
self.assertEqual(response.data['value'], val)
@ -374,7 +399,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
def test_api_detail(self):
"""Test that we can access the detail view for a setting based on the <key>."""
# These keys are invalid, and should return 404
for key in ["apple", "carrot", "dog"]:
for key in ['apple', 'carrot', 'dog']:
response = self.get(
reverse('api-global-setting-detail', kwargs={'key': key}),
expected_code=404,
@ -394,13 +419,7 @@ class GlobalSettingsApiTest(InvenTreeAPITestCase):
self.assertEqual(response.data['value'], 'InvenTree')
# Now, the object should have been created in the DB
self.patch(
url,
{
'value': 'My new title',
},
expected_code=200,
)
self.patch(url, {'value': 'My new title'}, expected_code=200)
setting = InvenTreeSetting.objects.get(key=key)
@ -451,8 +470,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
"""Test a boolean user setting value."""
# Ensure we have a boolean setting available
setting = InvenTreeUserSetting.get_setting_object(
'SEARCH_PREVIEW_SHOW_PARTS',
user=self.user
'SEARCH_PREVIEW_SHOW_PARTS', user=self.user
)
# Check default values
@ -465,20 +483,16 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
self.assertEqual(response.data['pk'], setting.pk)
self.assertEqual(response.data['key'], 'SEARCH_PREVIEW_SHOW_PARTS')
self.assertEqual(response.data['description'], 'Display parts in search preview window')
self.assertEqual(
response.data['description'], 'Display parts in search preview window'
)
self.assertEqual(response.data['type'], 'boolean')
self.assertEqual(len(response.data['choices']), 0)
self.assertTrue(str2bool(response.data['value']))
# Assign some truthy values
for v in ['true', True, 1, 'y', 'TRUE']:
self.patch(
url,
{
'value': str(v),
},
expected_code=200,
)
self.patch(url, {'value': str(v)}, expected_code=200)
response = self.get(url, expected_code=200)
@ -486,13 +500,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Assign some false(ish) values
for v in ['false', False, '0', 'n', 'FalSe']:
self.patch(
url,
{
'value': str(v),
},
expected_code=200,
)
self.patch(url, {'value': str(v)}, expected_code=200)
response = self.get(url, expected_code=200)
@ -500,13 +508,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Assign some invalid values
for v in ['x', '', 'invalid', None, '-1', 'abcde']:
response = self.patch(
url,
{
'value': str(v),
},
expected_code=200
)
response = self.patch(url, {'value': str(v)}, expected_code=200)
# Invalid values evaluate to False
self.assertFalse(str2bool(response.data['value']))
@ -514,8 +516,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
def test_user_setting_choice(self):
"""Test a user setting with choices."""
setting = InvenTreeUserSetting.get_setting_object(
'DATE_DISPLAY_FORMAT',
user=self.user
'DATE_DISPLAY_FORMAT', user=self.user
)
url = reverse('api-user-setting-detail', kwargs={'key': setting.key})
@ -525,37 +526,21 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Check that a valid option can be assigned via the API
for opt in ['YYYY-MM-DD', 'DD-MM-YYYY', 'MM/DD/YYYY']:
self.patch(
url,
{
'value': opt,
},
expected_code=200,
)
self.patch(url, {'value': opt}, expected_code=200)
setting.refresh_from_db()
self.assertEqual(setting.value, opt)
# Send an invalid option
for opt in ['cat', 'dog', 12345]:
response = self.patch(
url,
{
'value': opt,
},
expected_code=400,
)
response = self.patch(url, {'value': opt}, expected_code=400)
self.assertIn('Chosen value is not a valid option', str(response.data))
def test_user_setting_integer(self):
"""Test a integer user setting value."""
setting = InvenTreeUserSetting.get_setting_object(
'SEARCH_PREVIEW_RESULTS',
user=self.user,
cache=False,
'SEARCH_PREVIEW_RESULTS', user=self.user, cache=False
)
url = reverse('api-user-setting-detail', kwargs={'key': setting.key})
@ -573,13 +558,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Set valid options via the api
for v in [5, 15, 25]:
self.patch(
url,
{
'value': v,
},
expected_code=200,
)
self.patch(url, {'value': v}, expected_code=200)
setting.refresh_from_db()
self.assertEqual(setting.to_native_value(), v)
@ -587,14 +566,7 @@ class UserSettingsApiTest(InvenTreeAPITestCase):
# Set invalid options via the API
# Note that this particular setting has a MinValueValidator(1) associated with it
for v in [0, -1, -5]:
response = self.patch(
url,
{
'value': v,
},
expected_code=400,
)
response = self.patch(url, {'value': v}, expected_code=400)
class NotificationUserSettingsApiTest(InvenTreeAPITestCase):
@ -608,9 +580,15 @@ class NotificationUserSettingsApiTest(InvenTreeAPITestCase):
def test_setting(self):
"""Test the string name for NotificationUserSetting."""
NotificationUserSetting.set_setting('NOTIFICATION_METHOD_MAIL', True, change_user=self.user, user=self.user)
test_setting = NotificationUserSetting.get_setting_object('NOTIFICATION_METHOD_MAIL', user=self.user)
self.assertEqual(str(test_setting), 'NOTIFICATION_METHOD_MAIL (for testuser): True')
NotificationUserSetting.set_setting(
'NOTIFICATION_METHOD_MAIL', True, change_user=self.user, user=self.user
)
test_setting = NotificationUserSetting.get_setting_object(
'NOTIFICATION_METHOD_MAIL', user=self.user
)
self.assertEqual(
str(test_setting), 'NOTIFICATION_METHOD_MAIL (for testuser): True'
)
class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase):
@ -638,26 +616,38 @@ class PluginSettingsApiTest(PluginMixin, InvenTreeAPITestCase):
registry.set_plugin_state('sample', True)
# get data
url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'API_KEY'})
url = reverse(
'api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'API_KEY'}
)
response = self.get(url, expected_code=200)
# check the right setting came through
self.assertTrue(response.data['key'], 'API_KEY')
self.assertTrue(response.data['plugin'], 'sample')
self.assertTrue(response.data['type'], 'string')
self.assertTrue(response.data['description'], 'Key required for accessing external API')
self.assertTrue(
response.data['description'], 'Key required for accessing external API'
)
# Failure mode tests
# Non-existent plugin
url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'doesnotexist', 'key': 'doesnotmatter'})
url = reverse(
'api-plugin-setting-detail',
kwargs={'plugin': 'doesnotexist', 'key': 'doesnotmatter'},
)
response = self.get(url, expected_code=404)
self.assertIn("Plugin 'doesnotexist' not installed", str(response.data))
# Wrong key
url = reverse('api-plugin-setting-detail', kwargs={'plugin': 'sample', 'key': 'doesnotexist'})
url = reverse(
'api-plugin-setting-detail',
kwargs={'plugin': 'sample', 'key': 'doesnotexist'},
)
response = self.get(url, expected_code=404)
self.assertIn("Plugin 'sample' has no setting matching 'doesnotexist'", str(response.data))
self.assertIn(
"Plugin 'sample' has no setting matching 'doesnotexist'", str(response.data)
)
def test_invalid_setting_key(self):
"""Test that an invalid setting key returns a 404."""
@ -684,32 +674,30 @@ class WebhookMessageTests(TestCase):
def test_missing_token(self):
"""Tests that token checks work."""
response = self.client.post(
self.url,
content_type=CONTENT_TYPE_JSON,
)
response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON)
assert response.status_code == HTTPStatus.FORBIDDEN
assert (
json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR
json.loads(response.content)['detail']
== WebhookView.model_class.MESSAGE_TOKEN_ERROR
)
def test_bad_token(self):
"""Test that a wrong token is not working."""
response = self.client.post(
self.url,
content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': '1234567fghj'},
self.url, content_type=CONTENT_TYPE_JSON, **{'HTTP_TOKEN': '1234567fghj'}
)
assert response.status_code == HTTPStatus.FORBIDDEN
assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR)
assert (
json.loads(response.content)['detail']
== WebhookView.model_class.MESSAGE_TOKEN_ERROR
)
def test_bad_url(self):
"""Test that a wrongly formed url is not working."""
response = self.client.post(
'/api/webhook/1234/',
content_type=CONTENT_TYPE_JSON,
'/api/webhook/1234/', content_type=CONTENT_TYPE_JSON
)
assert response.status_code == HTTPStatus.NOT_FOUND
@ -725,7 +713,8 @@ class WebhookMessageTests(TestCase):
assert response.status_code == HTTPStatus.NOT_ACCEPTABLE
assert (
json.loads(response.content)['detail'] == 'Expecting property name enclosed in double quotes'
json.loads(response.content)['detail']
== 'Expecting property name enclosed in double quotes'
)
def test_success_no_token_check(self):
@ -735,10 +724,7 @@ class WebhookMessageTests(TestCase):
self.endpoint_def.save()
# check
response = self.client.post(
self.url,
content_type=CONTENT_TYPE_JSON,
)
response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON)
assert response.status_code == HTTPStatus.OK
assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK
@ -751,13 +737,13 @@ class WebhookMessageTests(TestCase):
self.endpoint_def.save()
# check
response = self.client.post(
self.url,
content_type=CONTENT_TYPE_JSON,
)
response = self.client.post(self.url, content_type=CONTENT_TYPE_JSON)
assert response.status_code == HTTPStatus.FORBIDDEN
assert (json.loads(response.content)['detail'] == WebhookView.model_class.MESSAGE_TOKEN_ERROR)
assert (
json.loads(response.content)['detail']
== WebhookView.model_class.MESSAGE_TOKEN_ERROR
)
def test_success_hmac(self):
"""Test with a valid HMAC provided."""
@ -783,7 +769,7 @@ class WebhookMessageTests(TestCase):
"""
response = self.client.post(
self.url,
data={"this": "is a message"},
data={'this': 'is a message'},
content_type=CONTENT_TYPE_JSON,
**{'HTTP_TOKEN': str(self.endpoint_def.token)},
)
@ -791,15 +777,13 @@ class WebhookMessageTests(TestCase):
assert response.status_code == HTTPStatus.OK
assert str(response.content, 'utf-8') == WebhookView.model_class.MESSAGE_OK
message = WebhookMessage.objects.get()
assert message.body == {"this": "is a message"}
assert message.body == {'this': 'is a message'}
class NotificationTest(InvenTreeAPITestCase):
"""Tests for NotificationEntry."""
fixtures = [
'users',
]
fixtures = ['users']
def test_check_notification_entries(self):
"""Test that notification entries can be created."""
@ -832,7 +816,10 @@ class NotificationTest(InvenTreeAPITestCase):
self.assertIn('GET', response.data['actions'])
self.assertNotIn('POST', response.data['actions'])
self.assertEqual(response.data['description'], 'List view for all notifications of the current user.')
self.assertEqual(
response.data['description'],
'List view for all notifications of the current user.',
)
# POST action should fail (not allowed)
response = self.post(url, {}, expected_code=405)
@ -867,13 +854,7 @@ class NotificationTest(InvenTreeAPITestCase):
ntf.save()
# Read out via API again
response = self.get(
url,
{
'read': True,
},
expected_code=200
)
response = self.get(url, {'read': True}, expected_code=200)
# Check validity of returned data
self.assertEqual(len(response.data), 3)
@ -882,15 +863,7 @@ class NotificationTest(InvenTreeAPITestCase):
# Now, let's bulk delete all 'unread' notifications via the API,
# but only associated with the logged in user
response = self.delete(
url,
{
'filters': {
'read': False,
}
},
expected_code=204,
)
response = self.delete(url, {'filters': {'read': False}}, expected_code=204)
# Only 7 notifications should have been deleted,
# as the notifications associated with other users must remain untouched
@ -907,13 +880,17 @@ class CommonTest(InvenTreeAPITestCase):
from plugin import registry
# set flag true
common.models.InvenTreeSetting.set_setting('SERVER_RESTART_REQUIRED', True, None)
common.models.InvenTreeSetting.set_setting(
'SERVER_RESTART_REQUIRED', True, None
)
# reload the app
registry.reload_plugins()
# now it should be false again
self.assertFalse(common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED'))
self.assertFalse(
common.models.InvenTreeSetting.get_setting('SERVER_RESTART_REQUIRED')
)
def test_config_api(self):
"""Test config URLs."""
@ -926,8 +903,13 @@ class CommonTest(InvenTreeAPITestCase):
# Successful checks
data = [
self.get(reverse('api-config-list'), expected_code=200).data[0], # list endpoint
self.get(reverse('api-config-detail', kwargs={'key': 'INVENTREE_DEBUG'}), expected_code=200).data, # detail endpoint
self.get(reverse('api-config-list'), expected_code=200).data[
0
], # list endpoint
self.get(
reverse('api-config-detail', kwargs={'key': 'INVENTREE_DEBUG'}),
expected_code=200,
).data, # detail endpoint
]
for item in data:
@ -956,21 +938,33 @@ class CommonTest(InvenTreeAPITestCase):
self.assertEqual(response.data[0]['key'], 'EXPERIMENTAL')
self.assertTrue(response.data[0]['conditions'])
response = self.get(reverse('api-flag-detail', kwargs={'key': 'EXPERIMENTAL'}), expected_code=200)
response = self.get(
reverse('api-flag-detail', kwargs={'key': 'EXPERIMENTAL'}),
expected_code=200,
)
self.assertEqual(len(response.data), 3)
self.assertEqual(response.data['key'], 'EXPERIMENTAL')
self.assertTrue(response.data['conditions'])
# Try without param -> false
response = self.get(reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), expected_code=200)
response = self.get(
reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), expected_code=200
)
self.assertFalse(response.data['state'])
# Try with param -> true
response = self.get(reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}), {'ngen': ''}, expected_code=200)
response = self.get(
reverse('api-flag-detail', kwargs={'key': 'NEXT_GEN'}),
{'ngen': ''},
expected_code=200,
)
self.assertTrue(response.data['state'])
# Try non existent flag
response = self.get(reverse('api-flag-detail', kwargs={'key': 'NON_EXISTENT'}), expected_code=404)
response = self.get(
reverse('api-flag-detail', kwargs={'key': 'NON_EXISTENT'}),
expected_code=404,
)
# Turn into normal user again
self.user.is_superuser = False
@ -1038,7 +1032,7 @@ class CurrencyAPITests(InvenTreeAPITestCase):
# Delay and try again
time.sleep(10)
raise TimeoutError("Could not refresh currency exchange data after 5 attempts")
raise TimeoutError('Could not refresh currency exchange data after 5 attempts')
class NotesImageTest(InvenTreeAPITestCase):
@ -1052,25 +1046,29 @@ class NotesImageTest(InvenTreeAPITestCase):
response = self.post(
reverse('api-notes-image-list'),
data={
'image': SimpleUploadedFile('test.txt', b"this is not an image file", content_type='text/plain'),
},
format='multipart',
expected_code=400
)
self.assertIn("Upload a valid image", str(response.data['image']))
# Test upload of an invalid image file
response = self.post(
reverse('api-notes-image-list'),
data={
'image': SimpleUploadedFile('test.png', b"this is not an image file", content_type='image/png'),
'image': SimpleUploadedFile(
'test.txt', b'this is not an image file', content_type='text/plain'
)
},
format='multipart',
expected_code=400,
)
self.assertIn("Upload a valid image", str(response.data['image']))
self.assertIn('Upload a valid image', str(response.data['image']))
# Test upload of an invalid image file
response = self.post(
reverse('api-notes-image-list'),
data={
'image': SimpleUploadedFile(
'test.png', b'this is not an image file', content_type='image/png'
)
},
format='multipart',
expected_code=400,
)
self.assertIn('Upload a valid image', str(response.data['image']))
# Check that no extra database entries have been created
self.assertEqual(NotesImage.objects.count(), n)
@ -1089,10 +1087,12 @@ class NotesImageTest(InvenTreeAPITestCase):
self.post(
reverse('api-notes-image-list'),
data={
'image': SimpleUploadedFile('test.png', contents, content_type='image/png'),
'image': SimpleUploadedFile(
'test.png', contents, content_type='image/png'
)
},
format='multipart',
expected_code=201
expected_code=201,
)
# Check that a new file has been created
@ -1136,7 +1136,7 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# Delete it
self.delete(
reverse('api-project-code-detail', kwargs={'pk': code.pk}),
expected_code=204
expected_code=204,
)
# Check it is gone
@ -1147,25 +1147,22 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# Create a new project code
response = self.post(
self.url,
data={
'code': 'PRJ-001',
'description': 'Test project code',
},
expected_code=400
data={'code': 'PRJ-001', 'description': 'Test project code'},
expected_code=400,
)
self.assertIn('project code with this Project Code already exists', str(response.data['code']))
self.assertIn(
'project code with this Project Code already exists',
str(response.data['code']),
)
def test_write_access(self):
"""Test that non-staff users have read-only access"""
# By default user has staff access, can create a new project code
response = self.post(
self.url,
data={
'code': 'PRJ-xxx',
'description': 'Test project code',
},
expected_code=201
data={'code': 'PRJ-xxx', 'description': 'Test project code'},
expected_code=201,
)
pk = response.data['pk']
@ -1173,10 +1170,8 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# Test we can edit, also
response = self.patch(
reverse('api-project-code-detail', kwargs={'pk': pk}),
data={
'code': 'PRJ-999',
},
expected_code=200
data={'code': 'PRJ-999'},
expected_code=200,
)
self.assertEqual(response.data['code'], 'PRJ-999')
@ -1188,20 +1183,15 @@ class ProjectCodesTest(InvenTreeAPITestCase):
# As user does not have staff access, should return 403 for list endpoint
response = self.post(
self.url,
data={
'code': 'PRJ-123',
'description': 'Test project code'
},
expected_code=403
data={'code': 'PRJ-123', 'description': 'Test project code'},
expected_code=403,
)
# Should also return 403 for detail endpoint
response = self.patch(
reverse('api-project-code-detail', kwargs={'pk': pk}),
data={
'code': 'PRJ-999',
},
expected_code=403
data={'code': 'PRJ-999'},
expected_code=403,
)
@ -1219,8 +1209,14 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
super().setUpTestData()
units = [
CustomUnit(name='metres_per_amp', definition='meter / ampere', symbol='m/A'),
CustomUnit(name='hectares_per_second', definition='hectares per second', symbol='ha/s'),
CustomUnit(
name='metres_per_amp', definition='meter / ampere', symbol='m/A'
),
CustomUnit(
name='hectares_per_second',
definition='hectares per second',
symbol='ha/s',
),
]
CustomUnit.objects.bulk_create(units)
@ -1240,10 +1236,8 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
self.patch(
reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}),
{
'name': 'new_unit_name',
},
expected_code=403
{'name': 'new_unit_name'},
expected_code=403,
)
# Ok, what if we have permission?
@ -1252,9 +1246,7 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
self.patch(
reverse('api-custom-unit-detail', kwargs={'pk': unit.pk}),
{
'name': 'new_unit_name',
},
{'name': 'new_unit_name'},
# expected_code=200
)
@ -1269,21 +1261,9 @@ class CustomUnitAPITest(InvenTreeAPITestCase):
self.user.save()
# Test invalid 'name' values (must be valid identifier)
invalid_name_values = [
'1',
'1abc',
'abc def',
'abc-def',
'abc.def',
]
invalid_name_values = ['1', '1abc', 'abc def', 'abc-def', 'abc.def']
url = reverse('api-custom-unit-detail', kwargs={'pk': unit.pk})
for name in invalid_name_values:
self.patch(
url,
{
'name': name,
},
expected_code=400
)
self.patch(url, {'name': name}, expected_code=400)

View File

@ -1,4 +1,3 @@
"""URL lookup for common views."""
common_urls = [
]
common_urls = []

View File

@ -81,11 +81,7 @@ class FileManagementFormView(MultiStepFormView):
('fields', forms.MatchFieldForm),
('items', forms.MatchItemForm),
]
form_steps_description = [
_("Upload File"),
_("Match Fields"),
_("Match Items"),
]
form_steps_description = [_('Upload File'), _('Match Fields'), _('Match Items')]
media_folder = 'file_upload/'
extra_context_data = {}
@ -95,8 +91,12 @@ class FileManagementFormView(MultiStepFormView):
super().__init__(self, *args, **kwargs)
# Check for file manager class
if not hasattr(self, 'file_manager_class') and not issubclass(self.file_manager_class, FileManager):
raise NotImplementedError('A subclass of a file manager class needs to be set!')
if not hasattr(self, 'file_manager_class') and not issubclass(
self.file_manager_class, FileManager
):
raise NotImplementedError(
'A subclass of a file manager class needs to be set!'
)
def get_context_data(self, form=None, **kwargs):
"""Handle context data."""
@ -106,7 +106,6 @@ class FileManagementFormView(MultiStepFormView):
context = super().get_context_data(form=form, **kwargs)
if self.steps.current in ('fields', 'items'):
# Get columns and row data
self.columns = self.file_manager.columns()
self.rows = self.file_manager.rows()
@ -140,7 +139,9 @@ class FileManagementFormView(MultiStepFormView):
# Get file
file = upload_files.get('upload-file', None)
if file:
self.file_manager = self.file_manager_class(file=file, name=self.name)
self.file_manager = self.file_manager_class(
file=file, name=self.name
)
def get_form_kwargs(self, step=None):
"""Update kwargs to dynamically build forms."""
@ -150,15 +151,11 @@ class FileManagementFormView(MultiStepFormView):
if step == 'upload':
# Dynamically build upload form
if self.name:
kwargs = {
'name': self.name
}
kwargs = {'name': self.name}
return kwargs
elif step == 'fields':
# Dynamically build match field form
kwargs = {
'file_manager': self.file_manager
}
kwargs = {'file_manager': self.file_manager}
return kwargs
elif step == 'items':
# Dynamically build match item form
@ -206,7 +203,6 @@ class FileManagementFormView(MultiStepFormView):
self.row_data = {}
for item, value in form_data.items():
# Column names as passed as col_name_<idx> where idx is an integer
# Extract the column names
@ -220,7 +216,6 @@ class FileManagementFormView(MultiStepFormView):
# Extract the column selections (in the 'select fields' view)
if item.startswith('fields-'):
try:
col_name = item.replace('fields-', '')
except ValueError:
@ -258,10 +253,7 @@ class FileManagementFormView(MultiStepFormView):
self.columns = []
for idx, value in self.column_names.items():
header = ({
'name': value,
'guess': self.column_selections.get(idx, ''),
})
header = {'name': value, 'guess': self.column_selections.get(idx, '')}
self.columns.append(header)
if self.row_data:
@ -280,18 +272,10 @@ class FileManagementFormView(MultiStepFormView):
'guess': self.column_selections[idx],
}
cell_data = {
'cell': item,
'idx': idx,
'column': column_data,
}
cell_data = {'cell': item, 'idx': idx, 'column': column_data}
data.append(cell_data)
row = {
'index': row_idx,
'data': data,
'errors': {},
}
row = {'index': row_idx, 'data': data, 'errors': {}}
self.rows.append(row)
@ -344,11 +328,7 @@ class FileManagementFormView(MultiStepFormView):
try:
if idx not in items:
# Insert into items
items.update({
idx: {
self.form_field_map[field]: form_value,
}
})
items.update({idx: {self.form_field_map[field]: form_value}})
else:
# Update items
items[idx][self.form_field_map[field]] = form_value
@ -383,14 +363,15 @@ class FileManagementFormView(MultiStepFormView):
duplicates = []
for col in self.column_names:
if col in self.column_selections:
guess = self.column_selections[col]
else:
guess = None
if guess:
n = list(self.column_selections.values()).count(self.column_selections[col])
n = list(self.column_selections.values()).count(
self.column_selections[col]
)
if n > 1 and self.column_selections[col] not in duplicates:
duplicates.append(self.column_selections[col])
@ -459,7 +440,9 @@ class FileManagementAjaxView(AjaxView):
wizard_back = self.request.POST.get('act-btn_back', None)
if wizard_back:
back_step_index = self.get_step_index() - 1
self.storage.current_step = list(self.get_form_list().keys())[back_step_index]
self.storage.current_step = list(self.get_form_list().keys())[
back_step_index
]
return self.renderJsonResponse(request, data={'form_valid': None})
# validate form
@ -499,13 +482,19 @@ class FileManagementAjaxView(AjaxView):
data = {}
self.setTemplate()
return super().renderJsonResponse(request, form=form, data=data, context=context)
return super().renderJsonResponse(
request, form=form, data=data, context=context
)
def get_data(self) -> dict:
"""Get extra context data."""
data = super().get_data()
data['hideErrorMessage'] = '1' # hide the error
buttons = [{'name': 'back', 'title': _('Previous Step')}] if self.get_step_index() > 0 else []
buttons = (
[{'name': 'back', 'title': _('Previous Step')}]
if self.get_step_index() > 0
else []
)
data['buttons'] = buttons # set buttons
return data

View File

@ -9,9 +9,16 @@ from import_export.fields import Field
from InvenTree.admin import InvenTreeResource
from part.models import Part
from .models import (Address, Company, Contact, ManufacturerPart,
ManufacturerPartAttachment, ManufacturerPartParameter,
SupplierPart, SupplierPriceBreak)
from .models import (
Address,
Company,
Contact,
ManufacturerPart,
ManufacturerPartAttachment,
ManufacturerPartParameter,
SupplierPart,
SupplierPriceBreak,
)
class CompanyResource(InvenTreeResource):
@ -19,6 +26,7 @@ class CompanyResource(InvenTreeResource):
class Meta:
"""Metaclass defines extra options"""
model = Company
skip_unchanged = True
report_skipped = False
@ -32,10 +40,7 @@ class CompanyAdmin(ImportExportModelAdmin):
list_display = ('name', 'website', 'contact')
search_fields = [
'name',
'description',
]
search_fields = ['name', 'description']
class SupplierPartResource(InvenTreeResource):
@ -43,6 +48,7 @@ class SupplierPartResource(InvenTreeResource):
class Meta:
"""Metaclass defines extra admin options"""
model = SupplierPart
skip_unchanged = True
report_skipped = True
@ -70,18 +76,11 @@ class SupplierPartAdmin(ImportExportModelAdmin):
list_display = ('part', 'supplier', 'SKU')
search_fields = [
'supplier__name',
'part__name',
'manufacturer_part__MPN',
'SKU',
]
search_fields = ['supplier__name', 'part__name', 'manufacturer_part__MPN', 'SKU']
inlines = [
SupplierPriceBreakInline,
]
inlines = [SupplierPriceBreakInline]
autocomplete_fields = ('part', 'supplier', 'manufacturer_part',)
autocomplete_fields = ('part', 'supplier', 'manufacturer_part')
class ManufacturerPartResource(InvenTreeResource):
@ -89,6 +88,7 @@ class ManufacturerPartResource(InvenTreeResource):
class Meta:
"""Metaclass defines extra admin options"""
model = ManufacturerPart
skip_unchanged = True
report_skipped = True
@ -98,7 +98,9 @@ class ManufacturerPartResource(InvenTreeResource):
part_name = Field(attribute='part__full_name', readonly=True)
manufacturer = Field(attribute='manufacturer', widget=widgets.ForeignKeyWidget(Company))
manufacturer = Field(
attribute='manufacturer', widget=widgets.ForeignKeyWidget(Company)
)
manufacturer_name = Field(attribute='manufacturer__name', readonly=True)
@ -110,13 +112,9 @@ class ManufacturerPartAdmin(ImportExportModelAdmin):
list_display = ('part', 'manufacturer', 'MPN')
search_fields = [
'manufacturer__name',
'part__name',
'MPN',
]
search_fields = ['manufacturer__name', 'part__name', 'MPN']
autocomplete_fields = ('part', 'manufacturer',)
autocomplete_fields = ('part', 'manufacturer')
class ManufacturerPartAttachmentAdmin(ImportExportModelAdmin):
@ -132,6 +130,7 @@ class ManufacturerPartParameterResource(InvenTreeResource):
class Meta:
"""Metaclass defines extra admin options"""
model = ManufacturerPartParameter
skip_unchanged = True
report_skipped = True
@ -145,11 +144,7 @@ class ManufacturerPartParameterAdmin(ImportExportModelAdmin):
list_display = ('manufacturer_part', 'name', 'value')
search_fields = [
'manufacturer_part__manufacturer__name',
'name',
'value'
]
search_fields = ['manufacturer_part__manufacturer__name', 'name', 'value']
autocomplete_fields = ('manufacturer_part',)
@ -159,6 +154,7 @@ class SupplierPriceBreakResource(InvenTreeResource):
class Meta:
"""Metaclass defines extra admin options"""
model = SupplierPriceBreak
skip_unchanged = True
report_skipped = False
@ -192,6 +188,7 @@ class AddressResource(InvenTreeResource):
class Meta:
"""Metaclass defining extra options"""
model = Address
skip_unchanged = True
report_skipped = False
@ -207,11 +204,7 @@ class AddressAdmin(ImportExportModelAdmin):
list_display = ('company', 'line1', 'postal_code', 'country')
search_fields = [
'company',
'country',
'postal_code',
]
search_fields = ['company', 'country', 'postal_code']
class ContactResource(InvenTreeResource):
@ -219,6 +212,7 @@ class ContactResource(InvenTreeResource):
class Meta:
"""Metaclass defining extra options"""
model = Contact
skip_unchanged = True
report_skipped = False
@ -234,11 +228,7 @@ class ContactAdmin(ImportExportModelAdmin):
list_display = ('company', 'name', 'role', 'email', 'phone')
search_fields = [
'company',
'name',
'email',
]
search_fields = ['company', 'name', 'email']
admin.site.register(Company, CompanyAdmin)

View File

@ -6,23 +6,37 @@ from django.urls import include, path, re_path
from django_filters import rest_framework as rest_filters
import part.models
from InvenTree.api import (AttachmentMixin, ListCreateDestroyAPIView,
MetadataView)
from InvenTree.filters import (ORDER_FILTER, SEARCH_ORDER_FILTER,
SEARCH_ORDER_FILTER_ALIAS)
from InvenTree.api import AttachmentMixin, ListCreateDestroyAPIView, MetadataView
from InvenTree.filters import (
ORDER_FILTER,
SEARCH_ORDER_FILTER,
SEARCH_ORDER_FILTER_ALIAS,
)
from InvenTree.helpers import str2bool
from InvenTree.mixins import ListCreateAPI, RetrieveUpdateDestroyAPI
from .models import (Address, Company, CompanyAttachment, Contact,
ManufacturerPart, ManufacturerPartAttachment,
ManufacturerPartParameter, SupplierPart,
SupplierPriceBreak)
from .serializers import (AddressSerializer, CompanyAttachmentSerializer,
CompanySerializer, ContactSerializer,
ManufacturerPartAttachmentSerializer,
ManufacturerPartParameterSerializer,
ManufacturerPartSerializer, SupplierPartSerializer,
SupplierPriceBreakSerializer)
from .models import (
Address,
Company,
CompanyAttachment,
Contact,
ManufacturerPart,
ManufacturerPartAttachment,
ManufacturerPartParameter,
SupplierPart,
SupplierPriceBreak,
)
from .serializers import (
AddressSerializer,
CompanyAttachmentSerializer,
CompanySerializer,
ContactSerializer,
ManufacturerPartAttachmentSerializer,
ManufacturerPartParameterSerializer,
ManufacturerPartSerializer,
SupplierPartSerializer,
SupplierPriceBreakSerializer,
)
class CompanyList(ListCreateAPI):
@ -46,24 +60,11 @@ class CompanyList(ListCreateAPI):
filter_backends = SEARCH_ORDER_FILTER
filterset_fields = [
'is_customer',
'is_manufacturer',
'is_supplier',
'name',
]
filterset_fields = ['is_customer', 'is_manufacturer', 'is_supplier', 'name']
search_fields = [
'name',
'description',
'website',
]
search_fields = ['name', 'description', 'website']
ordering_fields = [
'name',
'parts_supplied',
'parts_manufactured',
]
ordering_fields = ['name', 'parts_supplied', 'parts_manufactured']
ordering = 'name'
@ -88,9 +89,7 @@ class CompanyAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
queryset = CompanyAttachment.objects.all()
serializer_class = CompanyAttachmentSerializer
filterset_fields = [
'company',
]
filterset_fields = ['company']
class CompanyAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI):
@ -108,18 +107,11 @@ class ContactList(ListCreateDestroyAPIView):
filter_backends = SEARCH_ORDER_FILTER
filterset_fields = [
'company',
]
filterset_fields = ['company']
search_fields = [
'company__name',
'name',
]
search_fields = ['company__name', 'name']
ordering_fields = [
'name',
]
ordering_fields = ['name']
ordering = 'name'
@ -139,13 +131,9 @@ class AddressList(ListCreateDestroyAPIView):
filter_backends = SEARCH_ORDER_FILTER
filterset_fields = [
'company',
]
filterset_fields = ['company']
ordering_fields = [
'title',
]
ordering_fields = ['title']
ordering = 'title'
@ -164,13 +152,7 @@ class ManufacturerPartFilter(rest_filters.FilterSet):
"""Metaclass options."""
model = ManufacturerPart
fields = [
'manufacturer',
'MPN',
'part',
'tags__name',
'tags__slug',
]
fields = ['manufacturer', 'MPN', 'part', 'tags__name', 'tags__slug']
# Filter by 'active' status of linked part
active = rest_filters.BooleanFilter(field_name='part__active')
@ -184,10 +166,7 @@ class ManufacturerPartList(ListCreateDestroyAPIView):
"""
queryset = ManufacturerPart.objects.all().prefetch_related(
'part',
'manufacturer',
'supplier_parts',
'tags',
'part', 'manufacturer', 'supplier_parts', 'tags'
)
serializer_class = ManufacturerPartSerializer
@ -200,7 +179,9 @@ class ManufacturerPartList(ListCreateDestroyAPIView):
params = self.request.query_params
kwargs['part_detail'] = str2bool(params.get('part_detail', None))
kwargs['manufacturer_detail'] = str2bool(params.get('manufacturer_detail', None))
kwargs['manufacturer_detail'] = str2bool(
params.get('manufacturer_detail', None)
)
kwargs['pretty'] = str2bool(params.get('pretty', None))
except AttributeError:
pass
@ -241,9 +222,7 @@ class ManufacturerPartAttachmentList(AttachmentMixin, ListCreateDestroyAPIView):
queryset = ManufacturerPartAttachment.objects.all()
serializer_class = ManufacturerPartAttachmentSerializer
filterset_fields = [
'manufacturer_part',
]
filterset_fields = ['manufacturer_part']
class ManufacturerPartAttachmentDetail(AttachmentMixin, RetrieveUpdateDestroyAPI):
@ -258,17 +237,17 @@ class ManufacturerPartParameterFilter(rest_filters.FilterSet):
class Meta:
"""Metaclass options"""
model = ManufacturerPartParameter
fields = [
'name',
'value',
'units',
'manufacturer_part',
]
fields = ['name', 'value', 'units', 'manufacturer_part']
manufacturer = rest_filters.ModelChoiceFilter(queryset=Company.objects.all(), field_name='manufacturer_part__manufacturer')
manufacturer = rest_filters.ModelChoiceFilter(
queryset=Company.objects.all(), field_name='manufacturer_part__manufacturer'
)
part = rest_filters.ModelChoiceFilter(queryset=part.models.Part.objects.all(), field_name='manufacturer_part__part')
part = rest_filters.ModelChoiceFilter(
queryset=part.models.Part.objects.all(), field_name='manufacturer_part__part'
)
class ManufacturerPartParameterList(ListCreateDestroyAPIView):
@ -284,9 +263,7 @@ class ManufacturerPartParameterList(ListCreateDestroyAPIView):
try:
params = self.request.query_params
optional_fields = [
'manufacturer_part_detail',
]
optional_fields = ['manufacturer_part_detail']
for key in optional_fields:
kwargs[key] = str2bool(params.get(key, None))
@ -300,11 +277,7 @@ class ManufacturerPartParameterList(ListCreateDestroyAPIView):
filter_backends = SEARCH_ORDER_FILTER
search_fields = [
'name',
'value',
'units',
]
search_fields = ['name', 'value', 'units']
class ManufacturerPartParameterDetail(RetrieveUpdateDestroyAPI):
@ -337,7 +310,7 @@ class SupplierPartFilter(rest_filters.FilterSet):
MPN = rest_filters.CharFilter(
label='Manufacturer Part Number',
field_name='manufacturer_part__MPN',
lookup_expr='iexact'
lookup_expr='iexact',
)
@ -348,9 +321,7 @@ class SupplierPartList(ListCreateDestroyAPIView):
- POST: Create a new SupplierPart object
"""
queryset = SupplierPart.objects.all().prefetch_related(
'tags',
)
queryset = SupplierPart.objects.all().prefetch_related('tags')
filterset_class = SupplierPartFilter
def get_queryset(self, *args, **kwargs):
@ -376,7 +347,9 @@ class SupplierPartList(ListCreateDestroyAPIView):
company = params.get('company', None)
if company is not None:
queryset = queryset.filter(Q(manufacturer_part__manufacturer=company) | Q(supplier=company)).distinct()
queryset = queryset.filter(
Q(manufacturer_part__manufacturer=company) | Q(supplier=company)
).distinct()
return queryset
@ -387,7 +360,9 @@ class SupplierPartList(ListCreateDestroyAPIView):
params = self.request.query_params
kwargs['part_detail'] = str2bool(params.get('part_detail', None))
kwargs['supplier_detail'] = str2bool(params.get('supplier_detail', True))
kwargs['manufacturer_detail'] = str2bool(params.get('manufacturer_detail', None))
kwargs['manufacturer_detail'] = str2bool(
params.get('manufacturer_detail', None)
)
kwargs['pretty'] = str2bool(params.get('pretty', None))
except AttributeError:
pass
@ -446,8 +421,7 @@ class SupplierPartDetail(RetrieveUpdateDestroyAPI):
queryset = SupplierPart.objects.all()
serializer_class = SupplierPartSerializer
read_only_fields = [
]
read_only_fields = []
class SupplierPriceBreakFilter(rest_filters.FilterSet):
@ -457,10 +431,7 @@ class SupplierPriceBreakFilter(rest_filters.FilterSet):
"""Metaclass options"""
model = SupplierPriceBreak
fields = [
'part',
'quantity',
]
fields = ['part', 'quantity']
base_part = rest_filters.ModelChoiceFilter(
label='Base Part',
@ -469,9 +440,7 @@ class SupplierPriceBreakFilter(rest_filters.FilterSet):
)
supplier = rest_filters.ModelChoiceFilter(
label='Supplier',
queryset=Company.objects.all(),
field_name='part__supplier',
label='Supplier', queryset=Company.objects.all(), field_name='part__supplier'
)
@ -503,9 +472,7 @@ class SupplierPriceBreakList(ListCreateAPI):
filter_backends = ORDER_FILTER
ordering_fields = [
'quantity',
]
ordering_fields = ['quantity']
ordering = 'quantity'
@ -518,37 +485,74 @@ class SupplierPriceBreakDetail(RetrieveUpdateDestroyAPI):
manufacturer_part_api_urls = [
# Base URL for ManufacturerPartAttachment API endpoints
re_path(r'^attachment/', include([
path(r'<int:pk>/', ManufacturerPartAttachmentDetail.as_view(), name='api-manufacturer-part-attachment-detail'),
re_path(r'^$', ManufacturerPartAttachmentList.as_view(), name='api-manufacturer-part-attachment-list'),
])),
re_path(r'^parameter/', include([
path(r'<int:pk>/', ManufacturerPartParameterDetail.as_view(), name='api-manufacturer-part-parameter-detail'),
# Catch anything else
re_path(r'^.*$', ManufacturerPartParameterList.as_view(), name='api-manufacturer-part-parameter-list'),
])),
re_path(r'^(?P<pk>\d+)/?', include([
re_path('^metadata/', MetadataView.as_view(), {'model': ManufacturerPart}, name='api-manufacturer-part-metadata'),
re_path('^.*$', ManufacturerPartDetail.as_view(), name='api-manufacturer-part-detail'),
])),
re_path(
r'^attachment/',
include([
path(
r'<int:pk>/',
ManufacturerPartAttachmentDetail.as_view(),
name='api-manufacturer-part-attachment-detail',
),
re_path(
r'^$',
ManufacturerPartAttachmentList.as_view(),
name='api-manufacturer-part-attachment-list',
),
]),
),
re_path(
r'^parameter/',
include([
path(
r'<int:pk>/',
ManufacturerPartParameterDetail.as_view(),
name='api-manufacturer-part-parameter-detail',
),
# Catch anything else
re_path(
r'^.*$',
ManufacturerPartParameterList.as_view(),
name='api-manufacturer-part-parameter-list',
),
]),
),
re_path(
r'^(?P<pk>\d+)/?',
include([
re_path(
'^metadata/',
MetadataView.as_view(),
{'model': ManufacturerPart},
name='api-manufacturer-part-metadata',
),
re_path(
'^.*$',
ManufacturerPartDetail.as_view(),
name='api-manufacturer-part-detail',
),
]),
),
# Catch anything else
re_path(r'^.*$', ManufacturerPartList.as_view(), name='api-manufacturer-part-list'),
]
supplier_part_api_urls = [
re_path(r'^(?P<pk>\d+)/?', include([
re_path('^metadata/', MetadataView.as_view(), {'model': SupplierPart}, name='api-supplier-part-metadata'),
re_path('^.*$', SupplierPartDetail.as_view(), name='api-supplier-part-detail'),
])),
re_path(
r'^(?P<pk>\d+)/?',
include([
re_path(
'^metadata/',
MetadataView.as_view(),
{'model': SupplierPart},
name='api-supplier-part-metadata',
),
re_path(
'^.*$', SupplierPartDetail.as_view(), name='api-supplier-part-detail'
),
]),
),
# Catch anything else
re_path(r'^.*$', SupplierPartList.as_view(), name='api-supplier-part-list'),
]
@ -556,39 +560,74 @@ supplier_part_api_urls = [
company_api_urls = [
re_path(r'^part/manufacturer/', include(manufacturer_part_api_urls)),
re_path(r'^part/', include(supplier_part_api_urls)),
# Supplier price breaks
re_path(r'^price-break/', include([
re_path(r'^(?P<pk>\d+)/?', SupplierPriceBreakDetail.as_view(), name='api-part-supplier-price-detail'),
re_path(r'^.*$', SupplierPriceBreakList.as_view(), name='api-part-supplier-price-list'),
])),
re_path(r'^(?P<pk>\d+)/?', include([
re_path(r'^metadata/', MetadataView.as_view(), {'model': Company}, name='api-company-metadata'),
re_path(r'^.*$', CompanyDetail.as_view(), name='api-company-detail'),
])),
re_path(r'^attachment/', include([
path(r'<int:pk>/', CompanyAttachmentDetail.as_view(), name='api-company-attachment-detail'),
re_path(r'^$', CompanyAttachmentList.as_view(), name='api-company-attachment-list'),
])),
re_path(r'^contact/', include([
re_path(r'^(?P<pk>\d+)/?', include([
re_path('^metadata/', MetadataView.as_view(), {'model': Contact}, name='api-contact-metadata'),
re_path('^.*$', ContactDetail.as_view(), name='api-contact-detail'),
])),
re_path(r'^.*$', ContactList.as_view(), name='api-contact-list'),
])),
re_path(r'^address/', include([
path('<int:pk>/', AddressDetail.as_view(), name='api-address-detail'),
re_path(r'^.*$', AddressList.as_view(), name='api-address-list'),
])),
re_path(
r'^price-break/',
include([
re_path(
r'^(?P<pk>\d+)/?',
SupplierPriceBreakDetail.as_view(),
name='api-part-supplier-price-detail',
),
re_path(
r'^.*$',
SupplierPriceBreakList.as_view(),
name='api-part-supplier-price-list',
),
]),
),
re_path(
r'^(?P<pk>\d+)/?',
include([
re_path(
r'^metadata/',
MetadataView.as_view(),
{'model': Company},
name='api-company-metadata',
),
re_path(r'^.*$', CompanyDetail.as_view(), name='api-company-detail'),
]),
),
re_path(
r'^attachment/',
include([
path(
r'<int:pk>/',
CompanyAttachmentDetail.as_view(),
name='api-company-attachment-detail',
),
re_path(
r'^$',
CompanyAttachmentList.as_view(),
name='api-company-attachment-list',
),
]),
),
re_path(
r'^contact/',
include([
re_path(
r'^(?P<pk>\d+)/?',
include([
re_path(
'^metadata/',
MetadataView.as_view(),
{'model': Contact},
name='api-contact-metadata',
),
re_path('^.*$', ContactDetail.as_view(), name='api-contact-detail'),
]),
),
re_path(r'^.*$', ContactList.as_view(), name='api-contact-list'),
]),
),
re_path(
r'^address/',
include([
path('<int:pk>/', AddressDetail.as_view(), name='api-address-detail'),
re_path(r'^.*$', AddressList.as_view(), name='api-address-list'),
]),
),
re_path(r'^.*$', CompanyList.as_view(), name='api-company-list'),
]

View File

@ -29,8 +29,12 @@ import InvenTree.tasks
import InvenTree.validators
from common.settings import currency_code_default
from InvenTree.fields import InvenTreeURLField, RoundingDecimalField
from InvenTree.models import (InvenTreeAttachment, InvenTreeBarcodeMixin,
InvenTreeNotesMixin, MetadataMixin)
from InvenTree.models import (
InvenTreeAttachment,
InvenTreeBarcodeMixin,
InvenTreeNotesMixin,
MetadataMixin,
)
from InvenTree.status_codes import PurchaseOrderStatusGroups
@ -87,20 +91,24 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
class Meta:
"""Metaclass defines extra model options"""
ordering = ['name', ]
ordering = ['name']
constraints = [
UniqueConstraint(fields=['name', 'email'], name='unique_name_email_pair')
]
verbose_name_plural = "Companies"
verbose_name_plural = 'Companies'
@staticmethod
def get_api_url():
"""Return the API URL associated with the Company model"""
return reverse('api-company-list')
name = models.CharField(max_length=100, blank=False,
help_text=_('Company name'),
verbose_name=_('Company name'))
name = models.CharField(
max_length=100,
blank=False,
help_text=_('Company name'),
verbose_name=_('Company name'),
)
description = models.CharField(
max_length=500,
@ -110,41 +118,62 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
)
website = InvenTreeURLField(
blank=True,
verbose_name=_('Website'),
help_text=_('Company website URL')
blank=True, verbose_name=_('Website'), help_text=_('Company website URL')
)
phone = models.CharField(max_length=50,
verbose_name=_('Phone number'),
blank=True, help_text=_('Contact phone number'))
phone = models.CharField(
max_length=50,
verbose_name=_('Phone number'),
blank=True,
help_text=_('Contact phone number'),
)
email = models.EmailField(blank=True, null=True,
verbose_name=_('Email'), help_text=_('Contact email address'))
email = models.EmailField(
blank=True,
null=True,
verbose_name=_('Email'),
help_text=_('Contact email address'),
)
contact = models.CharField(max_length=100,
verbose_name=_('Contact'),
blank=True, help_text=_('Point of contact'))
contact = models.CharField(
max_length=100,
verbose_name=_('Contact'),
blank=True,
help_text=_('Point of contact'),
)
link = InvenTreeURLField(blank=True, verbose_name=_('Link'), help_text=_('Link to external company information'))
link = InvenTreeURLField(
blank=True,
verbose_name=_('Link'),
help_text=_('Link to external company information'),
)
image = StdImageField(
upload_to=rename_company_image,
null=True,
blank=True,
variations={
'thumbnail': (128, 128),
'preview': (256, 256),
},
variations={'thumbnail': (128, 128), 'preview': (256, 256)},
delete_orphans=True,
verbose_name=_('Image'),
)
is_customer = models.BooleanField(default=False, verbose_name=_('is customer'), help_text=_('Do you sell items to this company?'))
is_customer = models.BooleanField(
default=False,
verbose_name=_('is customer'),
help_text=_('Do you sell items to this company?'),
)
is_supplier = models.BooleanField(default=True, verbose_name=_('is supplier'), help_text=_('Do you purchase items from this company?'))
is_supplier = models.BooleanField(
default=True,
verbose_name=_('is supplier'),
help_text=_('Do you purchase items from this company?'),
)
is_manufacturer = models.BooleanField(default=False, verbose_name=_('is manufacturer'), help_text=_('Does this company manufacture parts?'))
is_manufacturer = models.BooleanField(
default=False,
verbose_name=_('is manufacturer'),
help_text=_('Does this company manufacture parts?'),
)
currency = models.CharField(
max_length=3,
@ -186,7 +215,7 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
def __str__(self):
"""Get string representation of a Company."""
return f"{self.name} - {self.description}"
return f'{self.name} - {self.description}'
def get_absolute_url(self):
"""Get the web URL for the detail view for this Company."""
@ -207,13 +236,18 @@ class Company(InvenTreeNotesMixin, MetadataMixin, models.Model):
@property
def parts(self):
"""Return SupplierPart objects which are supplied or manufactured by this company."""
return SupplierPart.objects.filter(Q(supplier=self.id) | Q(manufacturer_part__manufacturer=self.id)).distinct()
return SupplierPart.objects.filter(
Q(supplier=self.id) | Q(manufacturer_part__manufacturer=self.id)
).distinct()
@property
def stock_items(self):
"""Return a list of all stock items supplied or manufactured by this company."""
stock = apps.get_model('stock', 'StockItem')
return stock.objects.filter(Q(supplier_part__supplier=self.id) | Q(supplier_part__manufacturer_part__manufacturer=self.id)).distinct()
return stock.objects.filter(
Q(supplier_part__supplier=self.id)
| Q(supplier_part__manufacturer_part__manufacturer=self.id)
).distinct()
class CompanyAttachment(InvenTreeAttachment):
@ -229,7 +263,8 @@ class CompanyAttachment(InvenTreeAttachment):
return os.path.join('company_files', str(self.company.pk))
company = models.ForeignKey(
Company, on_delete=models.CASCADE,
Company,
on_delete=models.CASCADE,
verbose_name=_('Company'),
related_name='attachments',
)
@ -251,8 +286,9 @@ class Contact(MetadataMixin, models.Model):
"""Return the API URL associated with the Contcat model"""
return reverse('api-contact-list')
company = models.ForeignKey(Company, related_name='contacts',
on_delete=models.CASCADE)
company = models.ForeignKey(
Company, related_name='contacts', on_delete=models.CASCADE
)
name = models.CharField(max_length=100)
@ -281,7 +317,8 @@ class Address(models.Model):
class Meta:
"""Metaclass defines extra model options"""
verbose_name_plural = "Addresses"
verbose_name_plural = 'Addresses'
def __init__(self, *args, **kwargs):
"""Custom init function"""
@ -289,27 +326,30 @@ class Address(models.Model):
def __str__(self):
"""Defines string representation of address to supple a one-line to API calls"""
available_lines = [self.line1,
self.line2,
self.postal_code,
self.postal_city,
self.province,
self.country
]
available_lines = [
self.line1,
self.line2,
self.postal_code,
self.postal_city,
self.province,
self.country,
]
populated_lines = []
for line in available_lines:
if len(line) > 0:
populated_lines.append(line)
return ", ".join(populated_lines)
return ', '.join(populated_lines)
def save(self, *args, **kwargs):
"""Run checks when saving an address:
- If this address is marked as "primary", ensure that all other addresses for this company are marked as non-primary
"""
others = list(Address.objects.filter(company=self.company).exclude(pk=self.pk).all())
others = list(
Address.objects.filter(company=self.company).exclude(pk=self.pk).all()
)
# If this is the *only* address for this company, make it the primary one
if len(others) == 0:
@ -329,63 +369,88 @@ class Address(models.Model):
"""Return the API URL associated with the Contcat model"""
return reverse('api-address-list')
company = models.ForeignKey(Company, related_name='addresses',
on_delete=models.CASCADE,
verbose_name=_('Company'),
help_text=_('Select company'))
company = models.ForeignKey(
Company,
related_name='addresses',
on_delete=models.CASCADE,
verbose_name=_('Company'),
help_text=_('Select company'),
)
title = models.CharField(max_length=100,
verbose_name=_('Address title'),
help_text=_('Title describing the address entry'),
blank=False)
title = models.CharField(
max_length=100,
verbose_name=_('Address title'),
help_text=_('Title describing the address entry'),
blank=False,
)
primary = models.BooleanField(default=False,
verbose_name=_('Primary address'),
help_text=_('Set as primary address'))
primary = models.BooleanField(
default=False,
verbose_name=_('Primary address'),
help_text=_('Set as primary address'),
)
line1 = models.CharField(max_length=50,
verbose_name=_('Line 1'),
help_text=_('Address line 1'),
blank=True)
line1 = models.CharField(
max_length=50,
verbose_name=_('Line 1'),
help_text=_('Address line 1'),
blank=True,
)
line2 = models.CharField(max_length=50,
verbose_name=_('Line 2'),
help_text=_('Address line 2'),
blank=True)
line2 = models.CharField(
max_length=50,
verbose_name=_('Line 2'),
help_text=_('Address line 2'),
blank=True,
)
postal_code = models.CharField(max_length=10,
verbose_name=_('Postal code'),
help_text=_('Postal code'),
blank=True)
postal_code = models.CharField(
max_length=10,
verbose_name=_('Postal code'),
help_text=_('Postal code'),
blank=True,
)
postal_city = models.CharField(max_length=50,
verbose_name=_('City/Region'),
help_text=_('Postal code city/region'),
blank=True)
postal_city = models.CharField(
max_length=50,
verbose_name=_('City/Region'),
help_text=_('Postal code city/region'),
blank=True,
)
province = models.CharField(max_length=50,
verbose_name=_('State/Province'),
help_text=_('State or province'),
blank=True)
province = models.CharField(
max_length=50,
verbose_name=_('State/Province'),
help_text=_('State or province'),
blank=True,
)
country = models.CharField(max_length=50,
verbose_name=_('Country'),
help_text=_('Address country'),
blank=True)
country = models.CharField(
max_length=50,
verbose_name=_('Country'),
help_text=_('Address country'),
blank=True,
)
shipping_notes = models.CharField(max_length=100,
verbose_name=_('Courier shipping notes'),
help_text=_('Notes for shipping courier'),
blank=True)
shipping_notes = models.CharField(
max_length=100,
verbose_name=_('Courier shipping notes'),
help_text=_('Notes for shipping courier'),
blank=True,
)
internal_shipping_notes = models.CharField(max_length=100,
verbose_name=_('Internal shipping notes'),
help_text=_('Shipping notes for internal use'),
blank=True)
internal_shipping_notes = models.CharField(
max_length=100,
verbose_name=_('Internal shipping notes'),
help_text=_('Shipping notes for internal use'),
blank=True,
)
link = InvenTreeURLField(blank=True,
verbose_name=_('Link'),
help_text=_('Link to address information (external)'))
link = InvenTreeURLField(
blank=True,
verbose_name=_('Link'),
help_text=_('Link to address information (external)'),
)
class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
@ -401,6 +466,7 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
class Meta:
"""Metaclass defines extra model options"""
unique_together = ('part', 'manufacturer', 'MPN')
@staticmethod
@ -408,23 +474,21 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
"""Return the API URL associated with the ManufacturerPart instance"""
return reverse('api-manufacturer-part-list')
part = models.ForeignKey('part.Part', on_delete=models.CASCADE,
related_name='manufacturer_parts',
verbose_name=_('Base Part'),
limit_choices_to={
'purchaseable': True,
},
help_text=_('Select part'),
)
part = models.ForeignKey(
'part.Part',
on_delete=models.CASCADE,
related_name='manufacturer_parts',
verbose_name=_('Base Part'),
limit_choices_to={'purchaseable': True},
help_text=_('Select part'),
)
manufacturer = models.ForeignKey(
Company,
on_delete=models.CASCADE,
null=True,
related_name='manufactured_parts',
limit_choices_to={
'is_manufacturer': True
},
limit_choices_to={'is_manufacturer': True},
verbose_name=_('Manufacturer'),
help_text=_('Select manufacturer'),
)
@ -433,19 +497,22 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
null=True,
max_length=100,
verbose_name=_('MPN'),
help_text=_('Manufacturer Part Number')
help_text=_('Manufacturer Part Number'),
)
link = InvenTreeURLField(
blank=True, null=True,
blank=True,
null=True,
verbose_name=_('Link'),
help_text=_('URL for external manufacturer part link')
help_text=_('URL for external manufacturer part link'),
)
description = models.CharField(
max_length=250, blank=True, null=True,
max_length=250,
blank=True,
null=True,
verbose_name=_('Description'),
help_text=_('Manufacturer part description')
help_text=_('Manufacturer part description'),
)
tags = TaggableManager(blank=True)
@ -456,12 +523,20 @@ class ManufacturerPart(MetadataMixin, InvenTreeBarcodeMixin, models.Model):
manufacturer_part = None
try:
manufacturer_part = ManufacturerPart.objects.get(part=part, manufacturer=manufacturer, MPN=mpn)
manufacturer_part = ManufacturerPart.objects.get(
part=part, manufacturer=manufacturer, MPN=mpn
)
except ManufacturerPart.DoesNotExist:
pass
if not manufacturer_part:
manufacturer_part = ManufacturerPart(part=part, manufacturer=manufacturer, MPN=mpn, description=description, link=link)
manufacturer_part = ManufacturerPart(
part=part,
manufacturer=manufacturer,
MPN=mpn,
description=description,
link=link,
)
manufacturer_part.save()
return manufacturer_part
@ -489,10 +564,14 @@ class ManufacturerPartAttachment(InvenTreeAttachment):
def getSubdir(self):
"""Return the subdirectory where attachment files for the ManufacturerPart model are located"""
return os.path.join("manufacturer_part_files", str(self.manufacturer_part.id))
return os.path.join('manufacturer_part_files', str(self.manufacturer_part.id))
manufacturer_part = models.ForeignKey(ManufacturerPart, on_delete=models.CASCADE,
verbose_name=_('Manufacturer Part'), related_name='attachments')
manufacturer_part = models.ForeignKey(
ManufacturerPart,
on_delete=models.CASCADE,
verbose_name=_('Manufacturer Part'),
related_name='attachments',
)
class ManufacturerPartParameter(models.Model):
@ -505,6 +584,7 @@ class ManufacturerPartParameter(models.Model):
class Meta:
"""Metaclass defines extra model options"""
unique_together = ('manufacturer_part', 'name')
@staticmethod
@ -523,21 +603,22 @@ class ManufacturerPartParameter(models.Model):
max_length=500,
blank=False,
verbose_name=_('Name'),
help_text=_('Parameter name')
help_text=_('Parameter name'),
)
value = models.CharField(
max_length=500,
blank=False,
verbose_name=_('Value'),
help_text=_('Parameter value')
help_text=_('Parameter value'),
)
units = models.CharField(
max_length=64,
blank=True, null=True,
blank=True,
null=True,
verbose_name=_('Units'),
help_text=_('Parameter units')
help_text=_('Parameter units'),
)
@ -551,10 +632,10 @@ class SupplierPartManager(models.Manager):
def get_queryset(self):
"""Prefetch related fields when querying against the SupplierPart model"""
# Always prefetch related models
return super().get_queryset().prefetch_related(
'part',
'supplier',
'manufacturer_part__manufacturer',
return (
super()
.get_queryset()
.prefetch_related('part', 'supplier', 'manufacturer_part__manufacturer')
)
@ -580,6 +661,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
class Meta:
"""Metaclass defines extra model options"""
unique_together = ('part', 'supplier', 'SKU')
# This model was moved from the 'Part' app
@ -600,11 +682,7 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
def api_instance_filters(self):
"""Return custom API filters for this particular instance"""
return {
'manufacturer_part': {
'part': self.part.pk
}
}
return {'manufacturer_part': {'part': self.part.pk}}
def clean(self):
"""Custom clean action for the SupplierPart model:
@ -624,36 +702,38 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
try:
# Attempt conversion to specified unit
native_value = InvenTree.conversion.convert_physical_value(
self.pack_quantity, self.part.units,
strip_units=False
self.pack_quantity, self.part.units, strip_units=False
)
# If part units are not provided, value must be dimensionless
if not self.part.units and not InvenTree.conversion.is_dimensionless(native_value):
if not self.part.units and not InvenTree.conversion.is_dimensionless(
native_value
):
raise ValidationError({
'pack_quantity': _("Pack units must be compatible with the base part units")
'pack_quantity': _(
'Pack units must be compatible with the base part units'
)
})
# Native value must be greater than zero
if float(native_value.magnitude) <= 0:
raise ValidationError({
'pack_quantity': _("Pack units must be greater than zero")
'pack_quantity': _('Pack units must be greater than zero')
})
# Update native pack units value
self.pack_quantity_native = Decimal(native_value.magnitude)
except ValidationError as e:
raise ValidationError({
'pack_quantity': e.messages
})
raise ValidationError({'pack_quantity': e.messages})
# Ensure that the linked manufacturer_part points to the same part!
if self.manufacturer_part and self.part:
if self.manufacturer_part.part != self.part:
raise ValidationError({
'manufacturer_part': _("Linked manufacturer part must reference the same base part"),
'manufacturer_part': _(
'Linked manufacturer part must reference the same base part'
)
})
def save(self, *args, **kwargs):
@ -666,7 +746,9 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
# Retrieve manufacturer part
try:
manufacturer_part = ManufacturerPart.objects.get(manufacturer__name=manufacturer_name, MPN=MPN)
manufacturer_part = ManufacturerPart.objects.get(
manufacturer__name=manufacturer_name, MPN=MPN
)
except (ValueError, Company.DoesNotExist):
# ManufacturerPart does not exist
pass
@ -676,7 +758,9 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
# Connect ManufacturerPart to SupplierPart
self.manufacturer_part = manufacturer_part
else:
raise ValidationError(f'SupplierPart {self.__str__} is already linked to {self.manufacturer_part}')
raise ValidationError(
f'SupplierPart {self.__str__} is already linked to {self.manufacturer_part}'
)
self.clean()
self.validate_unique()
@ -684,17 +768,17 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
super().save(*args, **kwargs)
part = models.ForeignKey(
'part.Part', on_delete=models.CASCADE,
'part.Part',
on_delete=models.CASCADE,
related_name='supplier_parts',
verbose_name=_('Base Part'),
limit_choices_to={
'purchaseable': True,
},
limit_choices_to={'purchaseable': True},
help_text=_('Select part'),
)
supplier = models.ForeignKey(
Company, on_delete=models.CASCADE,
Company,
on_delete=models.CASCADE,
related_name='supplied_parts',
limit_choices_to={'is_supplier': True},
verbose_name=_('Supplier'),
@ -703,50 +787,71 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
SKU = models.CharField(
max_length=100,
verbose_name=__("SKU = Stock Keeping Unit (supplier part number)", 'SKU'),
help_text=_('Supplier stock keeping unit')
verbose_name=__('SKU = Stock Keeping Unit (supplier part number)', 'SKU'),
help_text=_('Supplier stock keeping unit'),
)
manufacturer_part = models.ForeignKey(
ManufacturerPart, on_delete=models.CASCADE,
blank=True, null=True,
ManufacturerPart,
on_delete=models.CASCADE,
blank=True,
null=True,
related_name='supplier_parts',
verbose_name=_('Manufacturer Part'),
help_text=_('Select manufacturer part'),
)
link = InvenTreeURLField(
blank=True, null=True,
blank=True,
null=True,
verbose_name=_('Link'),
help_text=_('URL for external supplier part link')
help_text=_('URL for external supplier part link'),
)
description = models.CharField(
max_length=250, blank=True, null=True,
max_length=250,
blank=True,
null=True,
verbose_name=_('Description'),
help_text=_('Supplier part description')
help_text=_('Supplier part description'),
)
note = models.CharField(
max_length=100, blank=True, null=True,
max_length=100,
blank=True,
null=True,
verbose_name=_('Note'),
help_text=_('Notes')
help_text=_('Notes'),
)
base_cost = models.DecimalField(max_digits=10, decimal_places=3, default=0, validators=[MinValueValidator(0)], verbose_name=_('base cost'), help_text=_('Minimum charge (e.g. stocking fee)'))
base_cost = models.DecimalField(
max_digits=10,
decimal_places=3,
default=0,
validators=[MinValueValidator(0)],
verbose_name=_('base cost'),
help_text=_('Minimum charge (e.g. stocking fee)'),
)
packaging = models.CharField(max_length=50, blank=True, null=True, verbose_name=_('Packaging'), help_text=_('Part packaging'))
packaging = models.CharField(
max_length=50,
blank=True,
null=True,
verbose_name=_('Packaging'),
help_text=_('Part packaging'),
)
pack_quantity = models.CharField(
max_length=25,
verbose_name=_('Pack Quantity'),
help_text=_('Total quantity supplied in a single pack. Leave empty for single items.'),
help_text=_(
'Total quantity supplied in a single pack. Leave empty for single items.'
),
blank=True,
)
pack_quantity_native = RoundingDecimalField(
max_digits=20, decimal_places=10, default=1,
null=True,
max_digits=20, decimal_places=10, default=1, null=True
)
def base_quantity(self, quantity=1) -> Decimal:
@ -756,20 +861,29 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
return q
multiple = models.PositiveIntegerField(default=1, validators=[MinValueValidator(1)], verbose_name=_('multiple'), help_text=_('Order multiple'))
multiple = models.PositiveIntegerField(
default=1,
validators=[MinValueValidator(1)],
verbose_name=_('multiple'),
help_text=_('Order multiple'),
)
# TODO - Reimplement lead-time as a charfield with special validation (pattern matching).
# lead_time = models.DurationField(blank=True, null=True)
available = models.DecimalField(
max_digits=10, decimal_places=3, default=0,
max_digits=10,
decimal_places=3,
default=0,
validators=[MinValueValidator(0)],
verbose_name=_('Available'),
help_text=_('Quantity available from supplier'),
)
availability_updated = models.DateTimeField(
null=True, blank=True, verbose_name=_('Availability Updated'),
null=True,
blank=True,
verbose_name=_('Availability Updated'),
help_text=_('Date of last update of availability data'),
)
@ -821,17 +935,15 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
if self.price_breaks.filter(quantity=quantity, part=self.pk).exists():
return
SupplierPriceBreak.objects.create(
part=self,
quantity=quantity,
price=price
)
SupplierPriceBreak.objects.create(part=self, quantity=quantity, price=price)
get_price = common.models.get_price
def open_orders(self):
"""Return a database query for PurchaseOrder line items for this SupplierPart, limited to purchase orders that are open / outstanding."""
return self.purchase_order_line_items.prefetch_related('order').filter(order__status__in=PurchaseOrderStatusGroups.OPEN)
return self.purchase_order_line_items.prefetch_related('order').filter(
order__status__in=PurchaseOrderStatusGroups.OPEN
)
def on_order(self):
"""Return the total quantity of items currently on order.
@ -852,7 +964,10 @@ class SupplierPart(MetadataMixin, InvenTreeBarcodeMixin, common.models.MetaMixin
def purchase_orders(self):
"""Returns a list of purchase orders relating to this supplier part."""
return [line.order for line in self.purchase_order_line_items.all().prefetch_related('order')]
return [
line.order
for line in self.purchase_order_line_items.all().prefetch_related('order')
]
@property
def pretty_name(self):
@ -891,7 +1006,8 @@ class SupplierPriceBreak(common.models.PriceBreak):
class Meta:
"""Metaclass defines extra model options"""
unique_together = ("part", "quantity")
unique_together = ('part', 'quantity')
# This model was moved from the 'Part' app
db_table = 'part_supplierpricebreak'
@ -905,22 +1021,31 @@ class SupplierPriceBreak(common.models.PriceBreak):
"""Return the API URL associated with the SupplierPriceBreak model"""
return reverse('api-part-supplier-price-list')
part = models.ForeignKey(SupplierPart, on_delete=models.CASCADE, related_name='pricebreaks', verbose_name=_('Part'),)
part = models.ForeignKey(
SupplierPart,
on_delete=models.CASCADE,
related_name='pricebreaks',
verbose_name=_('Part'),
)
@receiver(post_save, sender=SupplierPriceBreak, dispatch_uid='post_save_supplier_price_break')
@receiver(
post_save, sender=SupplierPriceBreak, dispatch_uid='post_save_supplier_price_break'
)
def after_save_supplier_price(sender, instance, created, **kwargs):
"""Callback function when a SupplierPriceBreak is created or updated"""
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
if instance.part and instance.part.part:
instance.part.part.schedule_pricing_update(create=True)
@receiver(post_delete, sender=SupplierPriceBreak, dispatch_uid='post_delete_supplier_price_break')
@receiver(
post_delete,
sender=SupplierPriceBreak,
dispatch_uid='post_delete_supplier_price_break',
)
def after_delete_supplier_price(sender, instance, **kwargs):
"""Callback function when a SupplierPriceBreak is deleted"""
if InvenTree.ready.canAppAccessDatabase() and not InvenTree.ready.isImportingData():
if instance.part and instance.part.part:
instance.part.part.schedule_pricing_update(create=False)

View File

@ -10,20 +10,29 @@ from sql_util.utils import SubqueryCount
from taggit.serializers import TagListSerializerField
import part.filters
from InvenTree.serializers import (InvenTreeAttachmentSerializer,
InvenTreeCurrencySerializer,
InvenTreeDecimalField,
InvenTreeImageSerializerField,
InvenTreeModelSerializer,
InvenTreeMoneySerializer,
InvenTreeTagModelSerializer,
RemoteImageMixin)
from InvenTree.serializers import (
InvenTreeAttachmentSerializer,
InvenTreeCurrencySerializer,
InvenTreeDecimalField,
InvenTreeImageSerializerField,
InvenTreeModelSerializer,
InvenTreeMoneySerializer,
InvenTreeTagModelSerializer,
RemoteImageMixin,
)
from part.serializers import PartBriefSerializer
from .models import (Address, Company, CompanyAttachment, Contact,
ManufacturerPart, ManufacturerPartAttachment,
ManufacturerPartParameter, SupplierPart,
SupplierPriceBreak)
from .models import (
Address,
Company,
CompanyAttachment,
Contact,
ManufacturerPart,
ManufacturerPartAttachment,
ManufacturerPartParameter,
SupplierPart,
SupplierPriceBreak,
)
class CompanyBriefSerializer(InvenTreeModelSerializer):
@ -33,13 +42,7 @@ class CompanyBriefSerializer(InvenTreeModelSerializer):
"""Metaclass options."""
model = Company
fields = [
'pk',
'url',
'name',
'description',
'image',
]
fields = ['pk', 'url', 'name', 'description', 'image']
url = serializers.CharField(source='get_absolute_url', read_only=True)
@ -86,7 +89,7 @@ class AddressBriefSerializer(InvenTreeModelSerializer):
'province',
'country',
'shipping_notes',
'internal_shipping_notes'
'internal_shipping_notes',
]
@ -119,7 +122,7 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
'parts_manufactured',
'remote_image',
'address_count',
'primary_address'
'primary_address',
]
@staticmethod
@ -130,13 +133,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
parts_manufactured=SubqueryCount('manufactured_parts')
)
queryset = queryset.annotate(
parts_supplied=SubqueryCount('supplied_parts')
)
queryset = queryset.annotate(parts_supplied=SubqueryCount('supplied_parts'))
queryset = queryset.annotate(
address_count=SubqueryCount('addresses')
)
queryset = queryset.annotate(address_count=SubqueryCount('addresses'))
return queryset
@ -150,7 +149,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
parts_manufactured = serializers.IntegerField(read_only=True)
address_count = serializers.IntegerField(read_only=True)
currency = InvenTreeCurrencySerializer(help_text=_('Default currency used for this supplier'), required=True)
currency = InvenTreeCurrencySerializer(
help_text=_('Default currency used for this supplier'), required=True
)
def save(self):
"""Save the Company instance"""
@ -167,12 +168,9 @@ class CompanySerializer(RemoteImageMixin, InvenTreeModelSerializer):
remote_img.save(buffer, format=fmt)
# Construct a simplified name for the image
filename = f"company_{company.pk}_image.{fmt.lower()}"
filename = f'company_{company.pk}_image.{fmt.lower()}'
company.image.save(
filename,
ContentFile(buffer.getvalue()),
)
company.image.save(filename, ContentFile(buffer.getvalue()))
return self.instance
@ -182,11 +180,10 @@ class CompanyAttachmentSerializer(InvenTreeAttachmentSerializer):
class Meta:
"""Metaclass defines serializer options"""
model = CompanyAttachment
fields = InvenTreeAttachmentSerializer.attachment_fields([
'company',
])
fields = InvenTreeAttachmentSerializer.attachment_fields(['company'])
class ContactSerializer(InvenTreeModelSerializer):
@ -196,14 +193,7 @@ class ContactSerializer(InvenTreeModelSerializer):
"""Metaclass options"""
model = Contact
fields = [
'pk',
'company',
'name',
'phone',
'email',
'role',
]
fields = ['pk', 'company', 'name', 'phone', 'email', 'role']
class ManufacturerPartSerializer(InvenTreeTagModelSerializer):
@ -224,7 +214,6 @@ class ManufacturerPartSerializer(InvenTreeTagModelSerializer):
'MPN',
'link',
'barcode_hash',
'tags',
]
@ -249,11 +238,15 @@ class ManufacturerPartSerializer(InvenTreeTagModelSerializer):
part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
manufacturer_detail = CompanyBriefSerializer(source='manufacturer', many=False, read_only=True)
manufacturer_detail = CompanyBriefSerializer(
source='manufacturer', many=False, read_only=True
)
pretty_name = serializers.CharField(read_only=True)
manufacturer = serializers.PrimaryKeyRelatedField(queryset=Company.objects.filter(is_manufacturer=True))
manufacturer = serializers.PrimaryKeyRelatedField(
queryset=Company.objects.filter(is_manufacturer=True)
)
class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer):
@ -264,9 +257,7 @@ class ManufacturerPartAttachmentSerializer(InvenTreeAttachmentSerializer):
model = ManufacturerPartAttachment
fields = InvenTreeAttachmentSerializer.attachment_fields([
'manufacturer_part',
])
fields = InvenTreeAttachmentSerializer.attachment_fields(['manufacturer_part'])
class ManufacturerPartParameterSerializer(InvenTreeModelSerializer):
@ -295,7 +286,9 @@ class ManufacturerPartParameterSerializer(InvenTreeModelSerializer):
if not man_detail:
self.fields.pop('manufacturer_part_detail')
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', many=False, read_only=True)
manufacturer_part_detail = ManufacturerPartSerializer(
source='manufacturer_part', many=False, read_only=True
)
class SupplierPartSerializer(InvenTreeTagModelSerializer):
@ -330,14 +323,10 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
'supplier_detail',
'url',
'updated',
'tags',
]
read_only_fields = [
'availability_updated',
'barcode_hash',
]
read_only_fields = ['availability_updated', 'barcode_hash']
tags = TagListSerializerField(required=False)
@ -377,19 +366,27 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
part_detail = PartBriefSerializer(source='part', many=False, read_only=True)
supplier_detail = CompanyBriefSerializer(source='supplier', many=False, read_only=True)
supplier_detail = CompanyBriefSerializer(
source='supplier', many=False, read_only=True
)
manufacturer_detail = CompanyBriefSerializer(source='manufacturer_part.manufacturer', many=False, read_only=True)
manufacturer_detail = CompanyBriefSerializer(
source='manufacturer_part.manufacturer', many=False, read_only=True
)
pretty_name = serializers.CharField(read_only=True)
supplier = serializers.PrimaryKeyRelatedField(queryset=Company.objects.filter(is_supplier=True))
supplier = serializers.PrimaryKeyRelatedField(
queryset=Company.objects.filter(is_supplier=True)
)
manufacturer = serializers.CharField(read_only=True)
MPN = serializers.CharField(read_only=True)
manufacturer_part_detail = ManufacturerPartSerializer(source='manufacturer_part', part_detail=False, read_only=True)
manufacturer_part_detail = ManufacturerPartSerializer(
source='manufacturer_part', part_detail=False, read_only=True
)
url = serializers.CharField(source='get_absolute_url', read_only=True)
@ -403,9 +400,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
Fields:
in_stock: Current stock quantity for each SupplierPart
"""
queryset = queryset.annotate(
in_stock=part.filters.annotate_total_stock()
)
queryset = queryset.annotate(in_stock=part.filters.annotate_total_stock())
return queryset
@ -436,10 +431,7 @@ class SupplierPartSerializer(InvenTreeTagModelSerializer):
MPN = self.initial_data.get('MPN', None)
if manufacturer and MPN:
kwargs = {
'manufacturer': manufacturer,
'MPN': MPN,
}
kwargs = {'manufacturer': manufacturer, 'MPN': MPN}
supplier_part.save(**kwargs)
return supplier_part
@ -479,17 +471,19 @@ class SupplierPriceBreakSerializer(InvenTreeModelSerializer):
quantity = InvenTreeDecimalField()
price = InvenTreeMoneySerializer(
allow_null=True,
required=True,
label=_('Price'),
)
price = InvenTreeMoneySerializer(allow_null=True, required=True, label=_('Price'))
price_currency = InvenTreeCurrencySerializer()
supplier = serializers.PrimaryKeyRelatedField(source='part.supplier', many=False, read_only=True)
supplier = serializers.PrimaryKeyRelatedField(
source='part.supplier', many=False, read_only=True
)
supplier_detail = CompanyBriefSerializer(source='part.supplier', many=False, read_only=True)
supplier_detail = CompanyBriefSerializer(
source='part.supplier', many=False, read_only=True
)
# Detail serializer for SupplierPart
part_detail = SupplierPartSerializer(source='part', brief=True, many=False, read_only=True)
part_detail = SupplierPartSerializer(
source='part', brief=True, many=False, read_only=True
)

View File

@ -12,10 +12,7 @@ from .models import Address, Company, Contact, ManufacturerPart, SupplierPart
class CompanyTest(InvenTreeAPITestCase):
"""Series of tests for the Company DRF API."""
roles = [
'purchase_order.add',
'purchase_order.change',
]
roles = ['purchase_order.add', 'purchase_order.change']
@classmethod
def setUpTestData(cls):
@ -23,9 +20,18 @@ class CompanyTest(InvenTreeAPITestCase):
super().setUpTestData()
# Create some company objects to work with
cls.acme = Company.objects.create(name='ACME', description='Supplier', is_customer=False, is_supplier=True)
Company.objects.create(name='Drippy Cup Co.', description='Customer', is_customer=True, is_supplier=False)
Company.objects.create(name='Sippy Cup Emporium', description='Another supplier')
cls.acme = Company.objects.create(
name='ACME', description='Supplier', is_customer=False, is_supplier=True
)
Company.objects.create(
name='Drippy Cup Co.',
description='Customer',
is_customer=True,
is_supplier=False,
)
Company.objects.create(
name='Sippy Cup Emporium', description='Another supplier'
)
def test_company_list(self):
"""Test the list API endpoint for the Company model"""
@ -82,22 +88,13 @@ class CompanyTest(InvenTreeAPITestCase):
url = reverse('api-company-list')
# Name is required
response = self.post(
url,
{
'description': 'A description!',
},
expected_code=400
)
response = self.post(url, {'description': 'A description!'}, expected_code=400)
# Minimal example, checking default values
response = self.post(
url,
{
'name': 'My API Company',
'description': 'A company created via the API',
},
expected_code=201
{'name': 'My API Company', 'description': 'A company created via the API'},
expected_code=201,
)
self.assertTrue(response.data['is_supplier'])
@ -110,14 +107,14 @@ class CompanyTest(InvenTreeAPITestCase):
response = self.post(
url,
{
'name': "Another Company",
'description': "Also created via the API!",
'name': 'Another Company',
'description': 'Also created via the API!',
'currency': 'AUD',
'is_supplier': False,
'is_manufacturer': True,
'is_customer': True,
},
expected_code=201
expected_code=201,
)
self.assertEqual(response.data['currency'], 'AUD')
@ -128,12 +125,8 @@ class CompanyTest(InvenTreeAPITestCase):
# Attempt to create with invalid currency
response = self.post(
url,
{
'name': "A name",
'description': 'A description',
'currency': 'POQD',
},
expected_code=400
{'name': 'A name', 'description': 'A description', 'currency': 'POQD'},
expected_code=400,
)
self.assertTrue('currency' in response.data)
@ -151,10 +144,8 @@ class ContactTest(InvenTreeAPITestCase):
# Create some companies
companies = [
Company(
name=f"Company {idx}",
description="Some company"
) for idx in range(3)
Company(name=f'Company {idx}', description='Some company')
for idx in range(3)
]
Company.objects.bulk_create(companies)
@ -164,10 +155,7 @@ class ContactTest(InvenTreeAPITestCase):
# Create some contacts
for cmp in Company.objects.all():
contacts += [
Contact(
company=cmp,
name=f"My name {idx}",
) for idx in range(3)
Contact(company=cmp, name=f'My name {idx}') for idx in range(3)
]
Contact.objects.bulk_create(contacts)
@ -187,13 +175,7 @@ class ContactTest(InvenTreeAPITestCase):
# Filter by particular company
for cmp in Company.objects.all():
response = self.get(
self.url,
{
'company': cmp.pk,
},
expected_code=200
)
response = self.get(self.url, {'company': cmp.pk}, expected_code=200)
self.assertEqual(len(response.data), 3)
@ -205,23 +187,13 @@ class ContactTest(InvenTreeAPITestCase):
# Without required permissions, creation should fail
self.post(
self.url,
{
'company': company.pk,
'name': 'Joe Bloggs',
},
expected_code=403
self.url, {'company': company.pk, 'name': 'Joe Bloggs'}, expected_code=403
)
self.assignRole('return_order.add')
self.post(
self.url,
{
'company': company.pk,
'name': 'Joe Bloggs',
},
expected_code=201
self.url, {'company': company.pk, 'name': 'Joe Bloggs'}, expected_code=201
)
self.assertEqual(Contact.objects.count(), n + 1)
@ -239,23 +211,11 @@ class ContactTest(InvenTreeAPITestCase):
for key in ['pk', 'name', 'role']:
self.assertIn(key, data)
self.patch(
url,
{
'role': 'model',
},
expected_code=403
)
self.patch(url, {'role': 'model'}, expected_code=403)
self.assignRole('purchase_order.change')
self.patch(
url,
{
'role': 'x',
},
expected_code=200
)
self.patch(url, {'role': 'x'}, expected_code=200)
# Get the contact again
contact = Contact.objects.first()
@ -291,10 +251,8 @@ class AddressTest(InvenTreeAPITestCase):
cls.num_addr = 3
# Create some companies
companies = [
Company(
name=f"Company {idx}",
description="Some company"
) for idx in range(cls.num_companies)
Company(name=f'Company {idx}', description='Some company')
for idx in range(cls.num_companies)
]
Company.objects.bulk_create(companies)
@ -304,10 +262,8 @@ class AddressTest(InvenTreeAPITestCase):
# Create some contacts
for cmp in Company.objects.all():
addresses += [
Address(
company=cmp,
title=f"Address no. {idx}",
) for idx in range(cls.num_addr)
Address(company=cmp, title=f'Address no. {idx}')
for idx in range(cls.num_addr)
]
cls.url = reverse('api-address-list')
@ -332,21 +288,11 @@ class AddressTest(InvenTreeAPITestCase):
"""Test creating a new address"""
company = Company.objects.first()
self.post(self.url,
{
'company': company.pk,
'title': 'HQ'
},
expected_code=403)
self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=403)
self.assignRole('purchase_order.add')
self.post(self.url,
{
'company': company.pk,
'title': 'HQ'
},
expected_code=201)
self.post(self.url, {'company': company.pk, 'title': 'HQ'}, expected_code=201)
def test_get(self):
"""Test that objects are properly returned from a get"""
@ -357,7 +303,15 @@ class AddressTest(InvenTreeAPITestCase):
self.assertEqual(response.data['pk'], addr.pk)
for key in ['title', 'line1', 'line2', 'postal_code', 'postal_city', 'province', 'country']:
for key in [
'title',
'line1',
'line2',
'postal_code',
'postal_city',
'province',
'country',
]:
self.assertIn(key, response.data)
def test_edit(self):
@ -366,23 +320,11 @@ class AddressTest(InvenTreeAPITestCase):
url = reverse('api-address-detail', kwargs={'pk': addr.pk})
self.patch(
url,
{
'title': 'Hello'
},
expected_code=403
)
self.patch(url, {'title': 'Hello'}, expected_code=403)
self.assignRole('purchase_order.change')
self.patch(
url,
{
'title': 'World'
},
expected_code=200
)
self.patch(url, {'title': 'World'}, expected_code=200)
data = self.get(url, expected_code=200).data
@ -415,10 +357,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
'supplier_part',
]
roles = [
'part.add',
'part.change',
]
roles = ['part.add', 'part.change']
def test_manufacturer_part_list(self):
"""Test the ManufacturerPart API list functionality"""
@ -429,11 +368,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), 3)
# Create manufacturer part
data = {
'part': 1,
'manufacturer': 7,
'MPN': 'MPN_TEST',
}
data = {'part': 1, 'manufacturer': 7, 'MPN': 'MPN_TEST'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data['MPN'], 'MPN_TEST')
@ -456,9 +391,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(response.data['MPN'], 'MPN123')
# Change the MPN
data = {
'MPN': 'MPN-TEST-123',
}
data = {'MPN': 'MPN-TEST-123'}
response = self.client.patch(url, data, format='json')
@ -485,7 +418,7 @@ class ManufacturerTest(InvenTreeAPITestCase):
'MPN': 'PART_NUMBER',
'link': 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
},
expected_code=201
expected_code=201,
)
pk = response.data['pk']
@ -503,10 +436,16 @@ class ManufacturerTest(InvenTreeAPITestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
# Check link is not modified
self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E')
self.assertEqual(
response.data['link'],
'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
)
# Check link is not modified
self.assertEqual(response.data['link'], 'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E')
self.assertEqual(
response.data['link'],
'https://www.axel-larsson.se/Exego.aspx?p_id=341&ArtNr=0804020E',
)
class SupplierPartTest(InvenTreeAPITestCase):
@ -521,12 +460,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
'supplier_part',
]
roles = [
'part.add',
'part.change',
'part.add',
'purchase_order.change',
]
roles = ['part.add', 'part.change', 'part.add', 'purchase_order.change']
def test_supplier_part_list(self):
"""Test the SupplierPart API list functionality"""
@ -543,10 +477,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
self.assertEqual(len(response.data), supplier.supplied_parts.count())
# Filter by Part reference
expected = {
1: 4,
25: 2,
}
expected = {1: 4, 25: 2}
for pk, n in expected.items():
response = self.get(url, {'part': pk}, expected_code=200)
@ -559,12 +490,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
# Should fail when sending an invalid 'available' field
response = self.post(
url,
{
'part': 1,
'supplier': 2,
'SKU': 'QQ',
'available': 'not a number',
},
{'part': 1, 'supplier': 2, 'SKU': 'QQ', 'available': 'not a number'},
expected_code=400,
)
@ -572,13 +498,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
# Create a SupplierPart without specifying available quantity
response = self.post(
url,
{
'part': 1,
'supplier': 2,
'SKU': 'QQ',
},
expected_code=201
url, {'part': 1, 'supplier': 2, 'SKU': 'QQ'}, expected_code=201
)
sp = SupplierPart.objects.get(pk=response.data['pk'])
@ -589,9 +509,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
# Now, *update* the available quantity via the API
self.patch(
reverse('api-supplier-part-detail', kwargs={'pk': sp.pk}),
{
'available': 1234,
},
{'available': 1234},
expected_code=200,
)
@ -602,12 +520,7 @@ class SupplierPartTest(InvenTreeAPITestCase):
# We should also be able to create a SupplierPart with initial 'available' quantity
response = self.post(
url,
{
'part': 1,
'supplier': 2,
'SKU': 'QQQ',
'available': 999,
},
{'part': 1, 'supplier': 2, 'SKU': 'QQQ', 'available': 999},
expected_code=201,
)
@ -629,11 +542,7 @@ class CompanyMetadataAPITest(InvenTreeAPITestCase):
'supplier_part',
]
roles = [
'company.change',
'purchase_order.change',
'part.change',
]
roles = ['company.change', 'purchase_order.change', 'part.change']
def metatester(self, apikey, model):
"""Generic tester"""
@ -651,17 +560,15 @@ class CompanyMetadataAPITest(InvenTreeAPITestCase):
self.patch(
url,
{
'metadata': {
f'abc-{numstr}': f'xyz-{apikey}-{numstr}',
}
},
expected_code=200
{'metadata': {f'abc-{numstr}': f'xyz-{apikey}-{numstr}'}},
expected_code=200,
)
# Refresh
modeldata.refresh_from_db()
self.assertEqual(modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}')
self.assertEqual(
modeldata.get_metadata(f'abc-{numstr}'), f'xyz-{apikey}-{numstr}'
)
def test_metadata(self):
"""Test all endpoints"""

View File

@ -16,9 +16,7 @@ class TestForwardMigrations(MigratorTestCase):
Company = self.old_state.apps.get_model('company', 'company')
Company.objects.create(
name='MSPC',
description='Michael Scotts Paper Company',
is_supplier=True
name='MSPC', description='Michael Scotts Paper Company', is_supplier=True
)
def test_migrations(self):
@ -49,7 +47,10 @@ class TestManufacturerField(MigratorTestCase):
part = Part.objects.create(
name='Screw',
description='A single screw',
level=0, tree_id=0, lft=0, rght=0
level=0,
tree_id=0,
lft=0,
rght=0,
)
# Create a company to act as the supplier
@ -62,17 +63,11 @@ class TestManufacturerField(MigratorTestCase):
# Add some SupplierPart objects
SupplierPart.objects.create(
part=part,
supplier=supplier,
SKU='SCREW.001',
manufacturer_name='ACME',
part=part, supplier=supplier, SKU='SCREW.001', manufacturer_name='ACME'
)
SupplierPart.objects.create(
part=part,
supplier=supplier,
SKU='SCREW.002',
manufacturer_name='Zero Corp',
part=part, supplier=supplier, SKU='SCREW.002', manufacturer_name='Zero Corp'
)
self.assertEqual(Company.objects.count(), 1)
@ -186,17 +181,12 @@ class TestManufacturerPart(MigratorTestCase):
# No Manufacturer
SupplierPart.objects.create(
part=part,
supplier=supplier_2,
SKU='CAP-CER-01UF-2',
MPN='MUR-CAP-123456',
part=part, supplier=supplier_2, SKU='CAP-CER-01UF-2', MPN='MUR-CAP-123456'
)
# No Manufacturer data
SupplierPart.objects.create(
part=part,
supplier=supplier_2,
SKU='CAP-CER-01UF-3',
part=part, supplier=supplier_2, SKU='CAP-CER-01UF-3'
)
def test_manufacturer_part_objects(self):
@ -238,17 +228,20 @@ class TestCurrencyMigration(MigratorTestCase):
Part = self.old_state.apps.get_model('part', 'part')
part = Part.objects.create(
name="PART", description="A purchaseable part",
name='PART',
description='A purchaseable part',
purchaseable=True,
level=0,
tree_id=0,
lft=0,
rght=0
rght=0,
)
Company = self.old_state.apps.get_model('company', 'company')
supplier = Company.objects.create(name='Supplier', description='A supplier', is_supplier=True)
supplier = Company.objects.create(
name='Supplier', description='A supplier', is_supplier=True
)
SupplierPart = self.old_state.apps.get_model('company', 'supplierpart')
@ -256,8 +249,12 @@ class TestCurrencyMigration(MigratorTestCase):
Currency = self.old_state.apps.get_model('common', 'currency')
aud = Currency.objects.create(symbol='$', suffix='AUD', description='Australian Dollars', value=1.0)
usd = Currency.objects.create(symbol='$', suffix='USD', description='US Dollars', value=1.0)
aud = Currency.objects.create(
symbol='$', suffix='AUD', description='Australian Dollars', value=1.0
)
usd = Currency.objects.create(
symbol='$', suffix='USD', description='US Dollars', value=1.0
)
PB = self.old_state.apps.get_model('company', 'supplierpricebreak')
@ -312,7 +309,7 @@ class TestAddressMigration(MigratorTestCase):
a2 = Address.objects.filter(company=c2.pk).first()
self.assertEqual(a1.line1, self.short_l1)
self.assertEqual(a1.line2, "")
self.assertEqual(a1.line2, '')
self.assertEqual(a2.line1, self.long_l1)
self.assertEqual(a2.line2, self.l2)
self.assertEqual(c1.address, '')
@ -332,12 +329,18 @@ class TestSupplierPartQuantity(MigratorTestCase):
SupplierPart = self.old_state.apps.get_model('company', 'supplierpart')
self.part = Part.objects.create(
name="PART", description="A purchaseable part",
name='PART',
description='A purchaseable part',
purchaseable=True,
level=0, tree_id=0, lft=0, rght=0
level=0,
tree_id=0,
lft=0,
rght=0,
)
self.supplier = Company.objects.create(name='Supplier', description='A supplier', is_supplier=True)
self.supplier = Company.objects.create(
name='Supplier', description='A supplier', is_supplier=True
)
self.supplier_parts = []
@ -356,7 +359,6 @@ class TestSupplierPartQuantity(MigratorTestCase):
SupplierPart = self.new_state.apps.get_model('company', 'supplierpart')
for i, sp in enumerate(SupplierPart.objects.all()):
self.assertEqual(sp.pack_quantity, str(i + 1))
self.assertEqual(sp.pack_quantity_native, i + 1)

View File

@ -15,17 +15,15 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
def test_pack_quantity_dimensionless(self):
"""Test valid values for the 'pack_quantity' field"""
# Create a part without units (dimensionless)
part = Part.objects.create(name='Test Part', description='Test part description', component=True)
part = Part.objects.create(
name='Test Part', description='Test part description', component=True
)
# Create a supplier (company)
company = Company.objects.create(name='Test Company', is_supplier=True)
# Create a supplier part for this part
sp = SupplierPart.objects.create(
part=part,
supplier=company,
SKU='TEST-SKU'
)
sp = SupplierPart.objects.create(part=part, supplier=company, SKU='TEST-SKU')
# All these values are valid for a dimensionless part
pass_tests = {
@ -37,14 +35,7 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
}
# All these values are invalid for a dimensionless part
fail_tests = [
'1.2m',
'-1',
'0',
'0.0',
'100 feet',
'0 amps'
]
fail_tests = ['1.2m', '-1', '0', '0.0', '100 feet', '0 amps']
for test, expected in pass_tests.items():
sp.pack_quantity = test
@ -59,17 +50,18 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
def test_pack_quantity(self):
"""Test pack_quantity for a part with a specified dimension"""
# Create a part with units 'm'
part = Part.objects.create(name='Test Part', description='Test part description', component=True, units='m')
part = Part.objects.create(
name='Test Part',
description='Test part description',
component=True,
units='m',
)
# Create a supplier (company)
company = Company.objects.create(name='Test Company', is_supplier=True)
# Create a supplier part for this part
sp = SupplierPart.objects.create(
part=part,
supplier=company,
SKU='TEST-SKU'
)
sp = SupplierPart.objects.create(part=part, supplier=company, SKU='TEST-SKU')
# All these values are valid for a part with dimension 'm'
pass_tests = {
@ -87,23 +79,14 @@ class SupplierPartPackUnitsTests(InvenTreeTestCase):
# All these values are invalid for a part with dimension 'm'
# Either the values are invalid, or the units are incompatible
fail_tests = [
'-1',
'-1m',
'0',
'0m',
'12 deg',
'57 amps',
'-12 oz',
'17 yaks',
]
fail_tests = ['-1', '-1m', '0', '0m', '12 deg', '57 amps', '-12 oz', '17 yaks']
for test, expected in pass_tests.items():
sp.pack_quantity = test
sp.full_clean()
self.assertEqual(
round(Decimal(sp.pack_quantity_native), 10),
round(Decimal(str(expected)), 10)
round(Decimal(str(expected)), 10),
)
for test in fail_tests:

View File

@ -36,7 +36,9 @@ class CompanyViewTest(InvenTreeTestCase):
def test_manufacturer_part_detail_view(self):
"""Test the manufacturer part detail view."""
response = self.client.get(reverse('manufacturer-part-detail', kwargs={'pk': 1}))
response = self.client.get(
reverse('manufacturer-part-detail', kwargs={'pk': 1})
)
self.assertEqual(response.status_code, 200)
self.assertContains(response, 'MPN123')

View File

@ -8,8 +8,14 @@ from django.test import TestCase
from part.models import Part
from .models import (Address, Company, Contact, ManufacturerPart, SupplierPart,
rename_company_image)
from .models import (
Address,
Company,
Contact,
ManufacturerPart,
SupplierPart,
rename_company_image,
)
class CompanySimpleTest(TestCase):
@ -31,11 +37,13 @@ class CompanySimpleTest(TestCase):
"""Perform initialization for the tests in this class"""
super().setUpTestData()
Company.objects.create(name='ABC Co.',
description='Seller of ABC products',
website='www.abc-sales.com',
is_customer=False,
is_supplier=True)
Company.objects.create(
name='ABC Co.',
description='Seller of ABC products',
website='www.abc-sales.com',
is_customer=False,
is_supplier=True,
)
cls.acme0001 = SupplierPart.objects.get(SKU='ACME0001')
cls.acme0002 = SupplierPart.objects.get(SKU='ACME0002')
@ -95,9 +103,9 @@ class CompanySimpleTest(TestCase):
"""Unit tests for supplier part pricing"""
m2x4 = Part.objects.get(name='M2x4 LPHS')
self.assertEqual(m2x4.get_price_info(5.5), "38.5 - 41.25")
self.assertEqual(m2x4.get_price_info(10), "70 - 75")
self.assertEqual(m2x4.get_price_info(100), "125 - 350")
self.assertEqual(m2x4.get_price_info(5.5), '38.5 - 41.25')
self.assertEqual(m2x4.get_price_info(10), '70 - 75')
self.assertEqual(m2x4.get_price_info(100), '125 - 350')
pmin, pmax = m2x4.get_price_range(5)
self.assertEqual(pmin, 35)
@ -113,18 +121,14 @@ class CompanySimpleTest(TestCase):
"""Test validation for currency selection."""
# Create a company with a valid currency code (should pass)
company = Company.objects.create(
name='Test',
description='Toast',
currency='AUD',
name='Test', description='Toast', currency='AUD'
)
company.full_clean()
# Create a company with an invalid currency code (should fail)
company = Company.objects.create(
name='test',
description='Toasty',
currency='XZY',
name='test', description='Toasty', currency='XZY'
)
with self.assertRaises(ValidationError):
@ -154,7 +158,9 @@ class ContactSimpleTest(TestCase):
def setUp(self):
"""Initialization for the tests in this class"""
# Create a simple company
self.c = Company.objects.create(name='Test Corp.', description='We make stuff good')
self.c = Company.objects.create(
name='Test Corp.', description='We make stuff good'
)
# Add some contacts
Contact.objects.create(name='Joe Smith', company=self.c)
@ -178,7 +184,9 @@ class AddressTest(TestCase):
def setUp(self):
"""Initialization for the tests in this class"""
# Create a simple company
self.c = Company.objects.create(name='Test Corp.', description='We make stuff good')
self.c = Company.objects.create(
name='Test Corp.', description='We make stuff good'
)
def test_create(self):
"""Test that object creation with only company supplied is successful"""
@ -214,27 +222,28 @@ class AddressTest(TestCase):
def test_model_str(self):
"""Test value of __str__"""
t = "Test address"
l1 = "Busy street 56"
l2 = "Red building"
pcd = "12345"
pct = "City"
pv = "Province"
cn = "COUNTRY"
addr = Address.objects.create(company=self.c,
title=t,
line1=l1,
line2=l2,
postal_code=pcd,
postal_city=pct,
province=pv,
country=cn)
t = 'Test address'
l1 = 'Busy street 56'
l2 = 'Red building'
pcd = '12345'
pct = 'City'
pv = 'Province'
cn = 'COUNTRY'
addr = Address.objects.create(
company=self.c,
title=t,
line1=l1,
line2=l2,
postal_code=pcd,
postal_city=pct,
province=pv,
country=cn,
)
self.assertEqual(str(addr), f'{l1}, {l2}, {pcd}, {pct}, {pv}, {cn}')
addr2 = Address.objects.create(company=self.c,
title=t,
line1=l1,
postal_code=pcd)
addr2 = Address.objects.create(
company=self.c, title=t, line1=l1, postal_code=pcd
)
self.assertEqual(str(addr2), f'{l1}, {pcd}')
@ -242,13 +251,7 @@ class AddressTest(TestCase):
class ManufacturerPartSimpleTest(TestCase):
"""Unit tests for the ManufacturerPart model"""
fixtures = [
'category',
'company',
'location',
'part',
'manufacturer_part',
]
fixtures = ['category', 'company', 'location', 'part', 'manufacturer_part']
def setUp(self):
"""Initialization for the unit tests in this class"""
@ -266,9 +269,7 @@ class ManufacturerPartSimpleTest(TestCase):
# Create a supplier part
supplier = Company.objects.get(pk=5)
supplier_part = SupplierPart.objects.create(
part=self.part,
supplier=supplier,
SKU='SKU_TEST',
part=self.part, supplier=supplier, SKU='SKU_TEST'
)
supplier_part.save()

View File

@ -5,29 +5,41 @@ from django.urls import include, path, re_path
from . import views
company_urls = [
# Detail URLs for a specific Company instance
path(r'<int:pk>/', include([
re_path(r'^.*$', views.CompanyDetail.as_view(), name='company-detail'),
])),
path(
r'<int:pk>/',
include([
re_path(r'^.*$', views.CompanyDetail.as_view(), name='company-detail')
]),
),
re_path(r'suppliers/', views.CompanyIndex.as_view(), name='supplier-index'),
re_path(r'manufacturers/', views.CompanyIndex.as_view(), name='manufacturer-index'),
re_path(r'customers/', views.CompanyIndex.as_view(), name='customer-index'),
# Redirect any other patterns to the 'company' index which displays all companies
re_path(r'^.*$', views.CompanyIndex.as_view(), name='company-index'),
]
manufacturer_part_urls = [
path(r'<int:pk>/', views.ManufacturerPartDetail.as_view(template_name='company/manufacturer_part.html'), name='manufacturer-part-detail'),
path(
r'<int:pk>/',
views.ManufacturerPartDetail.as_view(
template_name='company/manufacturer_part.html'
),
name='manufacturer-part-detail',
)
]
supplier_part_urls = [
path(r'<int:pk>/', include([
re_path('^.*$', views.SupplierPartDetail.as_view(template_name='company/supplier_part.html'), name='supplier-part-detail'),
]))
path(
r'<int:pk>/',
include([
re_path(
'^.*$',
views.SupplierPartDetail.as_view(
template_name='company/supplier_part.html'
),
name='supplier-part-detail',
)
]),
)
]

View File

@ -44,14 +44,14 @@ class CompanyIndex(InvenTreeRoleMixin, ListView):
'button_text': _('New Customer'),
'filters': {'is_customer': 'true'},
'pagetype': 'customers',
}
},
}
default = {
'title': _('Companies'),
'button_text': _('New Company'),
'filters': {},
'pagetype': 'companies'
'pagetype': 'companies',
}
context = None
@ -88,6 +88,7 @@ class CompanyIndex(InvenTreeRoleMixin, ListView):
class CompanyDetail(InvenTreePluginViewMixin, DetailView):
"""Detail view for Company object."""
context_obect_name = 'company'
template_name = 'company/detail.html'
queryset = Company.objects.all()
@ -97,6 +98,7 @@ class CompanyDetail(InvenTreePluginViewMixin, DetailView):
class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView):
"""Detail view for ManufacturerPart."""
model = ManufacturerPart
template_name = 'company/manufacturer_part_detail.html'
context_object_name = 'part'
@ -106,6 +108,7 @@ class ManufacturerPartDetail(InvenTreePluginViewMixin, DetailView):
class SupplierPartDetail(InvenTreePluginViewMixin, DetailView):
"""Detail view for SupplierPart."""
model = SupplierPart
template_name = 'company/supplier_part_detail.html'
context_object_name = 'part'

View File

@ -9,9 +9,4 @@ States can be extended with custom options for each InvenTree instance - those o
from .states import StatusCode
from .transition import StateTransitionMixin, TransitionMethod, storage
__all__ = [
StatusCode,
storage,
TransitionMethod,
StateTransitionMixin,
]
__all__ = [StatusCode, storage, TransitionMethod, StateTransitionMixin]

View File

@ -18,9 +18,7 @@ class StatusView(APIView):
all available 'StockStatus' codes
"""
permission_classes = [
permissions.IsAuthenticated,
]
permission_classes = [permissions.IsAuthenticated]
# Override status_class for implementing subclass
MODEL_REF = 'statusmodel'
@ -30,7 +28,9 @@ class StatusView(APIView):
status_model = self.kwargs.get(self.MODEL_REF, None)
if status_model is None:
raise ValidationError(f"StatusView view called without '{self.MODEL_REF}' parameter")
raise ValidationError(
f"StatusView view called without '{self.MODEL_REF}' parameter"
)
return status_model
@ -39,15 +39,12 @@ class StatusView(APIView):
status_class = self.get_status_model()
if not inspect.isclass(status_class):
raise NotImplementedError("`status_class` not a class")
raise NotImplementedError('`status_class` not a class')
if not issubclass(status_class, StatusCode):
raise NotImplementedError("`status_class` not a valid StatusCode class")
raise NotImplementedError('`status_class` not a valid StatusCode class')
data = {
'class': status_class.__name__,
'values': status_class.dict(),
}
data = {'class': status_class.__name__, 'values': status_class.dict()}
return Response(data)
@ -55,9 +52,7 @@ class StatusView(APIView):
class AllStatusViews(StatusView):
"""Endpoint for listing all defined status models."""
permission_classes = [
permissions.IsAuthenticated,
]
permission_classes = [permissions.IsAuthenticated]
def get(self, request, *args, **kwargs):
"""Perform a GET request to learn information about status codes"""

View File

@ -1,4 +1,5 @@
"""Generic implementation of status for InvenTree models."""
import enum
import re
@ -148,12 +149,10 @@ class StatusCode(BaseEnum):
@classmethod
def dict(cls, key=None):
"""Return a dict representation containing all required information"""
return {x.name: {
'color': x.color,
'key': x.value,
'label': x.label,
'name': x.name,
} for x in cls.values(key)}
return {
x.name: {'color': x.color, 'key': x.value, 'label': x.label, 'name': x.name}
for x in cls.values(key)
}
@classmethod
def list(cls):

View File

@ -35,8 +35,7 @@ class TransitionTests(InvenTreeTestCase):
def test_class(self):
"""Ensure that the class itself works."""
class ErrorImplementation(TransitionMethod):
...
class ErrorImplementation(TransitionMethod): ...
with self.assertRaises(NotImplementedError):
ErrorImplementation()
@ -90,7 +89,6 @@ class TransitionTests(InvenTreeTestCase):
class ValidImplementation(TransitionMethod):
def transition(self, *args, **kwargs):
global raise_function
if raise_function:
@ -103,7 +101,9 @@ class TransitionTests(InvenTreeTestCase):
self.assertIn(ValidImplementation, storage.list)
# Ensure that the function is called
self.assertEqual(StateTransitionMixin.handle_transition(0, 1, self, self, dflt), 1234)
self.assertEqual(
StateTransitionMixin.handle_transition(0, 1, self, self, dflt), 1234
)
_clean_storage([ValidImplementationNoEffect, ValidImplementation])

View File

@ -1,4 +1,5 @@
"""Tests for the generic states module."""
from django.test.client import RequestFactory
from django.utils.translation import gettext_lazy as _
@ -13,9 +14,9 @@ from .states import StatusCode
class GeneralStatus(StatusCode):
"""Defines a set of status codes for tests."""
PENDING = 10, _("Pending"), 'secondary'
PLACED = 20, _("Placed"), 'primary'
COMPLETE = 30, _("Complete"), 'success'
PENDING = 10, _('Pending'), 'secondary'
PLACED = 20, _('Placed'), 'primary'
COMPLETE = 30, _('Complete'), 'success'
ABC = None # This should be ignored
_DEF = None # This should be ignored
jkl = None # This should be ignored
@ -27,6 +28,7 @@ class GeneralStatus(StatusCode):
class GeneralStateTest(InvenTreeTestCase):
"""Test that the StatusCode class works."""
def test_code_definition(self):
"""Test that the status code class has been defined correctly."""
self.assertEqual(GeneralStatus.PENDING, 10)
@ -36,36 +38,84 @@ class GeneralStateTest(InvenTreeTestCase):
def test_code_functions(self):
"""Test that the status code class functions work correctly"""
# render
self.assertEqual(GeneralStatus.render(10), "<span class='badge rounded-pill bg-secondary'>Pending</span>")
self.assertEqual(GeneralStatus.render(20), "<span class='badge rounded-pill bg-primary'>Placed</span>")
self.assertEqual(
GeneralStatus.render(10),
"<span class='badge rounded-pill bg-secondary'>Pending</span>",
)
self.assertEqual(
GeneralStatus.render(20),
"<span class='badge rounded-pill bg-primary'>Placed</span>",
)
# render with invalid key
self.assertEqual(GeneralStatus.render(100), 100)
# list
self.assertEqual(GeneralStatus.list(), [{'color': 'secondary', 'key': 10, 'label': 'Pending', 'name': 'PENDING'}, {'color': 'primary', 'key': 20, 'label': 'Placed', 'name': 'PLACED'}, {'color': 'success', 'key': 30, 'label': 'Complete', 'name': 'COMPLETE'}])
self.assertEqual(
GeneralStatus.list(),
[
{
'color': 'secondary',
'key': 10,
'label': 'Pending',
'name': 'PENDING',
},
{'color': 'primary', 'key': 20, 'label': 'Placed', 'name': 'PLACED'},
{
'color': 'success',
'key': 30,
'label': 'Complete',
'name': 'COMPLETE',
},
],
)
# text
self.assertEqual(GeneralStatus.text(10), 'Pending')
self.assertEqual(GeneralStatus.text(20), 'Placed')
# items
self.assertEqual(list(GeneralStatus.items()), [(10, 'Pending'), (20, 'Placed'), (30, 'Complete')])
self.assertEqual(
list(GeneralStatus.items()),
[(10, 'Pending'), (20, 'Placed'), (30, 'Complete')],
)
# keys
self.assertEqual(list(GeneralStatus.keys()), ([10, 20, 30]))
# labels
self.assertEqual(list(GeneralStatus.labels()), ['Pending', 'Placed', 'Complete'])
self.assertEqual(
list(GeneralStatus.labels()), ['Pending', 'Placed', 'Complete']
)
# names
self.assertEqual(GeneralStatus.names(), {'PENDING': 10, 'PLACED': 20, 'COMPLETE': 30})
self.assertEqual(
GeneralStatus.names(), {'PENDING': 10, 'PLACED': 20, 'COMPLETE': 30}
)
# dict
self.assertEqual(GeneralStatus.dict(), {
'PENDING': {'key': 10, 'name': 'PENDING', 'label': 'Pending', 'color': 'secondary'},
'PLACED': {'key': 20, 'name': 'PLACED', 'label': 'Placed', 'color': 'primary'},
'COMPLETE': {'key': 30, 'name': 'COMPLETE', 'label': 'Complete', 'color': 'success'},
})
self.assertEqual(
GeneralStatus.dict(),
{
'PENDING': {
'key': 10,
'name': 'PENDING',
'label': 'Pending',
'color': 'secondary',
},
'PLACED': {
'key': 20,
'name': 'PLACED',
'label': 'Placed',
'color': 'primary',
},
'COMPLETE': {
'key': 30,
'name': 'COMPLETE',
'label': 'Complete',
'color': 'success',
},
},
)
# label
self.assertEqual(GeneralStatus.label(10), 'Pending')
@ -74,7 +124,10 @@ class GeneralStateTest(InvenTreeTestCase):
"""Test that the status code tag functions."""
from .tags import status_label
self.assertEqual(status_label('general', 10), "<span class='badge rounded-pill bg-secondary'>Pending</span>")
self.assertEqual(
status_label('general', 10),
"<span class='badge rounded-pill bg-secondary'>Pending</span>",
)
# invalid type
with self.assertRaises(ValueError) as e:
@ -87,24 +140,54 @@ class GeneralStateTest(InvenTreeTestCase):
def test_api(self):
"""Test StatusView API view."""
view = StatusView.as_view()
rqst = RequestFactory().get('status/',)
rqst = RequestFactory().get('status/')
force_authenticate(rqst, user=self.user)
# Correct call
resp = view(rqst, **{StatusView.MODEL_REF: GeneralStatus})
self.assertEqual(resp.data, {'class': 'GeneralStatus', 'values': {'COMPLETE': {'key': 30, 'name': 'COMPLETE', 'label': 'Complete', 'color': 'success'}, 'PENDING': {'key': 10, 'name': 'PENDING', 'label': 'Pending', 'color': 'secondary'}, 'PLACED': {'key': 20, 'name': 'PLACED', 'label': 'Placed', 'color': 'primary'}}})
self.assertEqual(
resp.data,
{
'class': 'GeneralStatus',
'values': {
'COMPLETE': {
'key': 30,
'name': 'COMPLETE',
'label': 'Complete',
'color': 'success',
},
'PENDING': {
'key': 10,
'name': 'PENDING',
'label': 'Pending',
'color': 'secondary',
},
'PLACED': {
'key': 20,
'name': 'PLACED',
'label': 'Placed',
'color': 'primary',
},
},
},
)
# No status defined
resp = view(rqst, **{StatusView.MODEL_REF: None})
self.assertEqual(resp.status_code, 400)
self.assertEqual(str(resp.rendered_content, 'utf-8'), '["StatusView view called without \'statusmodel\' parameter"]')
self.assertEqual(
str(resp.rendered_content, 'utf-8'),
'["StatusView view called without \'statusmodel\' parameter"]',
)
# Invalid call - not a class
with self.assertRaises(NotImplementedError) as e:
resp = view(rqst, **{StatusView.MODEL_REF: 'invalid'})
self.assertEqual(str(e.exception), "`status_class` not a class")
self.assertEqual(str(e.exception), '`status_class` not a class')
# Invalid call - not the right class
with self.assertRaises(NotImplementedError) as e:
resp = view(rqst, **{StatusView.MODEL_REF: object})
self.assertEqual(str(e.exception), "`status_class` not a valid StatusCode class")
self.assertEqual(
str(e.exception), '`status_class` not a valid StatusCode class'
)

View File

@ -1,4 +1,5 @@
"""Classes and functions for plugin controlled object state transitions."""
import InvenTree.helpers
@ -15,8 +16,10 @@ class TransitionMethod:
- The needed functions are implemented
"""
# Check if a sending fnc is defined
if (not hasattr(self, 'transition')):
raise NotImplementedError('A TransitionMethod must define a `transition` method')
if not hasattr(self, 'transition'):
raise NotImplementedError(
'A TransitionMethod must define a `transition` method'
)
class TransitionMethodStorageClass:
@ -62,7 +65,9 @@ class StateTransitionMixin:
```
"""
def handle_transition(self, current_state, target_state, instance, default_action, **kwargs):
def handle_transition(
self, current_state, target_state, instance, default_action, **kwargs
):
"""Handle a state transition for an object.
Args:
@ -74,7 +79,9 @@ class StateTransitionMixin:
# Check if there is a custom override function for this transition
for override in storage.list:
rslt = override.transition(current_state, target_state, instance, default_action, **kwargs)
rslt = override.transition(
current_state, target_state, instance, default_action, **kwargs
)
if rslt:
return rslt

View File

@ -5,6 +5,4 @@ from django import template
register = template.Library()
from generic.states.tags import status_label # noqa: E402
__all__ = [
status_label,
]
__all__ = [status_label]

View File

@ -2,7 +2,7 @@
import multiprocessing
bind = "0.0.0.0:8000"
bind = '0.0.0.0:8000'
workers = multiprocessing.cpu_count() * 2 + 1

View File

@ -7,6 +7,7 @@ import label.models
class LabelAdmin(admin.ModelAdmin):
"""Admin class for the various label models"""
list_display = ('name', 'description', 'label', 'filters', 'enabled')

View File

@ -18,8 +18,7 @@ import label.models
import label.serializers
from InvenTree.api import MetadataView
from InvenTree.filters import InvenTreeSearchFilter
from InvenTree.mixins import (ListCreateAPI, RetrieveAPI,
RetrieveUpdateDestroyAPI)
from InvenTree.mixins import ListCreateAPI, RetrieveAPI, RetrieveUpdateDestroyAPI
from part.models import Part
from plugin.builtin.labels.inventree_label import InvenTreeLabelPlugin
from plugin.registry import registry
@ -59,7 +58,7 @@ class LabelFilterMixin:
for id in ids:
try:
valid_ids.append(int(id))
except (ValueError):
except ValueError:
pass
# Filter queryset by matching ID values
@ -120,34 +119,23 @@ class LabelListView(LabelFilterMixin, ListCreateAPI):
return queryset
filter_backends = [
DjangoFilterBackend,
InvenTreeSearchFilter
]
filter_backends = [DjangoFilterBackend, InvenTreeSearchFilter]
filterset_fields = [
'enabled',
]
filterset_fields = ['enabled']
search_fields = [
'name',
'description',
]
search_fields = ['name', 'description']
@method_decorator(cache_page(5), name='dispatch')
class LabelPrintMixin(LabelFilterMixin):
"""Mixin for printing labels."""
rolemap = {
"GET": "view",
"POST": "view",
}
rolemap = {'GET': 'view', 'POST': 'view'}
def check_permissions(self, request):
"""Override request method to GET so that also non superusers can print using a post request."""
if request.method == "POST":
request = clone_request(request, "GET")
if request.method == 'POST':
request = clone_request(request, 'GET')
return super().check_permissions(request)
@method_decorator(never_cache)
@ -161,7 +149,9 @@ class LabelPrintMixin(LabelFilterMixin):
plugin = self.get_plugin(self.request)
kwargs.setdefault('context', self.get_serializer_context())
serializer = plugin.get_printing_options_serializer(self.request, *args, **kwargs)
serializer = plugin.get_printing_options_serializer(
self.request, *args, **kwargs
)
# if no serializer is defined, return an empty serializer
if not serializer:
@ -171,8 +161,12 @@ class LabelPrintMixin(LabelFilterMixin):
def get(self, request, *args, **kwargs):
"""Perform a GET request against this endpoint to print labels"""
common.models.InvenTreeUserSetting.set_setting('DEFAULT_' + self.ITEM_KEY.upper() + '_LABEL_TEMPLATE',
self.get_object().pk, None, user=request.user)
common.models.InvenTreeUserSetting.set_setting(
'DEFAULT_' + self.ITEM_KEY.upper() + '_LABEL_TEMPLATE',
self.get_object().pk,
None,
user=request.user,
)
return self.print(request, self.get_items())
def post(self, request, *args, **kwargs):
@ -205,8 +199,10 @@ class LabelPrintMixin(LabelFilterMixin):
if not plugin.is_active():
raise ValidationError(f"Plugin '{plugin_key}' is not enabled")
if not plugin.mixin_enabled("labels"):
raise ValidationError(f"Plugin '{plugin_key}' is not a label printing plugin")
if not plugin.mixin_enabled('labels'):
raise ValidationError(
f"Plugin '{plugin_key}' is not a label printing plugin"
)
# Only return the plugin if it is enabled and has the label printing mixin
return plugin
@ -228,18 +224,24 @@ class LabelPrintMixin(LabelFilterMixin):
raise ValidationError('Label has invalid dimensions')
# if the plugin returns a serializer, validate the data
if serializer := plugin.get_printing_options_serializer(request, data=request.data, context=self.get_serializer_context()):
if serializer := plugin.get_printing_options_serializer(
request, data=request.data, context=self.get_serializer_context()
):
serializer.is_valid(raise_exception=True)
# At this point, we offload the label(s) to the selected plugin.
# The plugin is responsible for handling the request and returning a response.
result = plugin.print_labels(label, items_to_print, request, printing_options=request.data)
result = plugin.print_labels(
label, items_to_print, request, printing_options=request.data
)
if isinstance(result, JsonResponse):
result['plugin'] = plugin.plugin_slug()
return result
raise ValidationError(f"Plugin '{plugin.plugin_slug()}' returned invalid response type '{type(result)}'")
raise ValidationError(
f"Plugin '{plugin.plugin_slug()}' returned invalid response type '{type(result)}'"
)
class StockItemLabelMixin:
@ -261,16 +263,19 @@ class StockItemLabelList(StockItemLabelMixin, LabelListView):
- item: Filter by single stock item
- items: Filter by list of stock items
"""
pass
class StockItemLabelDetail(StockItemLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single StockItemLabel object."""
pass
class StockItemLabelPrint(StockItemLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a StockItemLabel object."""
pass
@ -293,21 +298,25 @@ class StockLocationLabelList(StockLocationLabelMixin, LabelListView):
- location: Filter by a single stock location
- locations: Filter by list of stock locations
"""
pass
class StockLocationLabelDetail(StockLocationLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single StockLocationLabel object."""
pass
class StockLocationLabelPrint(StockLocationLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a StockLocationLabel object."""
pass
class PartLabelMixin:
"""Mixin for PartLabel endpoints"""
queryset = label.models.PartLabel.objects.all()
serializer_class = label.serializers.PartLabelSerializer
@ -317,16 +326,19 @@ class PartLabelMixin:
class PartLabelList(PartLabelMixin, LabelListView):
"""API endpoint for viewing list of PartLabel objects."""
pass
class PartLabelDetail(PartLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single PartLabel object."""
pass
class PartLabelPrint(PartLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a PartLabel object."""
pass
@ -342,70 +354,147 @@ class BuildLineLabelMixin:
class BuildLineLabelList(BuildLineLabelMixin, LabelListView):
"""API endpoint for viewing a list of BuildLineLabel objects"""
pass
class BuildLineLabelDetail(BuildLineLabelMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for a single BuildLineLabel object"""
pass
class BuildLineLabelPrint(BuildLineLabelMixin, LabelPrintMixin, RetrieveAPI):
"""API endpoint for printing a BuildLineLabel object"""
pass
label_api_urls = [
# Stock item labels
re_path(r'stock/', include([
# Detail views
path(r'<int:pk>/', include([
re_path(r'print/?', StockItemLabelPrint.as_view(), name='api-stockitem-label-print'),
re_path(r'metadata/', MetadataView.as_view(), {'model': label.models.StockItemLabel}, name='api-stockitem-label-metadata'),
re_path(r'^.*$', StockItemLabelDetail.as_view(), name='api-stockitem-label-detail'),
])),
# List view
re_path(r'^.*$', StockItemLabelList.as_view(), name='api-stockitem-label-list'),
])),
re_path(
r'stock/',
include([
# Detail views
path(
r'<int:pk>/',
include([
re_path(
r'print/?',
StockItemLabelPrint.as_view(),
name='api-stockitem-label-print',
),
re_path(
r'metadata/',
MetadataView.as_view(),
{'model': label.models.StockItemLabel},
name='api-stockitem-label-metadata',
),
re_path(
r'^.*$',
StockItemLabelDetail.as_view(),
name='api-stockitem-label-detail',
),
]),
),
# List view
re_path(
r'^.*$', StockItemLabelList.as_view(), name='api-stockitem-label-list'
),
]),
),
# Stock location labels
re_path(r'location/', include([
# Detail views
path(r'<int:pk>/', include([
re_path(r'print/?', StockLocationLabelPrint.as_view(), name='api-stocklocation-label-print'),
re_path(r'metadata/', MetadataView.as_view(), {'model': label.models.StockLocationLabel}, name='api-stocklocation-label-metadata'),
re_path(r'^.*$', StockLocationLabelDetail.as_view(), name='api-stocklocation-label-detail'),
])),
# List view
re_path(r'^.*$', StockLocationLabelList.as_view(), name='api-stocklocation-label-list'),
])),
re_path(
r'location/',
include([
# Detail views
path(
r'<int:pk>/',
include([
re_path(
r'print/?',
StockLocationLabelPrint.as_view(),
name='api-stocklocation-label-print',
),
re_path(
r'metadata/',
MetadataView.as_view(),
{'model': label.models.StockLocationLabel},
name='api-stocklocation-label-metadata',
),
re_path(
r'^.*$',
StockLocationLabelDetail.as_view(),
name='api-stocklocation-label-detail',
),
]),
),
# List view
re_path(
r'^.*$',
StockLocationLabelList.as_view(),
name='api-stocklocation-label-list',
),
]),
),
# Part labels
re_path(r'^part/', include([
# Detail views
path(r'<int:pk>/', include([
re_path(r'^print/', PartLabelPrint.as_view(), name='api-part-label-print'),
re_path(r'^metadata/', MetadataView.as_view(), {'model': label.models.PartLabel}, name='api-part-label-metadata'),
re_path(r'^.*$', PartLabelDetail.as_view(), name='api-part-label-detail'),
])),
# List view
re_path(r'^.*$', PartLabelList.as_view(), name='api-part-label-list'),
])),
re_path(
r'^part/',
include([
# Detail views
path(
r'<int:pk>/',
include([
re_path(
r'^print/',
PartLabelPrint.as_view(),
name='api-part-label-print',
),
re_path(
r'^metadata/',
MetadataView.as_view(),
{'model': label.models.PartLabel},
name='api-part-label-metadata',
),
re_path(
r'^.*$', PartLabelDetail.as_view(), name='api-part-label-detail'
),
]),
),
# List view
re_path(r'^.*$', PartLabelList.as_view(), name='api-part-label-list'),
]),
),
# BuildLine labels
re_path(r'^buildline/', include([
# Detail views
path(r'<int:pk>/', include([
re_path(r'^print/', BuildLineLabelPrint.as_view(), name='api-buildline-label-print'),
re_path(r'^metadata/', MetadataView.as_view(), {'model': label.models.BuildLineLabel}, name='api-buildline-label-metadata'),
re_path(r'^.*$', BuildLineLabelDetail.as_view(), name='api-buildline-label-detail'),
])),
# List view
re_path(r'^.*$', BuildLineLabelList.as_view(), name='api-buildline-label-list'),
])),
re_path(
r'^buildline/',
include([
# Detail views
path(
r'<int:pk>/',
include([
re_path(
r'^print/',
BuildLineLabelPrint.as_view(),
name='api-buildline-label-print',
),
re_path(
r'^metadata/',
MetadataView.as_view(),
{'model': label.models.BuildLineLabel},
name='api-buildline-label-metadata',
),
re_path(
r'^.*$',
BuildLineLabelDetail.as_view(),
name='api-buildline-label-detail',
),
]),
),
# List view
re_path(
r'^.*$', BuildLineLabelList.as_view(), name='api-buildline-label-list'
),
]),
),
]

View File

@ -14,7 +14,7 @@ from django.db.utils import IntegrityError, OperationalError, ProgrammingError
import InvenTree.ready
logger = logging.getLogger("inventree")
logger = logging.getLogger('inventree')
def hashFile(filename):
@ -36,23 +36,37 @@ class LabelConfig(AppConfig):
def ready(self):
"""This function is called whenever the label app is loaded."""
# skip loading if plugin registry is not loaded or we run in a background thread
if not InvenTree.ready.isPluginRegistryLoaded() or not InvenTree.ready.isInMainThread():
if (
not InvenTree.ready.isPluginRegistryLoaded()
or not InvenTree.ready.isInMainThread()
):
return
if InvenTree.ready.isRunningMigrations():
return
if InvenTree.ready.canAppAccessDatabase(allow_test=False) and not InvenTree.ready.isImportingData():
if (
InvenTree.ready.canAppAccessDatabase(allow_test=False)
and not InvenTree.ready.isImportingData()
):
try:
self.create_labels() # pragma: no cover
except (AppRegistryNotReady, IntegrityError, OperationalError, ProgrammingError):
except (
AppRegistryNotReady,
IntegrityError,
OperationalError,
ProgrammingError,
):
# Database might not yet be ready
warnings.warn('Database was not ready for creating labels', stacklevel=2)
warnings.warn(
'Database was not ready for creating labels', stacklevel=2
)
def create_labels(self):
"""Create all default templates."""
# Test if models are ready
import label.models
assert bool(label.models.StockLocationLabel is not None)
# Create the categories
@ -66,7 +80,7 @@ class LabelConfig(AppConfig):
'description': 'Simple QR code label',
'width': 24,
'height': 24,
},
}
],
)
@ -87,8 +101,8 @@ class LabelConfig(AppConfig):
'description': 'Label with QR code and name of location',
'width': 50,
'height': 24,
}
]
},
],
)
self.create_labels_category(
@ -109,7 +123,7 @@ class LabelConfig(AppConfig):
'width': 70,
'height': 24,
},
]
],
)
self.create_labels_category(
@ -122,24 +136,16 @@ class LabelConfig(AppConfig):
'description': 'Example build line label',
'width': 125,
'height': 48,
},
]
}
],
)
def create_labels_category(self, model, ref_name, labels):
"""Create folder and database entries for the default templates, if they do not already exist."""
# Create root dir for templates
src_dir = Path(__file__).parent.joinpath(
'templates',
'label',
ref_name,
)
src_dir = Path(__file__).parent.joinpath('templates', 'label', ref_name)
dst_dir = settings.MEDIA_ROOT.joinpath(
'label',
'inventree',
ref_name,
)
dst_dir = settings.MEDIA_ROOT.joinpath('label', 'inventree', ref_name)
if not dst_dir.exists():
logger.info("Creating required directory: '%s'", dst_dir)
@ -151,12 +157,7 @@ class LabelConfig(AppConfig):
def create_template_label(self, model, src_dir, ref_name, label):
"""Ensure a label template is in place."""
filename = os.path.join(
'label',
'inventree',
ref_name,
label['file']
)
filename = os.path.join('label', 'inventree', ref_name, label['file'])
src_file = src_dir.joinpath(label['file'])
dst_file = settings.MEDIA_ROOT.joinpath(filename)
@ -187,7 +188,10 @@ class LabelConfig(AppConfig):
if model.objects.filter(label=filename).exists():
return # pragma: no cover
except Exception:
logger.exception("Failed to query label for '%s' - you should run 'invoke update' first!", filename)
logger.exception(
"Failed to query label for '%s' - you should run 'invoke update' first!",
filename,
)
logger.info("Creating entry for %s '%s'", model, label['name'])

View File

@ -25,12 +25,12 @@ from plugin.registry import registry
try:
from django_weasyprint import WeasyTemplateResponseMixin
except OSError as err: # pragma: no cover
print(f"OSError: {err}")
print("You may require some further system packages to be installed.")
print(f'OSError: {err}')
print('You may require some further system packages to be installed.')
sys.exit(1)
logger = logging.getLogger("inventree")
logger = logging.getLogger('inventree')
def rename_label(instance, filename):
@ -97,7 +97,7 @@ class LabelTemplate(MetadataMixin, models.Model):
abstract = True
# Each class of label files will be stored in a separate subdirectory
SUBDIR = "label"
SUBDIR = 'label'
# Object we will be printing against (will be filled out later)
object_to_print = None
@ -109,17 +109,16 @@ class LabelTemplate(MetadataMixin, models.Model):
def __str__(self):
"""Format a string representation of a label instance"""
return f"{self.name} - {self.description}"
return f'{self.name} - {self.description}'
name = models.CharField(
blank=False, max_length=100,
verbose_name=_('Name'),
help_text=_('Label name'),
blank=False, max_length=100, verbose_name=_('Name'), help_text=_('Label name')
)
description = models.CharField(
max_length=250,
blank=True, null=True,
blank=True,
null=True,
verbose_name=_('Description'),
help_text=_('Label description'),
)
@ -127,7 +126,8 @@ class LabelTemplate(MetadataMixin, models.Model):
label = models.FileField(
upload_to=rename_label,
unique=True,
blank=False, null=False,
blank=False,
null=False,
verbose_name=_('Label'),
help_text=_('Label template file'),
validators=[FileExtensionValidator(allowed_extensions=['html'])],
@ -143,18 +143,18 @@ class LabelTemplate(MetadataMixin, models.Model):
default=50,
verbose_name=_('Width [mm]'),
help_text=_('Label width, specified in mm'),
validators=[MinValueValidator(2)]
validators=[MinValueValidator(2)],
)
height = models.FloatField(
default=20,
verbose_name=_('Height [mm]'),
help_text=_('Label height, specified in mm'),
validators=[MinValueValidator(2)]
validators=[MinValueValidator(2)],
)
filename_pattern = models.CharField(
default="label.pdf",
default='label.pdf',
verbose_name=_('Filename Pattern'),
help_text=_('Pattern for generating label filenames'),
max_length=100,
@ -249,11 +249,7 @@ class LabelTemplate(MetadataMixin, models.Model):
context = self.context(request, **kwargs)
return render_to_string(
self.template_name,
context,
request
)
return render_to_string(self.template_name, context, request)
def render(self, request, target_object=None, **kwargs):
"""Render the label template to a PDF file.
@ -269,16 +265,13 @@ class LabelTemplate(MetadataMixin, models.Model):
wp = WeasyprintLabelMixin(
request,
self.template_name,
base_url=request.build_absolute_uri("/"),
base_url=request.build_absolute_uri('/'),
presentational_hints=True,
filename=self.generate_filename(request),
**kwargs
**kwargs,
)
return wp.render_to_response(
context,
**kwargs
)
return wp.render_to_response(context, **kwargs)
class LabelOutput(models.Model):
@ -293,22 +286,14 @@ class LabelOutput(models.Model):
# File will be stored in a subdirectory
label = models.FileField(
upload_to=rename_label_output,
unique=True, blank=False, null=False,
upload_to=rename_label_output, unique=True, blank=False, null=False
)
# Creation date of label output
created = models.DateField(
auto_now_add=True,
editable=False,
)
created = models.DateField(auto_now_add=True, editable=False)
# User who generated the label
user = models.ForeignKey(
User,
on_delete=models.SET_NULL,
blank=True, null=True,
)
user = models.ForeignKey(User, on_delete=models.SET_NULL, blank=True, null=True)
class StockItemLabel(LabelTemplate):
@ -319,15 +304,14 @@ class StockItemLabel(LabelTemplate):
"""Return the API URL associated with the StockItemLabel model"""
return reverse('api-stockitem-label-list') # pragma: no cover
SUBDIR = "stockitem"
SUBDIR = 'stockitem'
filters = models.CharField(
blank=True, max_length=250,
blank=True,
max_length=250,
help_text=_('Query filters (comma-separated list of key=value pairs)'),
verbose_name=_('Filters'),
validators=[
validate_stock_item_filters
]
validators=[validate_stock_item_filters],
)
def get_context_data(self, request):
@ -348,7 +332,6 @@ class StockItemLabel(LabelTemplate):
'qr_url': request.build_absolute_uri(stock_item.get_absolute_url()),
'tests': stock_item.testResultMap(),
'parameters': stock_item.part.parameters_map(),
}
@ -360,24 +343,21 @@ class StockLocationLabel(LabelTemplate):
"""Return the API URL associated with the StockLocationLabel model"""
return reverse('api-stocklocation-label-list') # pragma: no cover
SUBDIR = "stocklocation"
SUBDIR = 'stocklocation'
filters = models.CharField(
blank=True, max_length=250,
blank=True,
max_length=250,
help_text=_('Query filters (comma-separated list of key=value pairs)'),
verbose_name=_('Filters'),
validators=[
validate_stock_location_filters]
validators=[validate_stock_location_filters],
)
def get_context_data(self, request):
"""Generate context data for each provided StockLocation."""
location = self.object_to_print
return {
'location': location,
'qr_data': location.format_barcode(brief=True),
}
return {'location': location, 'qr_data': location.format_barcode(brief=True)}
class PartLabel(LabelTemplate):
@ -391,12 +371,11 @@ class PartLabel(LabelTemplate):
SUBDIR = 'part'
filters = models.CharField(
blank=True, max_length=250,
blank=True,
max_length=250,
help_text=_('Query filters (comma-separated list of key=value pairs)'),
verbose_name=_('Filters'),
validators=[
validate_part_filters
]
validators=[validate_part_filters],
)
def get_context_data(self, request):
@ -427,12 +406,11 @@ class BuildLineLabel(LabelTemplate):
SUBDIR = 'buildline'
filters = models.CharField(
blank=True, max_length=250,
blank=True,
max_length=250,
help_text=_('Query filters (comma-separated list of key=value pairs)'),
verbose_name=_('Filters'),
validators=[
validate_build_line_filters
]
validators=[validate_build_line_filters],
)
def get_context_data(self, request):

View File

@ -1,8 +1,10 @@
"""API serializers for the label app"""
import label.models
from InvenTree.serializers import (InvenTreeAttachmentSerializerField,
InvenTreeModelSerializer)
from InvenTree.serializers import (
InvenTreeAttachmentSerializerField,
InvenTreeModelSerializer,
)
class LabelSerializerBase(InvenTreeModelSerializer):
@ -13,14 +15,7 @@ class LabelSerializerBase(InvenTreeModelSerializer):
@staticmethod
def label_fields():
"""Generic serializer fields for a label template"""
return [
'pk',
'name',
'description',
'label',
'filters',
'enabled',
]
return ['pk', 'name', 'description', 'label', 'filters', 'enabled']
class StockItemLabelSerializer(LabelSerializerBase):

View File

@ -8,17 +8,9 @@ from InvenTree.unit_test import InvenTreeAPITestCase
class TestReportTests(InvenTreeAPITestCase):
"""Tests for the StockItem TestReport templates."""
fixtures = [
'category',
'part',
'location',
'stock',
]
fixtures = ['category', 'part', 'location', 'stock']
roles = [
'stock.view',
'stock_location.view',
]
roles = ['stock.view', 'stock_location.view']
list_url = reverse('api-stockitem-testreport-list')
@ -42,22 +34,10 @@ class TestReportTests(InvenTreeAPITestCase):
self.assertEqual(len(response), 0)
# TODO - Add some tests to this response
response = self.do_list(
{
'item': 10,
}
)
response = self.do_list({'item': 10})
# TODO - Add some tests to this response
response = self.do_list(
{
'item': 100000,
}
)
response = self.do_list({'item': 100000})
# TODO - Add some tests to this response
response = self.do_list(
{
'items': [10, 11, 12],
}
)
response = self.do_list({'items': [10, 11, 12]})

View File

@ -24,12 +24,7 @@ from .models import PartLabel, StockItemLabel, StockLocationLabel
class LabelTest(InvenTreeAPITestCase):
"""Unit test class for label models"""
fixtures = [
'category',
'part',
'location',
'stock'
]
fixtures = ['category', 'part', 'location', 'stock']
@classmethod
def setUpTestData(cls):
@ -49,12 +44,9 @@ class LabelTest(InvenTreeAPITestCase):
def test_default_files(self):
"""Test that label files exist in the MEDIA directory."""
def test_subdir(ref_name):
item_dir = settings.MEDIA_ROOT.joinpath(
'label',
'inventree',
ref_name,
)
item_dir = settings.MEDIA_ROOT.joinpath('label', 'inventree', ref_name)
self.assertTrue(len([item_dir.iterdir()]) > 0)
test_subdir('stockitem')
@ -63,13 +55,13 @@ class LabelTest(InvenTreeAPITestCase):
def test_filters(self):
"""Test the label filters."""
filter_string = "part__pk=10"
filter_string = 'part__pk=10'
filters = validateFilterString(filter_string, model=StockItem)
self.assertEqual(type(filters), dict)
bad_filter_string = "part_pk=10"
bad_filter_string = 'part_pk=10'
with self.assertRaises(ValidationError):
validateFilterString(bad_filter_string, model=StockItem)
@ -115,14 +107,11 @@ class LabelTest(InvenTreeAPITestCase):
buffer = io.StringIO()
buffer.write(label_data)
template = ContentFile(buffer.getvalue(), "label.html")
template = ContentFile(buffer.getvalue(), 'label.html')
# Construct a label template
label = PartLabel.objects.create(
name='test',
description='Test label',
enabled=True,
label=template,
name='test', description='Test label', enabled=True, label=template
)
# Ensure we are in "debug" mode (so the report is generated as HTML)
@ -151,7 +140,7 @@ class LabelTest(InvenTreeAPITestCase):
content = f.read()
# Test that each element has been rendered correctly
self.assertIn(f"part: {part_pk} - {part_name}", content)
self.assertIn(f'part: {part_pk} - {part_name}', content)
self.assertIn(f'data: {{"part": {part_pk}}}', content)
self.assertIn(f'http://testserver/part/{part_pk}/', content)

View File

@ -7,17 +7,17 @@ import sys
def main():
"""Run administrative tasks."""
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "InvenTree.settings")
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'InvenTree.settings')
try:
from django.core.management import execute_from_command_line
except ImportError as exc: # pragma: no cover
raise ImportError(
"Couldn't import Django. Are you sure it's installed and "
"available on your PYTHONPATH environment variable? Did you "
"forget to activate a virtual environment?"
'available on your PYTHONPATH environment variable? Did you '
'forget to activate a virtual environment?'
) from exc
execute_from_command_line(sys.argv)
if __name__ == "__main__":
if __name__ == '__main__':
main()

View File

@ -50,52 +50,54 @@ class PriceResourceMixin:
# region general classes
class GeneralExtraLineAdmin:
"""Admin class template for the 'ExtraLineItem' models"""
list_display = (
'order',
'quantity',
'reference'
)
search_fields = [
'order__reference',
'order__customer__name',
'reference',
]
list_display = ('order', 'quantity', 'reference')
autocomplete_fields = ('order', )
search_fields = ['order__reference', 'order__customer__name', 'reference']
autocomplete_fields = ('order',)
class GeneralExtraLineMeta:
"""Metaclass template for the 'ExtraLineItem' models"""
skip_unchanged = True
report_skipped = False
clean_model_instances = True
# endregion
class PurchaseOrderLineItemInlineAdmin(admin.StackedInline):
"""Inline admin class for the PurchaseOrderLineItem model"""
model = models.PurchaseOrderLineItem
extra = 0
class PurchaseOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource):
class PurchaseOrderResource(
ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource
):
"""Class for managing import / export of PurchaseOrder data."""
class Meta:
"""Metaclass"""
model = models.PurchaseOrder
skip_unchanged = True
clean_model_instances = True
exclude = [
'metadata',
]
exclude = ['metadata']
# Add number of line items
line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True)
line_items = Field(
attribute='line_count', widget=widgets.IntegerWidget(), readonly=True
)
# Is this order overdue?
overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True)
overdue = Field(
attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True
)
class PurchaseOrderAdmin(ImportExportModelAdmin):
@ -103,48 +105,39 @@ class PurchaseOrderAdmin(ImportExportModelAdmin):
resource_class = PurchaseOrderResource
exclude = [
'reference_int',
]
exclude = ['reference_int']
list_display = (
'reference',
'supplier',
'status',
'description',
'creation_date'
)
list_display = ('reference', 'supplier', 'status', 'description', 'creation_date')
search_fields = [
'reference',
'supplier__name',
'description',
]
search_fields = ['reference', 'supplier__name', 'description']
inlines = [
PurchaseOrderLineItemInlineAdmin
]
inlines = [PurchaseOrderLineItemInlineAdmin]
autocomplete_fields = ('supplier',)
class SalesOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource):
class SalesOrderResource(
ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource
):
"""Class for managing import / export of SalesOrder data."""
class Meta:
"""Metaclass options"""
model = models.SalesOrder
skip_unchanged = True
clean_model_instances = True
exclude = [
'metadata',
]
exclude = ['metadata']
# Add number of line items
line_items = Field(attribute='line_count', widget=widgets.IntegerWidget(), readonly=True)
line_items = Field(
attribute='line_count', widget=widgets.IntegerWidget(), readonly=True
)
# Is this order overdue?
overdue = Field(attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True)
overdue = Field(
attribute='is_overdue', widget=widgets.BooleanWidget(), readonly=True
)
class SalesOrderAdmin(ImportExportModelAdmin):
@ -152,23 +145,11 @@ class SalesOrderAdmin(ImportExportModelAdmin):
resource_class = SalesOrderResource
exclude = [
'reference_int',
]
exclude = ['reference_int']
list_display = (
'reference',
'customer',
'status',
'description',
'creation_date',
)
list_display = ('reference', 'customer', 'status', 'description', 'creation_date')
search_fields = [
'reference',
'customer__name',
'description',
]
search_fields = ['reference', 'customer__name', 'description']
autocomplete_fields = ('customer',)
@ -178,6 +159,7 @@ class PurchaseOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
class Meta:
"""Metaclass"""
model = models.PurchaseOrderLineItem
skip_unchanged = True
report_skipped = False
@ -212,6 +194,7 @@ class SalesOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
class Meta:
"""Metaclass options"""
model = models.SalesOrderLineItem
skip_unchanged = True
report_skipped = False
@ -240,6 +223,7 @@ class SalesOrderExtraLineResource(PriceResourceMixin, InvenTreeResource):
class Meta(GeneralExtraLineMeta):
"""Metaclass options."""
model = models.SalesOrderExtraLine
@ -248,20 +232,16 @@ class PurchaseOrderLineItemAdmin(ImportExportModelAdmin):
resource_class = PurchaseOrderLineItemResource
list_display = (
'order',
'part',
'quantity',
'reference'
)
list_display = ('order', 'part', 'quantity', 'reference')
search_fields = ('reference',)
autocomplete_fields = ('order', 'part', 'destination',)
autocomplete_fields = ('order', 'part', 'destination')
class PurchaseOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the PurchaseOrderExtraLine model"""
resource_class = PurchaseOrderExtraLineResource
@ -270,12 +250,7 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin):
resource_class = SalesOrderLineItemResource
list_display = (
'order',
'part',
'quantity',
'reference'
)
list_display = ('order', 'part', 'quantity', 'reference')
search_fields = [
'part__name',
@ -284,28 +259,21 @@ class SalesOrderLineItemAdmin(ImportExportModelAdmin):
'reference',
]
autocomplete_fields = ('order', 'part',)
autocomplete_fields = ('order', 'part')
class SalesOrderExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the SalesOrderExtraLine model"""
resource_class = SalesOrderExtraLineResource
class SalesOrderShipmentAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderShipment model"""
list_display = [
'order',
'shipment_date',
'reference',
]
list_display = ['order', 'shipment_date', 'reference']
search_fields = [
'reference',
'order__reference',
'order__customer__name',
]
search_fields = ['reference', 'order__reference', 'order__customer__name']
autocomplete_fields = ('order',)
@ -313,26 +281,23 @@ class SalesOrderShipmentAdmin(ImportExportModelAdmin):
class SalesOrderAllocationAdmin(ImportExportModelAdmin):
"""Admin class for the SalesOrderAllocation model"""
list_display = (
'line',
'item',
'quantity'
)
list_display = ('line', 'item', 'quantity')
autocomplete_fields = ('line', 'shipment', 'item',)
autocomplete_fields = ('line', 'shipment', 'item')
class ReturnOrderResource(ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource):
class ReturnOrderResource(
ProjectCodeResourceMixin, TotalPriceResourceMixin, InvenTreeResource
):
"""Class for managing import / export of ReturnOrder data"""
class Meta:
"""Metaclass options"""
model = models.ReturnOrder
skip_unchanged = True
clean_model_instances = True
exclude = [
'metadata',
]
exclude = ['metadata']
class ReturnOrderAdmin(ImportExportModelAdmin):
@ -340,25 +305,13 @@ class ReturnOrderAdmin(ImportExportModelAdmin):
resource_class = ReturnOrderResource
exclude = [
'reference_int',
]
exclude = ['reference_int']
list_display = [
'reference',
'customer',
'status',
]
list_display = ['reference', 'customer', 'status']
search_fields = [
'reference',
'customer__name',
'description',
]
search_fields = ['reference', 'customer__name', 'description']
autocomplete_fields = [
'customer',
]
autocomplete_fields = ['customer']
class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
@ -366,6 +319,7 @@ class ReturnOrderLineItemResource(PriceResourceMixin, InvenTreeResource):
class Meta:
"""Metaclass options"""
model = models.ReturnOrderLineItem
skip_unchanged = True
report_skipped = False
@ -377,11 +331,7 @@ class ReturnOrderLineItemAdmin(ImportExportModelAdmin):
resource_class = ReturnOrderLineItemResource
list_display = [
'order',
'item',
'reference',
]
list_display = ['order', 'item', 'reference']
class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource):
@ -389,11 +339,13 @@ class ReturnOrderExtraLineClass(PriceResourceMixin, InvenTreeResource):
class Meta(GeneralExtraLineMeta):
"""Metaclass options"""
model = models.ReturnOrderExtraLine
class ReturnOrdeerExtraLineAdmin(GeneralExtraLineAdmin, ImportExportModelAdmin):
"""Admin class for the ReturnOrderExtraLine model"""
resource_class = ReturnOrderExtraLineClass

File diff suppressed because it is too large Load Diff

View File

@ -5,4 +5,5 @@ from django.apps import AppConfig
class OrderConfig(AppConfig):
"""Configuration class for the 'order' app"""
name = 'order'

View File

@ -17,14 +17,16 @@ class OrderMatchItemForm(MatchItemForm):
if 'quantity' in col_guess.lower():
return forms.CharField(
required=False,
widget=forms.NumberInput(attrs={
'name': 'quantity' + str(row['index']),
'class': 'numberinput',
'type': 'number',
'min': '0',
'step': 'any',
'value': clean_decimal(row.get('quantity', '')),
})
widget=forms.NumberInput(
attrs={
'name': 'quantity' + str(row['index']),
'class': 'numberinput',
'type': 'number',
'min': '0',
'step': 'any',
'value': clean_decimal(row.get('quantity', '')),
}
),
)
# set price field
elif 'price' in col_guess.lower():

Some files were not shown because too many files have changed in this diff Show More