Refactor login state management (#7158)

* Refactor login state management

- Previously relied only on presence of cookie
- Cookie may not actually be *valid*
- Inspect actual login state by looking at userState values
- Ensures better sequencing of global state API requests
- Login state is now correctly preseed across browsers

* Ignore errors for user/me/ API endpoint in playwright test

* Do not request notifications unless logged in

* Prevent duplicate licenses

* Update src/frontend/src/views/DesktopAppView.tsx

Co-authored-by: Matthias Mair <code@mjmair.com>

* Simplify checkLoginState

* Fix bug in return types

* Update playwright tests

* linting

* Remove error msg

* Use token auth for API calls

- Will (hopefully) allow us to bypass csrfmiddle request handling?

* Refetch token if not available

* Use cache for DISPLAY_FULL_NAMES setting

* Update src/frontend/tests/baseFixtures.ts

Co-authored-by: Matthias Mair <code@mjmair.com>

* PUI test updates

* Tweak doLogout function

* Revert change to baseFixtures.ts

* Cleanup

* Fix highlighted property

* Test cleanup

---------

Co-authored-by: Matthias Mair <code@mjmair.com>
This commit is contained in:
Oliver 2024-05-07 23:11:38 +10:00 committed by GitHub
parent 6c944c73dd
commit 289af4e924
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 225 additions and 109 deletions

View File

@ -73,8 +73,24 @@ class LicenseView(APIView):
logger.exception("Exception while reading license file '%s': %s", path, e)
return []
# Ensure consistent string between backend and frontend licenses
return [{key.lower(): value for key, value in entry.items()} for entry in data]
output = []
names = set()
# Ensure we do not have any duplicate 'name' values in the list
for entry in data:
name = None
for key in entry.keys():
if key.lower() == 'name':
name = entry[key]
break
if name is None or name in names:
continue
names.add(name)
output.append({key.lower(): value for key, value in entry.items()})
return output
@extend_schema(responses={200: OpenApiResponse(response=LicenseViewSerializer)})
def get(self, request, *args, **kwargs):

View File

@ -70,7 +70,8 @@ class AuthRequiredMiddleware(object):
# API requests are handled by the DRF library
if request.path_info.startswith('/api/'):
return self.get_response(request)
response = self.get_response(request)
return response
# Is the function exempt from auth requirements?
path_func = resolve(request.path).func

View File

@ -34,7 +34,7 @@ logger = logging.getLogger('inventree')
# string representation of a user
def user_model_str(self):
"""Function to override the default Django User __str__."""
if common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES'):
if common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES', cache=True):
if self.first_name or self.last_name:
return f'{self.first_name} {self.last_name}'
return self.username
@ -831,7 +831,9 @@ class Owner(models.Model):
"""Defines the owner string representation."""
if (
self.owner_type.name == 'user'
and common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES')
and common_models.InvenTreeSetting.get_setting(
'DISPLAY_FULL_NAMES', cache=True
)
):
display_name = self.owner.get_full_name()
else:
@ -842,7 +844,9 @@ class Owner(models.Model):
"""Return the 'name' of this owner."""
if (
self.owner_type.name == 'user'
and common_models.InvenTreeSetting.get_setting('DISPLAY_FULL_NAMES')
and common_models.InvenTreeSetting.get_setting(
'DISPLAY_FULL_NAMES', cache=True
)
):
return self.owner.get_full_name() or str(self.owner)
return str(self.owner)

View File

@ -5,6 +5,7 @@ export default defineConfig({
fullyParallel: true,
timeout: 60000,
forbidOnly: !!process.env.CI,
timeout: 5 * 60 * 1000,
retries: process.env.CI ? 1 : 0,
workers: process.env.CI ? 2 : undefined,
reporter: process.env.CI ? [['html', { open: 'never' }], ['github']] : 'list',

View File

@ -2,6 +2,7 @@ import { QueryClient } from '@tanstack/react-query';
import axios from 'axios';
import { useLocalState } from './states/LocalState';
import { useUserState } from './states/UserState';
// Global API instance
export const api = axios.create({});
@ -11,6 +12,7 @@ export const api = axios.create({});
*/
export function setApiDefaults() {
const host = useLocalState.getState().host;
const token = useUserState.getState().token;
api.defaults.baseURL = host;
api.defaults.timeout = 2500;
@ -19,6 +21,12 @@ export function setApiDefaults() {
api.defaults.withXSRFToken = true;
api.defaults.xsrfCookieName = 'csrftoken';
api.defaults.xsrfHeaderName = 'X-CSRFToken';
if (token) {
api.defaults.headers['Authorization'] = `Token ${token}`;
} else {
delete api.defaults.headers['Authorization'];
}
}
export const queryClient = new QueryClient();

View File

@ -17,9 +17,10 @@ import { useLocation, useNavigate } from 'react-router-dom';
import { api } from '../../App';
import { ApiEndpoints } from '../../enums/ApiEndpoints';
import { doBasicLogin, doSimpleLogin, isLoggedIn } from '../../functions/auth';
import { doBasicLogin, doSimpleLogin } from '../../functions/auth';
import { showLoginNotification } from '../../functions/notifications';
import { apiUrl, useServerApiState } from '../../states/ApiState';
import { useUserState } from '../../states/UserState';
import { SsoButton } from '../buttons/SSOButton';
export function AuthenticationForm() {
@ -31,6 +32,7 @@ export function AuthenticationForm() {
const [auth_settings] = useServerApiState((state) => [state.auth_settings]);
const navigate = useNavigate();
const location = useLocation();
const { isLoggedIn } = useUserState();
const [isLoggingIn, setIsLoggingIn] = useState<boolean>(false);

View File

@ -45,17 +45,17 @@ function ConditionalDocTooltip({
export function MenuLinks({
links,
highlighted
highlighted = false
}: {
links: MenuLinkItem[];
highlighted?: boolean;
}) {
const { classes } = InvenTreeStyle();
highlighted = highlighted || false;
const filteredLinks = links.filter(
(item) => !highlighted || item.highlight === true
);
return (
<SimpleGrid cols={2} spacing={0}>
{filteredLinks.map((item) => (

View File

@ -11,6 +11,7 @@ import { ApiEndpoints } from '../../enums/ApiEndpoints';
import { InvenTreeStyle } from '../../globalStyle';
import { apiUrl } from '../../states/ApiState';
import { useLocalState } from '../../states/LocalState';
import { useUserState } from '../../states/UserState';
import { ScanButton } from '../buttons/ScanButton';
import { SpotlightButton } from '../buttons/SpotlightButton';
import { MainMenu } from './MainMenu';
@ -37,11 +38,14 @@ export function Header() {
{ open: openNotificationDrawer, close: closeNotificationDrawer }
] = useDisclosure(false);
const { isLoggedIn } = useUserState();
const [notificationCount, setNotificationCount] = useState<number>(0);
// Fetch number of notifications for the current user
const notifications = useQuery({
queryKey: ['notification-count'],
enabled: isLoggedIn(),
queryFn: async () => {
try {
const params = {

View File

@ -6,13 +6,14 @@ import { useEffect, useState } from 'react';
import { Navigate, Outlet, useLocation, useNavigate } from 'react-router-dom';
import { getActions } from '../../defaults/actions';
import { isLoggedIn } from '../../functions/auth';
import { InvenTreeStyle } from '../../globalStyle';
import { useUserState } from '../../states/UserState';
import { Footer } from './Footer';
import { Header } from './Header';
export const ProtectedRoute = ({ children }: { children: JSX.Element }) => {
const location = useLocation();
const { isLoggedIn } = useUserState();
if (!isLoggedIn()) {
return (

View File

@ -51,7 +51,7 @@ export function PartCategoryTree({
)
.catch((error) => {
console.error('Error fetching part category tree:', error);
return error;
return [];
}),
refetchOnMount: true
});

View File

@ -43,7 +43,7 @@ export function StockLocationTree({
)
.catch((error) => {
console.error('Error fetching stock location tree:', error);
return error;
return [];
}),
refetchOnMount: true
});

View File

@ -104,7 +104,7 @@ export function LanguageContext({ children }: { children: JSX.Element }) {
})
/* istanbul ignore next */
.catch((err) => {
console.error('Failed loading translations', err);
console.error('ERR: Failed loading translations', err);
if (isMounted.current) setLoadedState('error');
});

View File

@ -6,6 +6,7 @@ import { api, setApiDefaults } from '../App';
import { ApiEndpoints } from '../enums/ApiEndpoints';
import { apiUrl } from '../states/ApiState';
import { useLocalState } from '../states/LocalState';
import { useUserState } from '../states/UserState';
import { fetchGlobalStates } from '../states/states';
import { showLoginNotification } from './notifications';
@ -16,7 +17,8 @@ import { showLoginNotification } from './notifications';
*/
export const doBasicLogin = async (username: string, password: string) => {
const { host } = useLocalState.getState();
// const apiState = useServerApiState.getState();
const { clearUserState, setToken, fetchUserState, isLoggedIn } =
useUserState.getState();
if (username.length == 0 || password.length == 0) {
return;
@ -26,6 +28,8 @@ export const doBasicLogin = async (username: string, password: string) => {
const login_url = apiUrl(ApiEndpoints.user_login);
let result: boolean = false;
// Attempt login with
await api
.post(
@ -39,18 +43,21 @@ export const doBasicLogin = async (username: string, password: string) => {
}
)
.then((response) => {
switch (response.status) {
case 200:
fetchGlobalStates();
break;
default:
clearCsrfCookie();
break;
if (response.status == 200) {
if (response.data.key) {
setToken(response.data.key);
result = true;
}
}
})
.catch(() => {
clearCsrfCookie();
});
.catch(() => {});
if (result) {
await fetchUserState();
await fetchGlobalStates();
} else {
clearUserState();
}
};
/**
@ -59,16 +66,21 @@ export const doBasicLogin = async (username: string, password: string) => {
* @arg deleteToken: If true, delete the token from the server
*/
export const doLogout = async (navigate: any) => {
const { clearUserState, isLoggedIn } = useUserState.getState();
// Logout from the server session
await api.post(apiUrl(ApiEndpoints.user_logout)).finally(() => {
clearCsrfCookie();
navigate('/login');
if (isLoggedIn() || !!getCsrfCookie()) {
await api.post(apiUrl(ApiEndpoints.user_logout)).catch(() => {});
showLoginNotification({
title: t`Logged Out`,
message: t`Successfully logged out`
});
});
}
clearUserState();
clearCsrfCookie();
navigate('/login');
};
export const doSimpleLogin = async (email: string) => {
@ -122,17 +134,19 @@ export function handleReset(navigate: any, values: { email: string }) {
* - An existing API token is stored in the session
* - An existing CSRF cookie is stored in the browser
*/
export function checkLoginState(
export const checkLoginState = async (
navigate: any,
redirect?: string,
no_redirect?: boolean
) {
) => {
setApiDefaults();
if (redirect == '/') {
redirect = '/home';
}
const { isLoggedIn, fetchUserState } = useUserState.getState();
// Callback function when login is successful
const loginSuccess = () => {
showLoginNotification({
@ -140,6 +154,8 @@ export function checkLoginState(
message: t`Successfully logged in`
});
fetchGlobalStates();
navigate(redirect ?? '/home');
};
@ -150,24 +166,22 @@ export function checkLoginState(
}
};
// Check the 'user_me' endpoint to see if the user is logged in
if (isLoggedIn()) {
api
.get(apiUrl(ApiEndpoints.user_me))
.then((response) => {
if (response.status == 200) {
loginSuccess();
} else {
loginFailure();
}
})
.catch(() => {
loginFailure();
});
// Already logged in
loginSuccess();
return;
}
// Not yet logged in, but we might have a valid session cookie
// Attempt to login
await fetchUserState();
if (isLoggedIn()) {
loginSuccess();
} else {
loginFailure();
}
}
};
/*
* Return the value of the CSRF cookie, if available
@ -181,10 +195,6 @@ export function getCsrfCookie() {
return cookieValue;
}
export function isLoggedIn() {
return !!getCsrfCookie();
}
/*
* Clear out the CSRF and session cookies (force session logout)
*/

View File

@ -17,12 +17,17 @@ export function AccountDetailPanel() {
const form = useForm({ initialValues: user });
const [editing, setEditing] = useToggle([false, true] as const);
function SaveData(values: any) {
api.put(apiUrl(ApiEndpoints.user_me), values).then((res) => {
if (res.status === 200) {
setEditing();
fetchUserState();
}
});
api
.put(apiUrl(ApiEndpoints.user_me), values)
.then((res) => {
if (res.status === 200) {
setEditing();
fetchUserState();
}
})
.catch(() => {
console.error('ERR: Error saving user data');
});
}
return (

View File

@ -5,9 +5,9 @@ import { create, createStore } from 'zustand';
import { api } from '../App';
import { ApiEndpoints } from '../enums/ApiEndpoints';
import { isLoggedIn } from '../functions/auth';
import { isTrue } from '../functions/conversion';
import { PathParams, apiUrl } from './ApiState';
import { useUserState } from './UserState';
import { Setting, SettingsLookup } from './states';
export interface SettingsStateProps {
@ -29,6 +29,8 @@ export const useGlobalSettingsState = create<SettingsStateProps>(
lookup: {},
endpoint: ApiEndpoints.settings_global_list,
fetchSettings: async () => {
const { isLoggedIn } = useUserState.getState();
if (!isLoggedIn()) {
return;
}
@ -63,6 +65,8 @@ export const useUserSettingsState = create<SettingsStateProps>((set, get) => ({
lookup: {},
endpoint: ApiEndpoints.settings_user_list,
fetchSettings: async () => {
const { isLoggedIn } = useUserState.getState();
if (!isLoggedIn()) {
return;
}

View File

@ -6,8 +6,8 @@ import { StatusCodeListInterface } from '../components/render/StatusRenderer';
import { statusCodeList } from '../defaults/backendMappings';
import { ApiEndpoints } from '../enums/ApiEndpoints';
import { ModelType } from '../enums/ModelType';
import { isLoggedIn } from '../functions/auth';
import { apiUrl } from './ApiState';
import { useUserState } from './UserState';
type StatusLookup = Record<ModelType | string, StatusCodeListInterface>;
@ -23,6 +23,8 @@ export const useGlobalStatusState = create<ServerStateProps>()(
status: undefined,
setStatus: (newStatus: StatusLookup) => set({ status: newStatus }),
fetchStatus: async () => {
const { isLoggedIn } = useUserState.getState();
// Fetch status data for rendering labels
if (!isLoggedIn()) {
return;

View File

@ -1,22 +1,28 @@
import { create } from 'zustand';
import { api } from '../App';
import { api, setApiDefaults } from '../App';
import { ApiEndpoints } from '../enums/ApiEndpoints';
import { UserPermissions, UserRoles } from '../enums/Roles';
import { isLoggedIn } from '../functions/auth';
import { clearCsrfCookie } from '../functions/auth';
import { apiUrl } from './ApiState';
import { UserProps } from './states';
interface UserStateProps {
user: UserProps | undefined;
token: string | undefined;
username: () => string;
setUser: (newUser: UserProps) => void;
setToken: (newToken: string) => void;
clearToken: () => void;
fetchUserToken: () => void;
fetchUserState: () => void;
clearUserState: () => void;
checkUserRole: (role: UserRoles, permission: UserPermissions) => boolean;
hasDeleteRole: (role: UserRoles) => boolean;
hasChangeRole: (role: UserRoles) => boolean;
hasAddRole: (role: UserRoles) => boolean;
hasViewRole: (role: UserRoles) => boolean;
isLoggedIn: () => boolean;
isStaff: () => boolean;
isSuperuser: () => boolean;
}
@ -26,6 +32,15 @@ interface UserStateProps {
*/
export const useUserState = create<UserStateProps>((set, get) => ({
user: undefined,
token: undefined,
setToken: (newToken: string) => {
set({ token: newToken });
setApiDefaults();
},
clearToken: () => {
set({ token: undefined });
setApiDefaults();
},
username: () => {
const user: UserProps = get().user as UserProps;
@ -36,9 +51,29 @@ export const useUserState = create<UserStateProps>((set, get) => ({
}
},
setUser: (newUser: UserProps) => set({ user: newUser }),
clearUserState: () => {
set({ user: undefined });
set({ token: undefined });
clearCsrfCookie();
setApiDefaults();
},
fetchUserToken: async () => {
await api
.get(apiUrl(ApiEndpoints.user_token))
.then((response) => {
if (response.status == 200 && response.data.token) {
get().setToken(response.data.token);
} else {
get().clearToken();
}
})
.catch(() => {
get().clearToken();
});
},
fetchUserState: async () => {
if (!isLoggedIn()) {
return;
if (!get().token) {
await get().fetchUserToken();
}
// Fetch user data
@ -47,35 +82,48 @@ export const useUserState = create<UserStateProps>((set, get) => ({
timeout: 2000
})
.then((response) => {
const user: UserProps = {
pk: response.data.pk,
first_name: response.data?.first_name ?? '',
last_name: response.data?.last_name ?? '',
email: response.data.email,
username: response.data.username
};
set({ user: user });
if (response.status == 200) {
const user: UserProps = {
pk: response.data.pk,
first_name: response.data?.first_name ?? '',
last_name: response.data?.last_name ?? '',
email: response.data.email,
username: response.data.username
};
set({ user: user });
} else {
get().clearUserState();
}
})
.catch((error) => {
console.error('ERR: Error fetching user data');
.catch(() => {
get().clearUserState();
});
if (!get().isLoggedIn()) {
return;
}
// Fetch role data
await api
.get(apiUrl(ApiEndpoints.user_roles))
.then((response) => {
const user: UserProps = get().user as UserProps;
if (response.status == 200) {
const user: UserProps = get().user as UserProps;
// Update user with role data
if (user) {
user.roles = response.data?.roles ?? {};
user.is_staff = response.data?.is_staff ?? false;
user.is_superuser = response.data?.is_superuser ?? false;
set({ user: user });
// Update user with role data
if (user) {
user.roles = response.data?.roles ?? {};
user.is_staff = response.data?.is_staff ?? false;
user.is_superuser = response.data?.is_superuser ?? false;
set({ user: user });
}
} else {
get().clearUserState();
}
})
.catch((_error) => {
console.error('ERR: Error fetching user roles');
get().clearUserState();
});
},
checkUserRole: (role: UserRoles, permission: UserPermissions) => {
@ -93,6 +141,13 @@ export const useUserState = create<UserStateProps>((set, get) => ({
return user?.roles[role]?.includes(permission) ?? false;
},
isLoggedIn: () => {
if (!get().token) {
return false;
}
const user: UserProps = get().user as UserProps;
return !!user && !!user.pk;
},
isStaff: () => {
const user: UserProps = get().user as UserProps;
return user?.is_staff ?? false;

View File

@ -1,5 +1,4 @@
import { setApiDefaults } from '../App';
import { isLoggedIn } from '../functions/auth';
import { useServerApiState } from './ApiState';
import { useGlobalSettingsState, useUserSettingsState } from './SettingsState';
import { useGlobalStatusState } from './StatusState';
@ -126,6 +125,8 @@ export type SettingsLookup = {
* Necessary on login, or if locale is changed.
*/
export function fetchGlobalStates() {
const { isLoggedIn } = useUserState.getState();
if (!isLoggedIn()) {
return;
}

View File

@ -1,46 +1,22 @@
import { QueryClientProvider } from '@tanstack/react-query';
import { useEffect, useState } from 'react';
import { useEffect } from 'react';
import { BrowserRouter } from 'react-router-dom';
import { queryClient } from '../App';
import { BaseContext } from '../contexts/BaseContext';
import { defaultHostList } from '../defaults/defaultHostList';
import { isLoggedIn } from '../functions/auth';
import { base_url } from '../main';
import { routes } from '../router';
import { useLocalState } from '../states/LocalState';
import {
useGlobalSettingsState,
useUserSettingsState
} from '../states/SettingsState';
import { useUserState } from '../states/UserState';
export default function DesktopAppView() {
const [hostList] = useLocalState((state) => [state.hostList]);
const [fetchUserState] = useUserState((state) => [state.fetchUserState]);
const [fetchGlobalSettings] = useGlobalSettingsState((state) => [
state.fetchSettings
]);
const [fetchUserSettings] = useUserSettingsState((state) => [
state.fetchSettings
]);
// Server Session
const [fetchedServerSession, setFetchedServerSession] = useState(false);
useEffect(() => {
if (Object.keys(hostList).length === 0) {
useLocalState.setState({ hostList: defaultHostList });
}
if (isLoggedIn() && !fetchedServerSession) {
setFetchedServerSession(true);
fetchUserState();
fetchGlobalSettings();
fetchUserSettings();
}
}, [fetchedServerSession]);
}, [hostList]);
return (
<BaseContext>

View File

@ -59,6 +59,8 @@ export const test = baseTest.extend({
if (
msg.type() === 'error' &&
!msg.text().startsWith('ERR: ') &&
url != 'http://localhost:8000/api/user/me/' &&
url != 'http://localhost:8000/api/user/token/' &&
url != 'http://localhost:8000/api/barcode/' &&
url != 'http://localhost:8000/api/news/?search=&offset=0&limit=25' &&
url != 'https://docs.inventree.org/en/versions.json' &&

View File

@ -9,7 +9,6 @@ export const doLogin = async (page, username?: string, password?: string) => {
password = password ?? user.password;
await page.goto(logoutUrl);
await page.goto(loginUrl);
await expect(page).toHaveTitle(RegExp('^InvenTree.*$'));
await page.waitForURL('**/platform/login');
await page.getByLabel('username').fill(username);

View File

@ -1,5 +1,5 @@
import { expect, test } from './baseFixtures.js';
import { baseUrl, user } from './defaults.js';
import { baseUrl, loginUrl, user } from './defaults.js';
import { doLogin, doQuickLogin } from './login.js';
test('PUI - Basic Login Test', async ({ page }) => {
@ -17,6 +17,22 @@ test('PUI - Basic Login Test', async ({ page }) => {
await page
.getByRole('heading', { name: `Welcome to your Dashboard, ${user.name}` })
.click();
// Check that the username is provided
await page.getByText(user.username);
await expect(page).toHaveTitle(RegExp('^InvenTree'));
// Go to the dashboard
await page.goto(baseUrl);
await page.waitForURL('**/platform');
// Logout (via menu)
await page.getByRole('button', { name: 'Ally Access' }).click();
await page.getByRole('menuitem', { name: 'Logout' }).click();
await page.waitForURL('**/platform/login');
await page.getByLabel('username');
});
test('PUI - Quick Login Test', async ({ page }) => {
@ -34,4 +50,8 @@ test('PUI - Quick Login Test', async ({ page }) => {
await page
.getByRole('heading', { name: `Welcome to your Dashboard, ${user.name}` })
.click();
// Logout (via URL)
await page.goto(`${baseUrl}/logout/`);
await page.waitForURL('**/platform/login');
});

View File

@ -71,9 +71,10 @@ test('PUI - Parts - Supplier Parts', async ({ page }) => {
test('PUI - Sales', async ({ page }) => {
await doQuickLogin(page);
await page.goto(`${baseUrl}/sales/`);
await page.goto(`${baseUrl}/sales/index/`);
await page.waitForURL('**/platform/sales/**');
await page.getByRole('tab', { name: 'Sales Orders' }).click();
await page.waitForURL('**/platform/sales/index/salesorders');
await page.getByRole('tab', { name: 'Return Orders' }).click();

View File

@ -5,8 +5,12 @@ import { doQuickLogin } from './login.js';
test('PUI - Stock', async ({ page }) => {
await doQuickLogin(page);
await page.goto(`${baseUrl}/stock`);
await page.goto(`${baseUrl}/stock/location/index/`);
await page.waitForURL('**/platform/stock/location/**');
await page.getByRole('tab', { name: 'Location Details' }).click();
await page.waitForURL('**/platform/stock/location/index/details');
await page.getByRole('tab', { name: 'Stock Items' }).click();
await page.getByRole('cell', { name: '1551ABK' }).click();
await page.getByRole('tab', { name: 'Stock', exact: true }).click();