Skip to content

Commit

Permalink
Merge pull request #348 from jeafreezy/fix/navigate-to-model-confirma…
Browse files Browse the repository at this point in the history
…tion-after-creation
  • Loading branch information
kshitijrajsharma authored Feb 27, 2025
2 parents 51459df + 655ccfc commit a2a37bd
Show file tree
Hide file tree
Showing 36 changed files with 671 additions and 554 deletions.
181 changes: 106 additions & 75 deletions frontend/src/app/providers/models-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@ import {
TOAST_NOTIFICATIONS,
} from "@/constants";
import { BASE_MODELS, TrainingDatasetOption, TrainingType } from "@/enums";
import { HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY } from "@/config";
import { HOT_FAIR_MODEL_CREATION_LOCAL_STORAGE_KEY } from "@/config";
import { LngLatBoundsLike } from "maplibre-gl";
import { useCreateTrainingDataset } from "@/features/model-creation/hooks/use-training-datasets";
import { useGetTrainingDataset } from "@/features/models/hooks/use-dataset";
import { useLocation, useNavigate, useParams } from "react-router-dom";
import { useModelDetails } from "@/features/models/hooks/use-models";
import { UseMutationResult } from "@tanstack/react-query";
import { useSessionStorage } from "@/hooks/use-storage";
import { useLocalStorage } from "@/hooks/use-storage";

import {
TModelDetails,
TTrainingAreaFeature,
TTrainingDataset,
TTrainingDetails,
Expand All @@ -26,6 +26,7 @@ import {
} from "@/utils";
import React, {
createContext,
useCallback,
useContext,
useEffect,
useMemo,
Expand All @@ -41,6 +42,8 @@ import {
useCreateModelTrainingRequest,
useUpdateModel,
} from "@/features/model-creation/hooks/use-models";
import axios from "axios";
import { useAuth } from "./auth-provider";

/**
* The names here are the same with the `initialFormState` object keys.
Expand Down Expand Up @@ -229,6 +232,10 @@ const ModelsContext = createContext<{
handleModelCreationAndUpdate: () => void;
handleTrainingDatasetCreation: () => void;
validateEditMode: boolean;
isError: boolean;
isPending: boolean;
data: TModelDetails;
isModelOwner: boolean;
}>({
formData: initialFormState,
setFormData: () => {},
Expand All @@ -255,23 +262,24 @@ const ModelsContext = createContext<{
trainingDatasetCreationInProgress: false,
handleTrainingDatasetCreation: () => {},
validateEditMode: false,
isPending: false,
isError: false,
data: {} as TModelDetails,
isModelOwner: false,
});

export const ModelsProvider: React.FC<{
children: React.ReactNode;
}> = ({ children }) => {
const navigate = useNavigate();
const { pathname } = useLocation();
const { modelId } = useParams();
const { getSessionValue, setSessionValue, removeSessionValue } =
useSessionStorage();

const storedFormData = getSessionValue(
HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY,
);
const { modelId, id } = useParams();
const { setValue, removeValue, getValue } = useLocalStorage();
const storedFormData = getValue(HOT_FAIR_MODEL_CREATION_LOCAL_STORAGE_KEY);
const [formData, setFormData] = useState<typeof initialFormState>(
storedFormData ? JSON.parse(storedFormData) : initialFormState,
);
const { user, isAuthenticated } = useAuth();

const handleChange = (
field: string,
Expand All @@ -285,8 +293,8 @@ export const ModelsProvider: React.FC<{
) => {
setFormData((prev) => {
const updatedData = { ...prev, [field]: value };
setSessionValue(
HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY,
setValue(
HOT_FAIR_MODEL_CREATION_LOCAL_STORAGE_KEY,
JSON.stringify(updatedData),
);
return updatedData;
Expand All @@ -300,31 +308,43 @@ export const ModelsProvider: React.FC<{

const isEditMode = Boolean(modelId && !pathname.includes("new"));

const { data, isPending, isError } = useModelDetails(
modelId as string,
isEditMode,
const { data, isPending, isError, error } = useModelDetails(
id ?? (modelId as string),
!!id || !!modelId,
10000,
);

const {
data: trainingDataset,
isPending: trainingDatasetIsPending,
isError: trainingDatasetIsError,
} = useGetTrainingDataset(
Number(data?.dataset),
Boolean(isEditMode && data?.dataset),
);
const isModelOwner = isAuthenticated && data?.user?.osm_id === user?.osm_id;

// Will be used in the route validator component to delay the redirection for a while until the data are retrieved
const validateEditMode =
formData.selectedTrainingDatasetId !== "" && formData.tmsURL !== "";

// Fetch and prefill model details
useEffect(() => {
if (!isEditMode || isPending || !data) return;

if (isError) {
navigate(APPLICATION_ROUTES.NOTFOUND);
if (isError && error) {
const currentPath = pathname;
if (axios.isAxiosError(error)) {
navigate(APPLICATION_ROUTES.NOTFOUND, {
state: {
from: currentPath,
error: error.response?.data?.detail,
},
});
} else {
const err = error as Error;
navigate(APPLICATION_ROUTES.NOTFOUND, {
state: {
from: currentPath,
error: err.message,
},
});
}
}
}, [isError, error, navigate]);

// Fetch and prefill model details and training dataset
useEffect(() => {
if (!isEditMode || isPending || !data || isError) return;

handleChange(MODEL_CREATION_FORM_NAME.BASE_MODELS, data.base_model);
handleChange(
Expand All @@ -334,50 +354,36 @@ export const ModelsProvider: React.FC<{
handleChange(MODEL_CREATION_FORM_NAME.MODEL_NAME, data.name ?? "");
handleChange(
MODEL_CREATION_FORM_NAME.SELECTED_TRAINING_DATASET_ID,
data.dataset,
data.dataset.id,
);
}, [isEditMode, isError, isPending, data]);

// Fetch and prefill training dataset
useEffect(() => {
if (!isEditMode || trainingDatasetIsPending || trainingDatasetIsError)
return;
handleChange(
MODEL_CREATION_FORM_NAME.DATASET_NAME,
trainingDataset.name ?? "",
data.dataset.name ?? "",
);
handleChange(
MODEL_CREATION_FORM_NAME.TMS_URL,
trainingDataset.source_imagery ?? "",
data.dataset.source_imagery ?? "",
);
}, [
isEditMode,
trainingDatasetIsPending,
trainingDataset,
trainingDatasetIsError,
]);
}, [isEditMode, isError, isPending, data]);

const resetState = () => {
removeValue(HOT_FAIR_MODEL_CREATION_LOCAL_STORAGE_KEY);
setFormData(initialFormState);
};

useEffect(() => {
// Cleanup the timeout on component unmount
return () => {
removeSessionValue(HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY);
if (timeOutRef.current) {
clearTimeout(timeOutRef.current);
}
};
}, []);

const resetState = () => {
removeSessionValue(HOT_FAIR_MODEL_CREATION_SESSION_STORAGE_KEY);
setFormData(initialFormState);
};

const createNewTrainingRequestMutation = useCreateModelTrainingRequest({
mutationConfig: {
onSuccess: () => {
showSuccessToast(TOAST_NOTIFICATIONS.trainingRequestSubmittedSuccess);
// Reset the state after 2 second on the model success page.
resetState();
},
onError: (error) => {
showErrorToast(error);
Expand Down Expand Up @@ -405,22 +411,29 @@ export const ModelsProvider: React.FC<{
},
});

const handleModelCreationOrUpdateSuccess = (modelId: string) => {
showSuccessToast(
isEditMode
? TOAST_NOTIFICATIONS.modelUpdateSuccess
: TOAST_NOTIFICATIONS.modelCreationSuccess,
);
navigate(`${getFullPath(MODELS_ROUTES.CONFIRMATION)}?id=${modelId}`);
// Submit the model for training request
const submitTrainingRequest = useCallback(() => {
createNewTrainingRequestMutation.mutate({
model: modelId,
model: modelId as string,
input_boundary_width: formData.boundaryWidth,
input_contact_spacing: formData.contactSpacing,
epochs: formData.epoch,
batch_size: formData.batchSize,
zoom_level: formData.zoomLevels,
});
}, [formData, modelId]);

const handleModelCreationOrUpdateSuccess = (id?: string) => {
if (isModelOwner) {
showSuccessToast(
isEditMode
? TOAST_NOTIFICATIONS.modelUpdateSuccess
: TOAST_NOTIFICATIONS.modelCreationSuccess,
);
}

navigate(`${getFullPath(MODELS_ROUTES.CONFIRMATION)}?id=${id ?? modelId}`);
// Submit the model for training request
submitTrainingRequest();
};

const modelCreateMutation = useCreateModel({
Expand All @@ -446,19 +459,23 @@ export const ModelsProvider: React.FC<{
modelId: modelId as string,
});

// Confirm that all the training areas labels has been retrieved
const hasLabeledTrainingAreas =
formData.trainingAreas.length > 0 &&
formData.trainingAreas.filter(
(aoi: TTrainingAreaFeature) => aoi.properties.label_fetched === null,
).length === 0;
// Confirm that all the training areas labels have been fetched.
const hasLabeledTrainingAreas = useMemo(
() =>
formData.trainingAreas.every(
(aoi: TTrainingAreaFeature) => aoi.properties.label_fetched !== null,
),
[formData],
);

// Confirm that all of the training areas has a geometry
const hasAOIsWithGeometry =
formData.trainingAreas.length > 0 &&
formData.trainingAreas.filter(
(aoi: TTrainingAreaFeature) => aoi.geometry === null,
).length === 0;
// Confirm that all of the training areas have a geometry.
const hasAOIsWithGeometry = useMemo(
() =>
formData.trainingAreas.every(
(aoi: TTrainingAreaFeature) => aoi.geometry !== null,
),
[formData],
);

const handleTrainingDatasetCreation = () => {
createNewTrainingDatasetMutation.mutate({
Expand All @@ -470,14 +487,20 @@ export const ModelsProvider: React.FC<{
createNewTrainingDatasetMutation.isPending;

const handleModelCreationAndUpdate = () => {
if (isEditMode) {
// The user is trying to edit their model.
// In this case, send a PATCH request and submit a training request.
if (isEditMode && isModelOwner) {
modelUpdateMutation.mutate({
dataset: formData.selectedTrainingDatasetId,
name: formData.modelName,
description: formData.modelDescription,
base_model: formData.baseModel as BASE_MODELS,
modelId: modelId as string,
});
// The user is trying to edit another users model training area and settings.
// In this case, directly submit a training request.
} else if (isEditMode && !isModelOwner) {
handleModelCreationOrUpdateSuccess();
} else {
modelCreateMutation.mutate({
dataset: formData.selectedTrainingDatasetId,
Expand All @@ -496,7 +519,6 @@ export const ModelsProvider: React.FC<{
hasLabeledTrainingAreas,
hasAOIsWithGeometry,
formData,
resetState,
createNewTrainingRequestMutation,
isEditMode,
modelId,
Expand All @@ -505,6 +527,11 @@ export const ModelsProvider: React.FC<{
handleTrainingDatasetCreation,
trainingDatasetCreationInProgress,
validateEditMode,
resetState,
data,
isPending,
isError,
isModelOwner,
}),
[
setFormData,
Expand All @@ -513,15 +540,19 @@ export const ModelsProvider: React.FC<{
createNewTrainingDatasetMutation,
hasLabeledTrainingAreas,
hasAOIsWithGeometry,
resetState,
createNewTrainingRequestMutation,
isEditMode,
modelId,
resetState,
getFullPath,
handleModelCreationAndUpdate,
handleTrainingDatasetCreation,
trainingDatasetCreationInProgress,
validateEditMode,
data,
isPending,
isError,
isModelOwner,
],
);

Expand Down
11 changes: 8 additions & 3 deletions frontend/src/app/router.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
RouterProvider,
createBrowserRouter,
} from "react-router-dom";
import { ModelsProvider } from "@/app/providers/models-provider";

const router = createBrowserRouter([
{
Expand Down Expand Up @@ -71,7 +72,12 @@ const router = createBrowserRouter([
"@/app/routes/models/model-details-card"
);
return {
Component: () => <ModelDetailsPage />,
Component: () => (
<ModelsProvider>
{" "}
<ModelDetailsPage />
</ModelsProvider>
),
};
},
},
Expand Down Expand Up @@ -333,11 +339,10 @@ const router = createBrowserRouter([
return { Component: AuthenticationCallbackPage };
},
},

/**
* Auth route ends.
*/

/**
* 404 route
*/
Expand Down
Loading

0 comments on commit a2a37bd

Please sign in to comment.