[PUI] SSO Support (#6333)

* Add sso buttons
Fixes #5753

* Added more icons

* fix callback url

* made heading dynamic

* allow either sso or normal reg

* Added SSO registration

* added divider

* added preferred ui API

* fix test

* fix update function

* refactor

* fix naming

* fix import

* add coverage ignore

* more ignore

* fixed missing key

* renamed button

* revert coverage statements

* set prefered mode before sso login

* added dynamic login redirect

* fixed test assert

* use API Endpoints instead of hardcoding

* fix lookup
This commit is contained in:
Matthias Mair 2024-02-02 09:15:13 +00:00 committed by GitHub
parent c9c93bce39
commit be30cec2ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 337 additions and 118 deletions

View File

@ -1148,5 +1148,4 @@ if CUSTOM_FLAGS:
# Magic login django-sesame # Magic login django-sesame
SESAME_MAX_AGE = 300 SESAME_MAX_AGE = 300
# LOGIN_REDIRECT_URL = f"/{FRONTEND_URL_BASE}/logged-in/" LOGIN_REDIRECT_URL = '/api/auth/login-redirect/'
LOGIN_REDIRECT_URL = '/index/'

View File

@ -1247,9 +1247,6 @@ class MagicLoginTest(InvenTreeTestCase):
# Check that the login works # Check that the login works
resp = self.client.get(reverse('sesame-login') + '?sesame=' + token) resp = self.client.get(reverse('sesame-login') + '?sesame=' + token)
self.assertEqual(resp.status_code, 302) self.assertEqual(resp.status_code, 302)
self.assertEqual(resp.url, '/index/') self.assertEqual(resp.url, '/api/auth/login-redirect/')
# Note: 2023-08-08 - This test has been changed because "platform UI" is not generally available yet
# TODO: In the future, the URL comparison will need to be reverted
# self.assertEqual(resp.url, f'/{settings.FRONTEND_URL_BASE}/logged-in/')
# And we should be logged in again # And we should be logged in again
self.assertEqual(resp.wsgi_request.user, self.user) self.assertEqual(resp.wsgi_request.user, self.user)

View File

@ -35,6 +35,7 @@ from order.urls import order_urls
from part.urls import part_urls from part.urls import part_urls
from plugin.urls import get_plugin_urls from plugin.urls import get_plugin_urls
from stock.urls import stock_urls from stock.urls import stock_urls
from web.urls import api_urls as web_api_urls
from web.urls import urlpatterns as platform_urls from web.urls import urlpatterns as platform_urls
from .api import APISearchView, InfoView, NotFoundView, VersionTextView, VersionView from .api import APISearchView, InfoView, NotFoundView, VersionTextView, VersionView
@ -84,6 +85,7 @@ apipatterns = [
path('report/', include(report.api.report_api_urls)), path('report/', include(report.api.report_api_urls)),
path('user/', include(users.api.user_urls)), path('user/', include(users.api.user_urls)),
path('admin/', include(common.api.admin_api_urls)), path('admin/', include(common.api.admin_api_urls)),
path('web/', include(web_api_urls)),
# Plugin endpoints # Plugin endpoints
path('', include(plugin.api.plugin_api_urls)), path('', include(plugin.api.plugin_api_urls)),
# Common endpoints endpoint # Common endpoints endpoint
@ -149,6 +151,11 @@ apipatterns = [
name='social_account_disconnect', name='social_account_disconnect',
), ),
path('logout/', users.api.Logout.as_view(), name='api-logout'), path('logout/', users.api.Logout.as_view(), name='api-logout'),
path(
'login-redirect/',
users.api.LoginRedirect.as_view(),
name='api-login-redirect',
),
path('', include('dj_rest_auth.urls')), path('', include('dj_rest_auth.urls')),
]), ]),
), ),

View File

@ -6,6 +6,7 @@ import logging
from django.contrib.auth import get_user, login from django.contrib.auth import get_user, login
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from django.urls import include, path, re_path from django.urls import include, path, re_path
from django.views.generic.base import RedirectView
from dj_rest_auth.views import LogoutView from dj_rest_auth.views import LogoutView
from rest_framework import exceptions, permissions from rest_framework import exceptions, permissions
@ -22,6 +23,7 @@ from InvenTree.mixins import (
RetrieveUpdateDestroyAPI, RetrieveUpdateDestroyAPI,
) )
from InvenTree.serializers import ExendedUserSerializer, UserCreateSerializer from InvenTree.serializers import ExendedUserSerializer, UserCreateSerializer
from InvenTree.settings import FRONTEND_URL_BASE
from users.models import ApiToken, Owner, RuleSet, check_user_role from users.models import ApiToken, Owner, RuleSet, check_user_role
from users.serializers import GroupSerializer, OwnerSerializer from users.serializers import GroupSerializer, OwnerSerializer
@ -279,6 +281,17 @@ class GetAuthToken(APIView):
raise exceptions.NotAuthenticated() raise exceptions.NotAuthenticated()
class LoginRedirect(RedirectView):
"""Redirect to the correct starting page after backend login."""
def get_redirect_url(self, *args, **kwargs):
"""Return the URL to redirect to."""
session = self.request.session
if session.get('preferred_method', 'cui') == 'pui':
return f'/{FRONTEND_URL_BASE}/logged-in/'
return '/index/'
user_urls = [ user_urls = [
path('roles/', RoleDetails.as_view(), name='api-user-roles'), path('roles/', RoleDetails.as_view(), name='api-user-roles'),
path('token/', GetAuthToken.as_view(), name='api-token'), path('token/', GetAuthToken.as_view(), name='api-token'),

View File

@ -5,8 +5,10 @@ import os
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
from django.urls import reverse
from InvenTree.config import get_frontend_settings from InvenTree.config import get_frontend_settings
from InvenTree.unit_test import InvenTreeTestCase from InvenTree.unit_test import InvenTreeAPITestCase, InvenTreeTestCase
from .templatetags import spa_helper from .templatetags import spa_helper
@ -73,3 +75,26 @@ class TemplateTagTest(InvenTreeTestCase):
rsp = get_frontend_settings(False) rsp = get_frontend_settings(False)
self.assertFalse('show_server_selector' in rsp) self.assertFalse('show_server_selector' in rsp)
self.assertEqual(rsp['server_list'], ['aa', 'bb']) self.assertEqual(rsp['server_list'], ['aa', 'bb'])
class TestWebHelpers(InvenTreeAPITestCase):
"""Tests for the web helpers."""
def test_ui_preference(self):
"""Test the UI preference API."""
url = reverse('api-ui-preference')
# Test default
resp = self.get(url)
data = json.loads(resp.content)
self.assertTrue(data['cui'])
self.assertFalse(data['pui'])
self.assertEqual(data['preferred_method'], 'cui')
# Set to PUI
resp = self.put(url, {'preferred_method': 'pui'})
data = json.loads(resp.content)
self.assertEqual(resp.status_code, 200)
self.assertFalse(data['cui'])
self.assertTrue(data['pui'])
self.assertEqual(data['preferred_method'], 'pui')

View File

@ -1,11 +1,16 @@
"""URLs for web app.""" """URLs for web app."""
from django.conf import settings from django.conf import settings
from django.http import JsonResponse
from django.shortcuts import redirect from django.shortcuts import redirect
from django.urls import include, path, re_path from django.urls import include, path, re_path
from django.views.decorators.csrf import ensure_csrf_cookie from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import TemplateView from django.views.generic import TemplateView
from rest_framework import permissions, serializers
from InvenTree.mixins import RetrieveUpdateAPI
class RedirectAssetView(TemplateView): class RedirectAssetView(TemplateView):
"""View to redirect to static asset.""" """View to redirect to static asset."""
@ -17,6 +22,55 @@ class RedirectAssetView(TemplateView):
) )
class PreferredSerializer(serializers.Serializer):
"""Serializer for the preferred serializer session setting."""
preferred_method = serializers.ChoiceField(choices=['cui', 'pui'])
pui = serializers.SerializerMethodField(read_only=True)
cui = serializers.SerializerMethodField(read_only=True)
def get_pui(self, obj):
"""Return true if preferred method is PUI."""
return obj['preferred_method'] == 'pui'
def get_cui(self, obj):
"""Return true if preferred method is CUI."""
return obj['preferred_method'] == 'cui'
class Meta:
"""Meta class for PreferedSerializer."""
fields = '__all__'
class PreferredUiView(RetrieveUpdateAPI):
"""Set preferred UI (CIU/PUI)."""
permission_classes = [permissions.AllowAny]
serializer_class = PreferredSerializer
http_method_names = ['get', 'post', 'put', 'head', 'options']
def retrieve(self, request, *args, **kwargs):
"""Retrieve the preferred UI method."""
session = self.request.session
session['preferred_method'] = session.get('preferred_method', 'cui')
serializer = self.get_serializer(data=dict(session))
serializer.is_valid(raise_exception=True)
return JsonResponse(serializer.data)
def update(self, request, *args, **kwargs):
"""Update the preferred UI method."""
serializer = self.get_serializer(data=self.clean_data(request.data))
serializer.is_valid(raise_exception=True)
# Run update
session = self.request.session
session['preferred_method'] = serializer.validated_data['preferred_method']
session.modified = True
return JsonResponse(serializer.data)
spa_view = ensure_csrf_cookie(TemplateView.as_view(template_name='web/index.html')) spa_view = ensure_csrf_cookie(TemplateView.as_view(template_name='web/index.html'))
assets_path = path('assets/<path:path>', RedirectAssetView.as_view()) assets_path = path('assets/<path:path>', RedirectAssetView.as_view())
@ -37,3 +91,8 @@ urlpatterns = [
assets_path, assets_path,
path(settings.FRONTEND_URL_BASE, spa_view, name='platform'), path(settings.FRONTEND_URL_BASE, spa_view, name='platform'),
] ]
api_urls = [
# UI Preference
path('ui_preference/', PreferredUiView.as_view(), name='api-ui-preference')
]

View File

@ -0,0 +1,64 @@
import { Button } from '@mantine/core';
import {
IconBrandAzure,
IconBrandBitbucket,
IconBrandDiscord,
IconBrandFacebook,
IconBrandFlickr,
IconBrandGithub,
IconBrandGitlab,
IconBrandGoogle,
IconBrandReddit,
IconBrandTwitch,
IconBrandTwitter,
IconLogin
} from '@tabler/icons-react';
import { api } from '../../App';
import { ApiEndpoints } from '../../enums/ApiEndpoints';
import { apiUrl } from '../../states/ApiState';
import { Provider } from '../../states/states';
const brandIcons: { [key: string]: JSX.Element } = {
google: <IconBrandGoogle />,
github: <IconBrandGithub />,
facebook: <IconBrandFacebook />,
discord: <IconBrandDiscord />,
twitter: <IconBrandTwitter />,
bitbucket: <IconBrandBitbucket />,
flickr: <IconBrandFlickr />,
gitlab: <IconBrandGitlab />,
reddit: <IconBrandReddit />,
twitch: <IconBrandTwitch />,
microsoft: <IconBrandAzure />
};
export function SsoButton({ provider }: { provider: Provider }) {
function login() {
// set preferred provider
api
.put(
apiUrl(ApiEndpoints.ui_preference),
{ preferred_method: 'pui' },
{ headers: { Authorization: '' } }
)
.then(() => {
// redirect to login
window.location.href = provider.login;
});
}
return (
<Button
leftIcon={getBrandIcon(provider)}
radius="xl"
component="a"
onClick={login}
>
{provider.display_name}{' '}
</Button>
);
}
function getBrandIcon(provider: Provider) {
return brandIcons[provider.id] || <IconLogin />;
}

View File

@ -2,6 +2,7 @@ import { Trans, t } from '@lingui/macro';
import { import {
Anchor, Anchor,
Button, Button,
Divider,
Group, Group,
Loader, Loader,
PasswordInput, PasswordInput,
@ -21,6 +22,7 @@ import { ApiEndpoints } from '../../enums/ApiEndpoints';
import { doBasicLogin, doSimpleLogin } from '../../functions/auth'; import { doBasicLogin, doSimpleLogin } from '../../functions/auth';
import { apiUrl, useServerApiState } from '../../states/ApiState'; import { apiUrl, useServerApiState } from '../../states/ApiState';
import { useSessionState } from '../../states/SessionState'; import { useSessionState } from '../../states/SessionState';
import { SsoButton } from '../buttons/SSOButton';
export function AuthenticationForm() { export function AuthenticationForm() {
const classicForm = useForm({ const classicForm = useForm({
@ -83,76 +85,93 @@ export function AuthenticationForm() {
} }
return ( return (
<form onSubmit={classicForm.onSubmit(() => {})}> <>
{classicLoginMode ? ( {auth_settings?.sso_enabled === true ? (
<Stack spacing={0}> <>
<TextInput <Group grow mb="md" mt="md">
required {auth_settings.providers.map((provider) => (
label={t`Username`} <SsoButton provider={provider} key={provider.id} />
placeholder={t`Your username`} ))}
{...classicForm.getInputProps('username')} </Group>
/>
<PasswordInput
required
label={t`Password`}
placeholder={t`Your password`}
{...classicForm.getInputProps('password')}
/>
{auth_settings?.password_forgotten_enabled === true && (
<Group position="apart" mt="0">
<Anchor
component="button"
type="button"
color="dimmed"
size="xs"
onClick={() => navigate('/reset-password')}
>
<Trans>Reset password</Trans>
</Anchor>
</Group>
)}
</Stack>
) : (
<Stack>
<TextInput
required
label={t`Email`}
description={t`We will send you a link to login - if you are registered`}
placeholder="email@example.org"
{...simpleForm.getInputProps('email')}
/>
</Stack>
)}
<Group position="apart" mt="xl"> <Divider
<Anchor label={t`Or continue with other methods`}
component="button" labelPosition="center"
type="button" my="lg"
color="dimmed" />
size="xs" </>
onClick={() => setMode.toggle()} ) : null}
> <form onSubmit={classicForm.onSubmit(() => {})}>
{classicLoginMode ? ( {classicLoginMode ? (
<Trans>Send me an email</Trans> <Stack spacing={0}>
) : ( <TextInput
<Trans>Use username and password</Trans> required
)} label={t`Username`}
</Anchor> placeholder={t`Your username`}
<Button type="submit" disabled={isLoggingIn} onClick={handleLogin}> {...classicForm.getInputProps('username')}
{isLoggingIn ? ( />
<Loader size="sm" /> <PasswordInput
) : ( required
<> label={t`Password`}
{classicLoginMode ? ( placeholder={t`Your password`}
<Trans>Log In</Trans> {...classicForm.getInputProps('password')}
) : ( />
<Trans>Send Email</Trans> {auth_settings?.password_forgotten_enabled === true && (
)} <Group position="apart" mt="0">
</> <Anchor
)} component="button"
</Button> type="button"
</Group> color="dimmed"
</form> size="xs"
onClick={() => navigate('/reset-password')}
>
<Trans>Reset password</Trans>
</Anchor>
</Group>
)}
</Stack>
) : (
<Stack>
<TextInput
required
label={t`Email`}
description={t`We will send you a link to login - if you are registered`}
placeholder="email@example.org"
{...simpleForm.getInputProps('email')}
/>
</Stack>
)}
<Group position="apart" mt="xl">
<Anchor
component="button"
type="button"
color="dimmed"
size="xs"
onClick={() => setMode.toggle()}
>
{classicLoginMode ? (
<Trans>Send me an email</Trans>
) : (
<Trans>Use username and password</Trans>
)}
</Anchor>
<Button type="submit" disabled={isLoggingIn} onClick={handleLogin}>
{isLoggingIn ? (
<Loader size="sm" />
) : (
<>
{classicLoginMode ? (
<Trans>Log In</Trans>
) : (
<Trans>Send Email</Trans>
)}
</>
)}
</Button>
</Group>
</form>
</>
); );
} }
@ -161,6 +180,7 @@ export function RegistrationForm() {
initialValues: { username: '', email: '', password1: '', password2: '' } initialValues: { username: '', email: '', password1: '', password2: '' }
}); });
const navigate = useNavigate(); const navigate = useNavigate();
const [auth_settings] = useServerApiState((state) => [state.auth_settings]);
const [isRegistering, setIsRegistering] = useState<boolean>(false); const [isRegistering, setIsRegistering] = useState<boolean>(false);
function handleRegistration() { function handleRegistration() {
@ -201,47 +221,63 @@ export function RegistrationForm() {
}); });
} }
const both_reg_enabled =
auth_settings?.registration_enabled && auth_settings?.sso_registration;
return ( return (
<form onSubmit={registrationForm.onSubmit(() => {})}> <>
<Stack spacing={0}> {auth_settings?.registration_enabled && (
<TextInput <form onSubmit={registrationForm.onSubmit(() => {})}>
required <Stack spacing={0}>
label={t`Username`} <TextInput
placeholder={t`Your username`} required
{...registrationForm.getInputProps('username')} label={t`Username`}
/> placeholder={t`Your username`}
<TextInput {...registrationForm.getInputProps('username')}
required />
label={t`Email`} <TextInput
description={t`This will be used for a confirmation`} required
placeholder="email@example.org" label={t`Email`}
{...registrationForm.getInputProps('email')} description={t`This will be used for a confirmation`}
/> placeholder="email@example.org"
<PasswordInput {...registrationForm.getInputProps('email')}
required />
label={t`Password`} <PasswordInput
placeholder={t`Your password`} required
{...registrationForm.getInputProps('password1')} label={t`Password`}
/> placeholder={t`Your password`}
<PasswordInput {...registrationForm.getInputProps('password1')}
required />
label={t`Password repeat`} <PasswordInput
placeholder={t`Repeat password`} required
{...registrationForm.getInputProps('password2')} label={t`Password repeat`}
/> placeholder={t`Repeat password`}
</Stack> {...registrationForm.getInputProps('password2')}
/>
</Stack>
<Group position="apart" mt="xl"> <Group position="apart" mt="xl">
<Button <Button
type="submit" type="submit"
disabled={isRegistering} disabled={isRegistering}
onClick={handleRegistration} onClick={handleRegistration}
fullWidth fullWidth
> >
<Trans>Register</Trans> <Trans>Register</Trans>
</Button> </Button>
</Group> </Group>
</form> </form>
)}
{both_reg_enabled && (
<Divider label={t`Or use SSO`} labelPosition="center" my="lg" />
)}
{auth_settings?.sso_registration === true && (
<Group grow mb="md" mt="md">
{auth_settings.providers.map((provider) => (
<SsoButton provider={provider} key={provider.id} />
))}
</Group>
)}
</>
); );
} }
@ -253,8 +289,12 @@ export function ModeSelector({
setMode: any; setMode: any;
}) { }) {
const [auth_settings] = useServerApiState((state) => [state.auth_settings]); const [auth_settings] = useServerApiState((state) => [state.auth_settings]);
const registration_enabled =
auth_settings?.registration_enabled ||
auth_settings?.sso_registration ||
false;
if (auth_settings?.registration_enabled === false) return null; if (registration_enabled === false) return null;
return ( return (
<Text ta="center" size={'xs'} mt={'md'}> <Text ta="center" size={'xs'} mt={'md'}>
{loginMode ? ( {loginMode ? (

View File

@ -26,6 +26,7 @@ export enum ApiEndpoints {
user_logout = 'auth/logout/', user_logout = 'auth/logout/',
user_register = 'auth/registration/', user_register = 'auth/registration/',
// Generic API endpoints
currency_list = 'currency/exchange/', currency_list = 'currency/exchange/',
currency_refresh = 'currency/refresh/', currency_refresh = 'currency/refresh/',
task_overview = 'background-task/', task_overview = 'background-task/',
@ -43,6 +44,8 @@ export enum ApiEndpoints {
sso_providers = 'auth/providers/', sso_providers = 'auth/providers/',
group_list = 'user/group/', group_list = 'user/group/',
owner_list = 'user/owner/', owner_list = 'user/owner/',
// Build API endpoints
build_order_list = 'build/', build_order_list = 'build/',
build_order_attachment_list = 'build/attachment/', build_order_attachment_list = 'build/attachment/',
build_line_list = 'build/line/', build_line_list = 'build/line/',
@ -61,6 +64,7 @@ export enum ApiEndpoints {
part_attachment_list = 'part/attachment/', part_attachment_list = 'part/attachment/',
part_test_template_list = 'part/test-template/', part_test_template_list = 'part/test-template/',
// Company API endpoints
company_list = 'company/', company_list = 'company/',
contact_list = 'company/contact/', contact_list = 'company/contact/',
address_list = 'company/address/', address_list = 'company/address/',
@ -76,6 +80,8 @@ export enum ApiEndpoints {
stock_location_list = 'stock/location/', stock_location_list = 'stock/location/',
stock_location_tree = 'stock/location/tree/', stock_location_tree = 'stock/location/tree/',
stock_attachment_list = 'stock/attachment/', stock_attachment_list = 'stock/attachment/',
// Order API endpoints
purchase_order_list = 'order/po/', purchase_order_list = 'order/po/',
purchase_order_line_list = 'order/po-line/', purchase_order_line_list = 'order/po-line/',
purchase_order_attachment_list = 'order/po/attachment/', purchase_order_attachment_list = 'order/po/attachment/',
@ -84,12 +90,17 @@ export enum ApiEndpoints {
sales_order_shipment_list = 'order/so/shipment/', sales_order_shipment_list = 'order/so/shipment/',
return_order_list = 'order/ro/', return_order_list = 'order/ro/',
return_order_attachment_list = 'order/ro/attachment/', return_order_attachment_list = 'order/ro/attachment/',
// Plugin API endpoints
plugin_list = 'plugins/', plugin_list = 'plugins/',
plugin_setting_list = 'plugins/:plugin/settings/', plugin_setting_list = 'plugins/:plugin/settings/',
plugin_registry_status = 'plugins/status/', plugin_registry_status = 'plugins/status/',
plugin_install = 'plugins/install/', plugin_install = 'plugins/install/',
plugin_reload = 'plugins/reload/', plugin_reload = 'plugins/reload/',
// Miscellaneous API endpoints
error_report_list = 'error-report/', error_report_list = 'error-report/',
project_code_list = 'project-code/', project_code_list = 'project-code/',
custom_unit_list = 'units/' custom_unit_list = 'units/',
ui_preference = 'web/ui_preference/'
} }

View File

@ -71,7 +71,11 @@ export default function Login() {
<> <>
<Paper radius="md" p="xl" withBorder> <Paper radius="md" p="xl" withBorder>
<Text size="lg" weight={500}> <Text size="lg" weight={500}>
<Trans>Welcome, log in below</Trans> {loginMode ? (
<Trans>Welcome, log in below</Trans>
) : (
<Trans>Register below</Trans>
)}
</Text> </Text>
{loginMode ? <AuthenticationForm /> : <RegistrationForm />} {loginMode ? <AuthenticationForm /> : <RegistrationForm />}
<ModeSelector loginMode={loginMode} setMode={setMode} /> <ModeSelector loginMode={loginMode} setMode={setMode} />