diff --git a/InvenTree/InvenTree/settings.py b/InvenTree/InvenTree/settings.py index ff750a2166..0eaae34726 100644 --- a/InvenTree/InvenTree/settings.py +++ b/InvenTree/InvenTree/settings.py @@ -1148,5 +1148,4 @@ if CUSTOM_FLAGS: # Magic login django-sesame SESAME_MAX_AGE = 300 -# LOGIN_REDIRECT_URL = f"/{FRONTEND_URL_BASE}/logged-in/" -LOGIN_REDIRECT_URL = '/index/' +LOGIN_REDIRECT_URL = '/api/auth/login-redirect/' diff --git a/InvenTree/InvenTree/tests.py b/InvenTree/InvenTree/tests.py index a054185075..b58f53534c 100644 --- a/InvenTree/InvenTree/tests.py +++ b/InvenTree/InvenTree/tests.py @@ -1247,9 +1247,6 @@ class MagicLoginTest(InvenTreeTestCase): # Check that the login works resp = self.client.get(reverse('sesame-login') + '?sesame=' + token) self.assertEqual(resp.status_code, 302) - self.assertEqual(resp.url, '/index/') - # Note: 2023-08-08 - This test has been changed because "platform UI" is not generally available yet - # TODO: In the future, the URL comparison will need to be reverted - # self.assertEqual(resp.url, f'/{settings.FRONTEND_URL_BASE}/logged-in/') + self.assertEqual(resp.url, '/api/auth/login-redirect/') # And we should be logged in again self.assertEqual(resp.wsgi_request.user, self.user) diff --git a/InvenTree/InvenTree/urls.py b/InvenTree/InvenTree/urls.py index 038b75ace2..1d5b421c8d 100644 --- a/InvenTree/InvenTree/urls.py +++ b/InvenTree/InvenTree/urls.py @@ -35,6 +35,7 @@ from order.urls import order_urls from part.urls import part_urls from plugin.urls import get_plugin_urls from stock.urls import stock_urls +from web.urls import api_urls as web_api_urls from web.urls import urlpatterns as platform_urls from .api import APISearchView, InfoView, NotFoundView, VersionTextView, VersionView @@ -84,6 +85,7 @@ apipatterns = [ path('report/', include(report.api.report_api_urls)), path('user/', include(users.api.user_urls)), path('admin/', include(common.api.admin_api_urls)), + path('web/', include(web_api_urls)), # Plugin endpoints path('', include(plugin.api.plugin_api_urls)), # Common endpoints endpoint @@ -149,6 +151,11 @@ apipatterns = [ name='social_account_disconnect', ), path('logout/', users.api.Logout.as_view(), name='api-logout'), + path( + 'login-redirect/', + users.api.LoginRedirect.as_view(), + name='api-login-redirect', + ), path('', include('dj_rest_auth.urls')), ]), ), diff --git a/InvenTree/users/api.py b/InvenTree/users/api.py index 582b5f7f32..b5a8728f53 100644 --- a/InvenTree/users/api.py +++ b/InvenTree/users/api.py @@ -6,6 +6,7 @@ import logging from django.contrib.auth import get_user, login from django.contrib.auth.models import Group, User from django.urls import include, path, re_path +from django.views.generic.base import RedirectView from dj_rest_auth.views import LogoutView from rest_framework import exceptions, permissions @@ -22,6 +23,7 @@ from InvenTree.mixins import ( RetrieveUpdateDestroyAPI, ) from InvenTree.serializers import ExendedUserSerializer, UserCreateSerializer +from InvenTree.settings import FRONTEND_URL_BASE from users.models import ApiToken, Owner, RuleSet, check_user_role from users.serializers import GroupSerializer, OwnerSerializer @@ -279,6 +281,17 @@ class GetAuthToken(APIView): raise exceptions.NotAuthenticated() +class LoginRedirect(RedirectView): + """Redirect to the correct starting page after backend login.""" + + def get_redirect_url(self, *args, **kwargs): + """Return the URL to redirect to.""" + session = self.request.session + if session.get('preferred_method', 'cui') == 'pui': + return f'/{FRONTEND_URL_BASE}/logged-in/' + return '/index/' + + user_urls = [ path('roles/', RoleDetails.as_view(), name='api-user-roles'), path('token/', GetAuthToken.as_view(), name='api-token'), diff --git a/InvenTree/web/tests.py b/InvenTree/web/tests.py index 62212a5b75..eca79fa94b 100644 --- a/InvenTree/web/tests.py +++ b/InvenTree/web/tests.py @@ -5,8 +5,10 @@ import os from pathlib import Path from unittest import mock +from django.urls import reverse + from InvenTree.config import get_frontend_settings -from InvenTree.unit_test import InvenTreeTestCase +from InvenTree.unit_test import InvenTreeAPITestCase, InvenTreeTestCase from .templatetags import spa_helper @@ -73,3 +75,26 @@ class TemplateTagTest(InvenTreeTestCase): rsp = get_frontend_settings(False) self.assertFalse('show_server_selector' in rsp) self.assertEqual(rsp['server_list'], ['aa', 'bb']) + + +class TestWebHelpers(InvenTreeAPITestCase): + """Tests for the web helpers.""" + + def test_ui_preference(self): + """Test the UI preference API.""" + url = reverse('api-ui-preference') + + # Test default + resp = self.get(url) + data = json.loads(resp.content) + self.assertTrue(data['cui']) + self.assertFalse(data['pui']) + self.assertEqual(data['preferred_method'], 'cui') + + # Set to PUI + resp = self.put(url, {'preferred_method': 'pui'}) + data = json.loads(resp.content) + self.assertEqual(resp.status_code, 200) + self.assertFalse(data['cui']) + self.assertTrue(data['pui']) + self.assertEqual(data['preferred_method'], 'pui') diff --git a/InvenTree/web/urls.py b/InvenTree/web/urls.py index 16cdb57f40..56043c8ed9 100644 --- a/InvenTree/web/urls.py +++ b/InvenTree/web/urls.py @@ -1,11 +1,16 @@ """URLs for web app.""" from django.conf import settings +from django.http import JsonResponse from django.shortcuts import redirect from django.urls import include, path, re_path from django.views.decorators.csrf import ensure_csrf_cookie from django.views.generic import TemplateView +from rest_framework import permissions, serializers + +from InvenTree.mixins import RetrieveUpdateAPI + class RedirectAssetView(TemplateView): """View to redirect to static asset.""" @@ -17,6 +22,55 @@ class RedirectAssetView(TemplateView): ) +class PreferredSerializer(serializers.Serializer): + """Serializer for the preferred serializer session setting.""" + + preferred_method = serializers.ChoiceField(choices=['cui', 'pui']) + pui = serializers.SerializerMethodField(read_only=True) + cui = serializers.SerializerMethodField(read_only=True) + + def get_pui(self, obj): + """Return true if preferred method is PUI.""" + return obj['preferred_method'] == 'pui' + + def get_cui(self, obj): + """Return true if preferred method is CUI.""" + return obj['preferred_method'] == 'cui' + + class Meta: + """Meta class for PreferedSerializer.""" + + fields = '__all__' + + +class PreferredUiView(RetrieveUpdateAPI): + """Set preferred UI (CIU/PUI).""" + + permission_classes = [permissions.AllowAny] + serializer_class = PreferredSerializer + http_method_names = ['get', 'post', 'put', 'head', 'options'] + + def retrieve(self, request, *args, **kwargs): + """Retrieve the preferred UI method.""" + session = self.request.session + session['preferred_method'] = session.get('preferred_method', 'cui') + serializer = self.get_serializer(data=dict(session)) + serializer.is_valid(raise_exception=True) + return JsonResponse(serializer.data) + + def update(self, request, *args, **kwargs): + """Update the preferred UI method.""" + serializer = self.get_serializer(data=self.clean_data(request.data)) + serializer.is_valid(raise_exception=True) + + # Run update + session = self.request.session + session['preferred_method'] = serializer.validated_data['preferred_method'] + session.modified = True + + return JsonResponse(serializer.data) + + spa_view = ensure_csrf_cookie(TemplateView.as_view(template_name='web/index.html')) assets_path = path('assets/', RedirectAssetView.as_view()) @@ -37,3 +91,8 @@ urlpatterns = [ assets_path, path(settings.FRONTEND_URL_BASE, spa_view, name='platform'), ] + +api_urls = [ + # UI Preference + path('ui_preference/', PreferredUiView.as_view(), name='api-ui-preference') +] diff --git a/src/frontend/src/components/buttons/SSOButton.tsx b/src/frontend/src/components/buttons/SSOButton.tsx new file mode 100644 index 0000000000..503f5dafd5 --- /dev/null +++ b/src/frontend/src/components/buttons/SSOButton.tsx @@ -0,0 +1,64 @@ +import { Button } from '@mantine/core'; +import { + IconBrandAzure, + IconBrandBitbucket, + IconBrandDiscord, + IconBrandFacebook, + IconBrandFlickr, + IconBrandGithub, + IconBrandGitlab, + IconBrandGoogle, + IconBrandReddit, + IconBrandTwitch, + IconBrandTwitter, + IconLogin +} from '@tabler/icons-react'; + +import { api } from '../../App'; +import { ApiEndpoints } from '../../enums/ApiEndpoints'; +import { apiUrl } from '../../states/ApiState'; +import { Provider } from '../../states/states'; + +const brandIcons: { [key: string]: JSX.Element } = { + google: , + github: , + facebook: , + discord: , + twitter: , + bitbucket: , + flickr: , + gitlab: , + reddit: , + twitch: , + microsoft: +}; + +export function SsoButton({ provider }: { provider: Provider }) { + function login() { + // set preferred provider + api + .put( + apiUrl(ApiEndpoints.ui_preference), + { preferred_method: 'pui' }, + { headers: { Authorization: '' } } + ) + .then(() => { + // redirect to login + window.location.href = provider.login; + }); + } + + return ( + + ); +} +function getBrandIcon(provider: Provider) { + return brandIcons[provider.id] || ; +} diff --git a/src/frontend/src/components/forms/AuthenticationForm.tsx b/src/frontend/src/components/forms/AuthenticationForm.tsx index 3b2019c19b..fdadec91ea 100644 --- a/src/frontend/src/components/forms/AuthenticationForm.tsx +++ b/src/frontend/src/components/forms/AuthenticationForm.tsx @@ -2,6 +2,7 @@ import { Trans, t } from '@lingui/macro'; import { Anchor, Button, + Divider, Group, Loader, PasswordInput, @@ -21,6 +22,7 @@ import { ApiEndpoints } from '../../enums/ApiEndpoints'; import { doBasicLogin, doSimpleLogin } from '../../functions/auth'; import { apiUrl, useServerApiState } from '../../states/ApiState'; import { useSessionState } from '../../states/SessionState'; +import { SsoButton } from '../buttons/SSOButton'; export function AuthenticationForm() { const classicForm = useForm({ @@ -83,76 +85,93 @@ export function AuthenticationForm() { } return ( -
{})}> - {classicLoginMode ? ( - - - - {auth_settings?.password_forgotten_enabled === true && ( - - navigate('/reset-password')} - > - Reset password - - - )} - - ) : ( - - - - )} + <> + {auth_settings?.sso_enabled === true ? ( + <> + + {auth_settings.providers.map((provider) => ( + + ))} + - - setMode.toggle()} - > - {classicLoginMode ? ( - Send me an email - ) : ( - Use username and password - )} - - - -
+ + + ) : null} +
{})}> + {classicLoginMode ? ( + + + + {auth_settings?.password_forgotten_enabled === true && ( + + navigate('/reset-password')} + > + Reset password + + + )} + + ) : ( + + + + )} + + + setMode.toggle()} + > + {classicLoginMode ? ( + Send me an email + ) : ( + Use username and password + )} + + + +
+ ); } @@ -161,6 +180,7 @@ export function RegistrationForm() { initialValues: { username: '', email: '', password1: '', password2: '' } }); const navigate = useNavigate(); + const [auth_settings] = useServerApiState((state) => [state.auth_settings]); const [isRegistering, setIsRegistering] = useState(false); function handleRegistration() { @@ -201,47 +221,63 @@ export function RegistrationForm() { }); } + const both_reg_enabled = + auth_settings?.registration_enabled && auth_settings?.sso_registration; return ( -
{})}> - - - - - - + <> + {auth_settings?.registration_enabled && ( + {})}> + + + + + + - - - -
+ + + + + )} + {both_reg_enabled && ( + + )} + {auth_settings?.sso_registration === true && ( + + {auth_settings.providers.map((provider) => ( + + ))} + + )} + ); } @@ -253,8 +289,12 @@ export function ModeSelector({ setMode: any; }) { const [auth_settings] = useServerApiState((state) => [state.auth_settings]); + const registration_enabled = + auth_settings?.registration_enabled || + auth_settings?.sso_registration || + false; - if (auth_settings?.registration_enabled === false) return null; + if (registration_enabled === false) return null; return ( {loginMode ? ( diff --git a/src/frontend/src/enums/ApiEndpoints.tsx b/src/frontend/src/enums/ApiEndpoints.tsx index 764f20f506..d07d1c13c7 100644 --- a/src/frontend/src/enums/ApiEndpoints.tsx +++ b/src/frontend/src/enums/ApiEndpoints.tsx @@ -26,6 +26,7 @@ export enum ApiEndpoints { user_logout = 'auth/logout/', user_register = 'auth/registration/', + // Generic API endpoints currency_list = 'currency/exchange/', currency_refresh = 'currency/refresh/', task_overview = 'background-task/', @@ -43,6 +44,8 @@ export enum ApiEndpoints { sso_providers = 'auth/providers/', group_list = 'user/group/', owner_list = 'user/owner/', + + // Build API endpoints build_order_list = 'build/', build_order_attachment_list = 'build/attachment/', build_line_list = 'build/line/', @@ -61,6 +64,7 @@ export enum ApiEndpoints { part_attachment_list = 'part/attachment/', part_test_template_list = 'part/test-template/', + // Company API endpoints company_list = 'company/', contact_list = 'company/contact/', address_list = 'company/address/', @@ -76,6 +80,8 @@ export enum ApiEndpoints { stock_location_list = 'stock/location/', stock_location_tree = 'stock/location/tree/', stock_attachment_list = 'stock/attachment/', + + // Order API endpoints purchase_order_list = 'order/po/', purchase_order_line_list = 'order/po-line/', purchase_order_attachment_list = 'order/po/attachment/', @@ -84,12 +90,17 @@ export enum ApiEndpoints { sales_order_shipment_list = 'order/so/shipment/', return_order_list = 'order/ro/', return_order_attachment_list = 'order/ro/attachment/', + + // Plugin API endpoints plugin_list = 'plugins/', plugin_setting_list = 'plugins/:plugin/settings/', plugin_registry_status = 'plugins/status/', plugin_install = 'plugins/install/', plugin_reload = 'plugins/reload/', + + // Miscellaneous API endpoints error_report_list = 'error-report/', project_code_list = 'project-code/', - custom_unit_list = 'units/' + custom_unit_list = 'units/', + ui_preference = 'web/ui_preference/' } diff --git a/src/frontend/src/pages/Auth/Login.tsx b/src/frontend/src/pages/Auth/Login.tsx index f794bbcfd4..49d571b97d 100644 --- a/src/frontend/src/pages/Auth/Login.tsx +++ b/src/frontend/src/pages/Auth/Login.tsx @@ -71,7 +71,11 @@ export default function Login() { <> - Welcome, log in below + {loginMode ? ( + Welcome, log in below + ) : ( + Register below + )} {loginMode ? : }