import {useCallback, useState} from "react";
import {EventStreamContentType, fetchEventSource} from "@microsoft/fetch-event-source";
import {ChatCompletionResultInterface} from "@app/interfaces/ChatCompletionResultInterface";
import {ChatMessageInterface} from "@app/interfaces/ChatMessageInterface";

export const ChatCompletionRequestMessageRoleEnum = {
    System: 'system',
    User: 'user',
    Assistant: 'assistant',
    Function: 'function'
} as const;

type UseCompletionHelpers = {
    /** The current completion result */
    completion: string
    /** The error object of the API request */
    error: undefined | Error
    /**
     * Abort the current API request but keep the generated tokens.
     */
    stop: () => void
    /** The current value of the input */
    prompt: string
    /** Whether the API request is in progress */
    isLoading: boolean
    /** Whether the API request is completed */
    isComplete: boolean
    /**
     * Start the request to the API endpoint.
     * @param prompt
     */
    sendRequest: (
        prompt: string,
        useServerSideHistory?: boolean,
        previousMessages?: {
            message: ChatMessageInterface
        }[],
        model?: string,
        maxTokens?: number,
    ) => void
}

/**
 * Inspired by Vercel AI useCompletion hook
 */
const useEnhancedChatGpt = (
    {
        url = "/api/completion",
        model = "gpt-3.5-turbo",
        maxTokens = 200,
        apiToken = "token",
        apiDevice = "",
        onResponse,
        onFinish,
        onError
    }
): UseCompletionHelpers => {
    const [isLoading, setIsLoading] = useState<boolean>(false);
    const [isComplete, setIsComplete] = useState<boolean>(false);
    const [isError, setIsError] = useState<boolean>(false);

    const [prompt, setPrompt] = useState('')
    const [abortController, setAbortController] = useState<AbortController>(null);

    const [completion, setCompletion] = useState<string>('');
    const [error, setError] = useState<any>(null);

    const onResponseCallback = useCallback(
        (data: ChatCompletionResultInterface, prompt: string, completion: string) => {
            onResponse(data, prompt, completion);
        },
        [onResponse]
    );

    const onFinishCallback = useCallback(
        (data: ChatCompletionResultInterface, prompt: string, completion: string) => {
            onFinish(data, prompt, completion);
        },
        [onFinish]
    );

    const onErrorCallback = useCallback(
        (error: any) => {
            onError(error);
        },
        [onError]
    );

    const callServer = async (
        prompt: string,
        useServerSideHistory?: boolean,
        previousMessages?: {
            message: ChatMessageInterface
        }[],
        model?: string,
        maxTokens?: number,
    ) => {
        setPrompt(prompt);
        setCompletion('');
        setError(null);
        setIsLoading(true);
        setIsComplete(false);
        setIsError(false);

        let callCompleted = false;
        let currentCompletion = '';

        try {
            // see https://github.com/Azure/fetch-event-source
            const ctrl = new AbortController();
            setAbortController(ctrl);

            const fetchResponse = await fetchEventSource(
                url, {
                    method: 'POST',
                    headers: {
                        'Content-Type': 'text/event-stream',
                        'Connection': 'keep-alive',
                        'Cache-Control': 'no-cache',
                        'x-device': apiDevice,
                        'Authorization': `Bearer ${apiToken}`,
                        'accept': "text/event-stream",
                        'Transfer-Encoding': "chunked"
                    },
                    body: JSON.stringify({
                        model: model ?? 'gpt-3.5-turbo',
                        maxTokens: maxTokens ?? 200,
                        prompt: prompt,
                        useServerSideHistory: useServerSideHistory,
                        previousMessages: previousMessages ?? [],
                    }),
                    signal: ctrl.signal,

                    async onopen(response) {
                        if (response.ok && response.headers.get('content-type') === EventStreamContentType) {
                            return; // everything's good
                        } else if (response.status >= 400 && response.status < 500 && response.status !== 429) {
                            setError(error);
                            setIsError(true);
                            setIsLoading(false);
                            onErrorCallback(response.statusText);
                        } else {
                            // maybe a transient error, so could retry?
                        }
                    },

                    onmessage(msg) {
                        switch (msg.event) {
                            case 'chunk':
                                try {
                                    const data = JSON.parse(msg.data) as ChatCompletionResultInterface;
                                    currentCompletion += data.choices[0].delta?.content ?? '';
                                    setCompletion(currentCompletion);
                                    onResponseCallback(data, prompt, currentCompletion);
                                } catch (error) {
                                    console.error(error);
                                }
                                break;
                            case 'end':
                                try {
                                    const data = JSON.parse(msg.data) as ChatCompletionResultInterface;
                                    const completion = data.choices[0].message.content;
                                    setCompletion(completion);
                                    onFinishCallback(data, prompt, completion);
                                } catch (error) {
                                    console.error(error);
                                }

                                setIsComplete(true);
                                setIsLoading(false);
                                callCompleted = true;
                                break;
                            case 'error':
                                console.log('error', msg.data);
                                setError(error);
                                setIsError(true);
                                setIsLoading(false);
                                onErrorCallback(msg.data);
                                break;
                        }
                    },
                    onclose() {
                        if (callCompleted) {
                            // all fine, because we got the end event
                        } else {
                            setError('Connection closed unexpectedly');
                            setIsError(true);
                            setIsLoading(false);
                            onErrorCallback('Connection closed unexpectedly');
                        }
                    },
                    onerror(err) {
                        setError(error);
                        setIsError(true);
                        setIsLoading(false);
                        onErrorCallback(err);
                    }
                });
        } catch (error) {
            setError(error);
            setIsError(true);
            onErrorCallback(error);
            throw error;
        }
    }

    /**
     * Cancel current request
     */
    const stop = useCallback(() => {
        if (abortController) {
            abortController.abort('cancel by user');
            setIsLoading(false);
            setIsError(false)
            setIsComplete(false);
            setError(null);
            setAbortController(null)
        }
    }, [abortController])

    /**
     * Start request
     */
    const sendRequest = useCallback(
        (
            prompt: string,
            useServerSideHistory?: boolean,
            previousMessages?: {
                message: ChatMessageInterface
            }[],
            model?: string,
            maxTokens?: number,
        ) => {
            if (!prompt) {
                return
            }
            return callServer(
                prompt,
                useServerSideHistory,
                previousMessages,
                model,
                maxTokens
            )
        },
        []
    )

    return {
        completion,
        isComplete,
        error,
        stop,
        prompt,
        sendRequest: sendRequest,
        isLoading
    }
};

export default useEnhancedChatGpt;
