import React, { useEffect, useState } from "react";
import { useNavigate, useLocation } from "react-router-dom";
import { FormProvider, useForm } from "react-hook-form";
import ModelSelect from "./components/model_select";
import ImageDetails from "./components/image_details";
import StepWrapper from "./components/step_wrapper";
import { Button } from "../../common/components/ui/button";
import LoadingPage from "./loading_page";
import { GenerativeModel, SelectedModels } from "./models/image_generator";
import WizardFooter from "../../common/components/ui/wizard_footer";
import { useGenerativeAPI } from "../../api/generative_ai_api";
import { aspectRatioOptions } from "./models/constants";

const errors = {};



const ImageGenerator = () => {
    const location = useLocation();
    const generativeData = location.state?.generativeData;
    const form = useForm();
    const navigate = useNavigate();
    const [activeStep, setActiveStep] = useState(0);
    const [taskQueueId, setTaskQueueId] = useState<string | null>(null);

    const [actionToast, setActionToast] = useState<
        {
            showToast: boolean,
            toastMessage: string,
        }>({
            showToast: false,
            toastMessage: "",
        });

    const [isLoadingPage, setIsLoadingPage] = useState(false);
    const [progress, setProgress] = useState({ percent_complete: 0, progress_message: "Generating image..." });

    const [modelList, setModelList] = useState<GenerativeModel[]>([]);
    const [selectedModels, setSelectedModels] = useState<SelectedModels>({
        product: null,
        brand: null,
        talent: null,
    })

    let intervalId: NodeJS.Timeout;

    const { getModels, generateImageMedia, checkTaskQueueStatus } = useGenerativeAPI();

    useEffect(() => {
        const fetchData = async () => {
            try {
                const [productResponse, brandResponse, talentResponse] = await Promise.all([
                    getModels("USER_PRODUCT"),
                    getModels("BRAND"),
                    getModels("LICENSABLE_PROPERTY"),
                ]);

                const combinedModels = [
                    ...productResponse.map(model => ({ ...model, model_type: "product" })),
                    ...brandResponse.map(model => ({ ...model, model_type: "brand" })),
                    ...talentResponse.map(model => ({ ...model, model_type: "talent" })),
                ]
                setModelList(combinedModels);
            } catch (error) {
                console.error("Error fetching models", error);
            }
        };
        fetchData();
    }, []);

    useEffect(() => {
        if (generativeData) {
            setActiveStep(3);

            const { prompt } = generativeData;

            const updatedSelectedModels: SelectedModels = {
                product: null,
                brand: null,
                talent: null,
            };

            modelList.forEach((model) => {
                const generativeTag = `#${model.generative_tag}`;

                if (prompt && prompt.includes(generativeTag)) {
                    if (model.entity_type === "USER_PRODUCT") {
                        updatedSelectedModels.product = model;
                    } else if (model.entity_type === "BRAND") {
                        updatedSelectedModels.brand = model;
                    } else if (model.entity_type === "LICENSABLE_PROPERTY") {
                        updatedSelectedModels.talent = model;
                    }
                }
            });

            setSelectedModels(updatedSelectedModels);
            form.setValue("prompt", generativeData.prompt);
            form.setValue("aspect_ratio", generativeData.aspect_ratio);

        }
    }, [generativeData, modelList, form]);

    // Remove selected models if they are not in the prompt;
    useEffect(() => {
        const prompt = form.watch("prompt") || "";
        const modelTypes: Array<keyof SelectedModels> = ["product", "brand", "talent"];

        modelTypes.forEach((modelType) => {
            const model = selectedModels[modelType];
            if (model && !prompt.includes(`#${model.generative_tag}`)) {
                setSelectedModels(prev => ({
                    ...prev,
                    [modelType]: null
                }));
            }
        });
    }, [form.watch("prompt")]);

    const handleStepChange = (step: number) => {
        setActiveStep(step);
    }

    const handleNext = () => {
        if (activeStep < steps.length - 1) {
            setActiveStep(activeStep + 1);
        }
    };

    const handlePrevious = () => {
        if (activeStep > 0) {
            setActiveStep(activeStep - 1);
        }
    };


    // TODO extract out prompt updates into a separate function
    const handleModelSelection = (modelType: keyof SelectedModels, model: GenerativeModel) => {
        const existingPrompt = form.getValues("prompt") || "";
        let updatedPrompt = existingPrompt;

        const previousModel = selectedModels[modelType];
        if (previousModel?.id === model.id) {
            const previousTag = `#${previousModel.generative_tag}`;
            updatedPrompt = updatedPrompt.replace(previousTag, "").trim();

            form.setValue("prompt", updatedPrompt);
            setSelectedModels((prev) => ({
                ...prev,
                [modelType]: null
            }));
        } else {
            if (previousModel) {
                const previousTag = `#${previousModel.generative_tag}`;
                updatedPrompt = updatedPrompt.replace(previousTag, "").trim();
            }

            if (model) {
                const newGenerativeTag = `#${model.generative_tag}`;
                updatedPrompt = `${updatedPrompt} ${newGenerativeTag}`.trim();
            }

            form.setValue("prompt", updatedPrompt);

            setSelectedModels((prev) => ({
                ...prev,
                [modelType]: model
            }));
        }
    }

    const steps = [
        {
            panel: "Product",
            heading: "Select product",
            component: <ModelSelect modelList={modelList.filter(model => model.entity_type === "USER_PRODUCT")}
                handleModelSelect={(model: any) => handleModelSelection("product", model)}
                selectedModel={selectedModels.product} />,
            optional: true,
            canAdvance: true,
        },
        {
            panel: "Brand",
            heading: "Select brand",
            component: <ModelSelect modelList={modelList.filter(model => model.entity_type === "BRAND")}
                handleModelSelect={(model: any) => handleModelSelection("brand", model)}
                selectedModel={selectedModels.brand} />,
            optional: true,
            canAdvance: true,
        },
        {
            panel: "Talent",
            heading: "Select talent",
            component: <ModelSelect
                modelList={modelList.filter(model => model.entity_type === "LICENSABLE_PROPERTY")}
                handleModelSelect={(model: any) => handleModelSelection("talent", model)}
                selectedModel={selectedModels.talent} />,
            optional: true,
            canAdvance: true,
        },
        {
            panel: "Image details",
            heading: "Image details",
            component: <ImageDetails form={form} errors={errors} navigateToModel={handleStepChange} selectedModels={selectedModels} aspectRatios={aspectRatioOptions} />,
            optional: false,
            canAdvance: false,
        },
    ]
    const currentStep = steps[activeStep];

    const generateImage = async (data: any) => {
        setIsLoadingPage(true);
        if (!data.prompt) {
            setIsLoadingPage(false);
            showToast("Please enter a prompt");
            return;
        }
        try {
            const modelWeights = [];
            if (selectedModels.talent) {
                modelWeights.push({ id: selectedModels.talent.id, weight: 0.8 });
            }
            if (selectedModels.product) {
                // Lower the weight of the product model if there is a talent model
                modelWeights.push({ id: selectedModels.product.id, weight: (modelWeights.length === 0 ? 0.8 : 0.2) });
            }
            const payload = {
                prompt: data.prompt,
                aspect_ratio: data.aspect_ratio ? data.aspect_ratio : "SQUARE",
                image_file_id: data.image_file_id,
                model_weights: modelWeights
            }

            const response = await generateImageMedia(payload);
            const taskQueueId = response.id;
            setTaskQueueId(taskQueueId);

            intervalId = setInterval(() => checkStatus(taskQueueId, intervalId), 1000);
            setTimeout(() => clearInterval(intervalId), 3 * 60 * 1000);

        } catch (error) {
            console.error("Error generating image", error);
        }
    }

    const checkStatus = async (taskQueueId: string, intervalId: NodeJS.Timeout) => {
        if (!taskQueueId) return;

        try {
            const response = await checkTaskQueueStatus(taskQueueId);
            const completedTask = response.find((task: any) => task.status === "completed" && task.progress_status !== "processing");

            if (completedTask && completedTask.progress_status === "failed") {
                clearInterval(intervalId);
                setIsLoadingPage(false);
                showToast("Failed to generate image: " + completedTask.progress_message);
                return;
            }

            if (completedTask && completedTask.progress_status === "completed") {
                clearInterval(intervalId);
                // const result = JSON.parse(completedTask.result);
                navigate(`/image/editor/${taskQueueId}`);
                setIsLoadingPage(false);
                return;
            }


            const latestTask = response[response.length - 1];

            setProgress({
                percent_complete: latestTask.percent_complete || 0,
                progress_message: latestTask.progress_message
            })

        } catch (error) {
            console.error("Error checking status", error);
        }
    }

    const showToast = (message: string) => {
        setActionToast({ showToast: true, toastMessage: message });
        setTimeout(() => {
            setActionToast(prevState => ({ ...prevState, showToast: false, toastMessage: "" }));
        }, 3000);
    }


    return (
        <div className="pb-11">
            {isLoadingPage ? <LoadingPage percentComplete={progress.percent_complete} progressMessage={progress.progress_message} /> :
                <div className="p-20 md:p-[120px]">
                    <FormProvider {...form}>
                        <form onSubmit={form.handleSubmit(generateImage)}>
                            {/* <Button type="submit" variant="primary-negative" className="z-50 absolute top-5 right-20">Generate image</Button> */}
                            <StepWrapper
                                steps={steps}
                                activeStep={activeStep}
                                onStepChange={handleStepChange}
                                optional={currentStep.optional}
                                heading={currentStep.heading}
                                selectedModels={selectedModels}
                            >
                                {currentStep.component}
                            </StepWrapper>
                        </form>


                        {actionToast.showToast && <div className="fixed bottom-[150px] rounded-3xl right-9 bg-black px-6 py-5 text-white w-[453px]">{actionToast.toastMessage}</div>}

                        <WizardFooter
                            panels={steps}
                            onPanelChange={handleStepChange}
                            next={handleNext}
                            previous={handlePrevious}
                            activePanel={activeStep}
                            onSubmit={form.handleSubmit(generateImage)}
                            forImageGenerator
                        />
                    </FormProvider>
                </div>
            }


        </div>
    )
}


export default ImageGenerator;


