Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/navigate to model confirmation after creation #348

181 changes: 106 additions & 75 deletions frontend/src/app/providers/models-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ import {
TOAST_NOTIFICATIONS,
} from "@/constants";
import { BASE_MODELS, TrainingDatasetOption, TrainingType } from "@/enums";
import { HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY } from "@/config";
import { HOT_FAIR_MODEL_CREATION_LOCAL_STORAGE_KEY } from "@/config";
import { LngLatBoundsLike } from "maplibre-gl";
import { useCreateTrainingDataset } from "@/features/model-creation/hooks/use-training-datasets";
import { useGetTrainingDataset } from "@/features/models/hooks/use-dataset";
import { useLocation, useNavigate, useParams } from "react-router-dom";
import { useModelDetails } from "@/features/models/hooks/use-models";
import { UseMutationResult } from "@tanstack/react-query";
import { useSessionStorage } from "@/hooks/use-storage";
import { useLocalStorage } from "@/hooks/use-storage";

import {
TModelDetails,
TTrainingAreaFeature,
TTrainingDataset,
TTrainingDetails,
Expand All @@ -26,6 +26,7 @@ import {
} from "@/utils";
import React, {
createContext,
useCallback,
useContext,
useEffect,
useMemo,
Expand All @@ -41,6 +42,8 @@ import {
useCreateModelTrainingRequest,
useUpdateModel,
} from "@/features/model-creation/hooks/use-models";
import axios from "axios";
import { useAuth } from "./auth-provider";

/**
* The names here are the same with the `initialFormState` object keys.
Expand Down Expand Up @@ -229,6 +232,10 @@ const ModelsContext = createContext<{
handleModelCreationAndUpdate: () => void;
handleTrainingDatasetCreation: () => void;
validateEditMode: boolean;
isError: boolean;
isPending: boolean;
data: TModelDetails;
isModelOwner: boolean;
}>({
formData: initialFormState,
setFormData: () => {},
Expand All @@ -255,23 +262,24 @@ const ModelsContext = createContext<{
trainingDatasetCreationInProgress: false,
handleTrainingDatasetCreation: () => {},
validateEditMode: false,
isPending: false,
isError: false,
data: {} as TModelDetails,
isModelOwner: false,
});

export const ModelsProvider: React.FC<{
children: React.ReactNode;
}> = ({ children }) => {
const navigate = useNavigate();
const { pathname } = useLocation();
const { modelId } = useParams();
const { getSessionValue, setSessionValue, removeSessionValue } =
useSessionStorage();

const storedFormData = getSessionValue(
HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY,
);
const { modelId, id } = useParams();
const { setValue, removeValue, getValue } = useLocalStorage();
const storedFormData = getValue(HOT_FAIR_MODEL_CREATION_LOCAL_STORAGE_KEY);
const [formData, setFormData] = useState<typeof initialFormState>(
storedFormData ? JSON.parse(storedFormData) : initialFormState,
);
const { user, isAuthenticated } = useAuth();

const handleChange = (
field: string,
Expand All @@ -285,8 +293,8 @@ export const ModelsProvider: React.FC<{
) => {
setFormData((prev) => {
const updatedData = { ...prev, [field]: value };
setSessionValue(
HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY,
setValue(
HOT_FAIR_MODEL_CREATION_LOCAL_STORAGE_KEY,
JSON.stringify(updatedData),
);
return updatedData;
Expand All @@ -300,31 +308,43 @@ export const ModelsProvider: React.FC<{

const isEditMode = Boolean(modelId && !pathname.includes("new"));

const { data, isPending, isError } = useModelDetails(
modelId as string,
isEditMode,
const { data, isPending, isError, error } = useModelDetails(
id ?? (modelId as string),
!!id || !!modelId,
10000,
);

const {
data: trainingDataset,
isPending: trainingDatasetIsPending,
isError: trainingDatasetIsError,
} = useGetTrainingDataset(
Number(data?.dataset),
Boolean(isEditMode && data?.dataset),
);
const isModelOwner = isAuthenticated && data?.user?.osm_id === user?.osm_id;

// Will be used in the route validator component to delay the redirection for a while until the data are retrieved
const validateEditMode =
formData.selectedTrainingDatasetId !== "" && formData.tmsURL !== "";

// Fetch and prefill model details
useEffect(() => {
if (!isEditMode || isPending || !data) return;

if (isError) {
navigate(APPLICATION_ROUTES.NOTFOUND);
if (isError && error) {
const currentPath = pathname;
if (axios.isAxiosError(error)) {
navigate(APPLICATION_ROUTES.NOTFOUND, {
state: {
from: currentPath,
error: error.response?.data?.detail,
},
});
} else {
const err = error as Error;
navigate(APPLICATION_ROUTES.NOTFOUND, {
state: {
from: currentPath,
error: err.message,
},
});
}
}
}, [isError, error, navigate]);

// Fetch and prefill model details and training dataset
useEffect(() => {
if (!isEditMode || isPending || !data || isError) return;

handleChange(MODEL_CREATION_FORM_NAME.BASE_MODELS, data.base_model);
handleChange(
Expand All @@ -334,50 +354,36 @@ export const ModelsProvider: React.FC<{
handleChange(MODEL_CREATION_FORM_NAME.MODEL_NAME, data.name ?? "");
handleChange(
MODEL_CREATION_FORM_NAME.SELECTED_TRAINING_DATASET_ID,
data.dataset,
data.dataset.id,
);
}, [isEditMode, isError, isPending, data]);

// Fetch and prefill training dataset
useEffect(() => {
if (!isEditMode || trainingDatasetIsPending || trainingDatasetIsError)
return;
handleChange(
MODEL_CREATION_FORM_NAME.DATASET_NAME,
trainingDataset.name ?? "",
data.dataset.name ?? "",
);
handleChange(
MODEL_CREATION_FORM_NAME.TMS_URL,
trainingDataset.source_imagery ?? "",
data.dataset.source_imagery ?? "",
);
}, [
isEditMode,
trainingDatasetIsPending,
trainingDataset,
trainingDatasetIsError,
]);
}, [isEditMode, isError, isPending, data]);

const resetState = () => {
removeValue(HOT_FAIR_MODEL_CREATION_LOCAL_STORAGE_KEY);
setFormData(initialFormState);
};

useEffect(() => {
// Cleanup the timeout on component unmount
return () => {
removeSessionValue(HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY);
if (timeOutRef.current) {
clearTimeout(timeOutRef.current);
}
};
}, []);

const resetState = () => {
removeSessionValue(HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY);
setFormData(initialFormState);
};

const createNewTrainingRequestMutation = useCreateModelTrainingRequest({
mutationConfig: {
onSuccess: () => {
showSuccessToast(TOAST_NOTIFICATIONS.trainingRequestSubmittedSuccess);
// Reset the state after 2 second on the model success page.
resetState();
},
onError: (error) => {
showErrorToast(error);
Expand Down Expand Up @@ -405,22 +411,29 @@ export const ModelsProvider: React.FC<{
},
});

const handleModelCreationOrUpdateSuccess = (modelId: string) => {
showSuccessToast(
isEditMode
? TOAST_NOTIFICATIONS.modelUpdateSuccess
: TOAST_NOTIFICATIONS.modelCreationSuccess,
);
navigate(`${getFullPath(MODELS_ROUTES.CONFIRMATION)}?id=${modelId}`);
// Submit the model for training request
const submitTrainingRequest = useCallback(() => {
createNewTrainingRequestMutation.mutate({
model: modelId,
model: modelId as string,
input_boundary_width: formData.boundaryWidth,
input_contact_spacing: formData.contactSpacing,
epochs: formData.epoch,
batch_size: formData.batchSize,
zoom_level: formData.zoomLevels,
});
}, [formData, modelId]);

const handleModelCreationOrUpdateSuccess = (id?: string) => {
if (isModelOwner) {
showSuccessToast(
isEditMode
? TOAST_NOTIFICATIONS.modelUpdateSuccess
: TOAST_NOTIFICATIONS.modelCreationSuccess,
);
}

navigate(`${getFullPath(MODELS_ROUTES.CONFIRMATION)}?id=${id ?? modelId}`);
// Submit the model for training request
submitTrainingRequest();
};

const modelCreateMutation = useCreateModel({
Expand All @@ -446,19 +459,23 @@ export const ModelsProvider: React.FC<{
modelId: modelId as string,
});

// Confirm that all the training areas labels has been retrieved
const hasLabeledTrainingAreas =
formData.trainingAreas.length > 0 &&
formData.trainingAreas.filter(
(aoi: TTrainingAreaFeature) => aoi.properties.label_fetched === null,
).length === 0;
// Confirm that all the training areas labels have been fetched.
const hasLabeledTrainingAreas = useMemo(
() =>
formData.trainingAreas.every(
(aoi: TTrainingAreaFeature) => aoi.properties.label_fetched !== null,
),
[formData],
);

// Confirm that all of the training areas has a geometry
const hasAOIsWithGeometry =
formData.trainingAreas.length > 0 &&
formData.trainingAreas.filter(
(aoi: TTrainingAreaFeature) => aoi.geometry === null,
).length === 0;
// Confirm that all of the training areas have a geometry.
const hasAOIsWithGeometry = useMemo(
() =>
formData.trainingAreas.every(
(aoi: TTrainingAreaFeature) => aoi.geometry !== null,
),
[formData],
);

const handleTrainingDatasetCreation = () => {
createNewTrainingDatasetMutation.mutate({
Expand All @@ -470,14 +487,20 @@ export const ModelsProvider: React.FC<{
createNewTrainingDatasetMutation.isPending;

const handleModelCreationAndUpdate = () => {
if (isEditMode) {
// The user is trying to edit their model.
// In this case, send a PATCH request and submit a training request.
if (isEditMode && isModelOwner) {
modelUpdateMutation.mutate({
dataset: formData.selectedTrainingDatasetId,
name: formData.modelName,
description: formData.modelDescription,
base_model: formData.baseModel as BASE_MODELS,
modelId: modelId as string,
});
// The user is trying to edit another users model training area and settings.
// In this case, directly submit a training request.
} else if (isEditMode && !isModelOwner) {
handleModelCreationOrUpdateSuccess();
} else {
modelCreateMutation.mutate({
dataset: formData.selectedTrainingDatasetId,
Expand All @@ -496,7 +519,6 @@ export const ModelsProvider: React.FC<{
hasLabeledTrainingAreas,
hasAOIsWithGeometry,
formData,
resetState,
createNewTrainingRequestMutation,
isEditMode,
modelId,
Expand All @@ -505,6 +527,11 @@ export const ModelsProvider: React.FC<{
handleTrainingDatasetCreation,
trainingDatasetCreationInProgress,
validateEditMode,
resetState,
data,
isPending,
isError,
isModelOwner,
}),
[
setFormData,
Expand All @@ -513,15 +540,19 @@ export const ModelsProvider: React.FC<{
createNewTrainingDatasetMutation,
hasLabeledTrainingAreas,
hasAOIsWithGeometry,
resetState,
createNewTrainingRequestMutation,
isEditMode,
modelId,
resetState,
getFullPath,
handleModelCreationAndUpdate,
handleTrainingDatasetCreation,
trainingDatasetCreationInProgress,
validateEditMode,
data,
isPending,
isError,
isModelOwner,
],
);

Expand Down
11 changes: 8 additions & 3 deletions frontend/src/app/router.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
RouterProvider,
createBrowserRouter,
} from "react-router-dom";
import { ModelsProvider } from "@/app/providers/models-provider";

const router = createBrowserRouter([
{
Expand Down Expand Up @@ -71,7 +72,12 @@ const router = createBrowserRouter([
"@/app/routes/models/model-details-card"
);
return {
Component: () => <ModelDetailsPage />,
Component: () => (
<ModelsProvider>
{" "}
<ModelDetailsPage />
</ModelsProvider>
),
};
},
},
Expand Down Expand Up @@ -333,11 +339,10 @@ const router = createBrowserRouter([
return { Component: AuthenticationCallbackPage };
},
},

/**
* Auth route ends.
*/

/**
* 404 route
*/
Expand Down
Loading
Loading