From fcbc3fb879d7e27ddd5b1c788a8c5ec8aa9efe17 Mon Sep 17 00:00:00 2001 From: Adela Almasan <88068998+adela-almasan@users.noreply.github.com> Date: Thu, 9 Nov 2023 09:06:14 -0600 Subject: [PATCH] Dashboard: Add ability to stop title/description generation (#77896) --- .../components/GenAI/GenAIButton.test.tsx | 26 ++++++++++++--- .../components/GenAI/GenAIButton.tsx | 32 ++++++++++++------- .../GenAI/GenAIDashDescriptionButton.tsx | 1 - .../components/GenAI/GenAIDashTitleButton.tsx | 1 - .../GenAI/GenAIDashboardChangesButton.tsx | 1 - .../components/GenAI/GenAIHistory.tsx | 21 ++++++++---- .../GenAI/GenAIPanelDescriptionButton.tsx | 1 - .../GenAI/GenAIPanelTitleButton.tsx | 1 - .../dashboard/components/GenAI/hooks.ts | 16 ++++++++++ .../dashboard/components/GenAI/tracking.ts | 1 + 10 files changed, 74 insertions(+), 27 deletions(-) diff --git a/public/app/features/dashboard/components/GenAI/GenAIButton.test.tsx b/public/app/features/dashboard/components/GenAI/GenAIButton.test.tsx index dc94a9ae1e0..e73135fdfaf 100644 --- a/public/app/features/dashboard/components/GenAI/GenAIButton.test.tsx +++ b/public/app/features/dashboard/components/GenAI/GenAIButton.test.tsx @@ -47,6 +47,7 @@ describe('GenAIButton', () => { streamStatus: StreamStatus.IDLE, reply: 'Some completed genereated text', setMessages: jest.fn(), + setStopGeneration: jest.fn(), value: { enabled: false, stream: new Observable().subscribe(), @@ -63,12 +64,14 @@ describe('GenAIButton', () => { describe('when LLM plugin is properly configured, so it is enabled', () => { const setMessagesMock = jest.fn(); + const setShouldStopMock = jest.fn(); beforeEach(() => { jest.mocked(useOpenAIStream).mockReturnValue({ error: undefined, streamStatus: StreamStatus.IDLE, reply: 'Some completed genereated text', setMessages: setMessagesMock, + setStopGeneration: setShouldStopMock, value: { enabled: true, stream: new Observable().subscribe(), @@ -114,12 +117,15 @@ describe('GenAIButton', () => { }); describe('when it is generating data', () => { + const setShouldStopMock = jest.fn(); + beforeEach(() => { jest.mocked(useOpenAIStream).mockReturnValue({ error: undefined, streamStatus: StreamStatus.GENERATING, reply: 'Some incomplete generated text', setMessages: jest.fn(), + setStopGeneration: setShouldStopMock, value: { enabled: true, stream: new Observable().subscribe(), @@ -138,13 +144,12 @@ describe('GenAIButton', () => { waitFor(async () => expect(await screen.findByRole('button')).toBeEnabled()); }); - it('disables the button while generating', async () => { + it('shows the stop button while generating', async () => { const { getByText, getByRole } = setup(); - const generateButton = getByText('Generating'); + const generateButton = getByText('Stop generating'); - // The loading text should be visible and the button disabled expect(generateButton).toBeVisible(); - await waitFor(() => expect(getByRole('button')).toBeDisabled()); + await waitFor(() => expect(getByRole('button')).toBeEnabled()); }); it('should call onGenerate when the text is generating', async () => { @@ -155,16 +160,29 @@ describe('GenAIButton', () => { expect(onGenerate).toHaveBeenCalledWith('Some incomplete generated text'); }); + + it('should stop generating when clicking the button', async () => { + const onGenerate = jest.fn(); + const { getByText } = setup({ onGenerate, messages: [], eventTrackingSrc: eventTrackingSrc }); + const generateButton = getByText('Stop generating'); + + await fireEvent.click(generateButton); + + expect(setShouldStopMock).toHaveBeenCalledTimes(1); + expect(setShouldStopMock).toHaveBeenCalledWith(true); + }); }); describe('when there is an error generating data', () => { const setMessagesMock = jest.fn(); + const setShouldStopMock = jest.fn(); beforeEach(() => { jest.mocked(useOpenAIStream).mockReturnValue({ error: new Error('Something went wrong'), streamStatus: StreamStatus.IDLE, reply: '', setMessages: setMessagesMock, + setStopGeneration: setShouldStopMock, value: { enabled: true, stream: new Observable().subscribe(), diff --git a/public/app/features/dashboard/components/GenAI/GenAIButton.tsx b/public/app/features/dashboard/components/GenAI/GenAIButton.tsx index 133d800e6cc..c94a34412cf 100644 --- a/public/app/features/dashboard/components/GenAI/GenAIButton.tsx +++ b/public/app/features/dashboard/components/GenAI/GenAIButton.tsx @@ -12,8 +12,6 @@ import { OAI_MODEL, DEFAULT_OAI_MODEL, Message, sanitizeReply } from './utils'; export interface GenAIButtonProps { // Button label text text?: string; - // Button label text when loading - loadingText?: string; toggleTipTitle?: string; // Button click handler onClick?: (e: React.MouseEvent) => void; @@ -30,10 +28,10 @@ export interface GenAIButtonProps { // Whether the button should be disabled disabled?: boolean; } +export const STOP_GENERATION_TEXT = 'Stop generating'; export const GenAIButton = ({ text = 'Auto-generate', - loadingText = 'Generating', toggleTipTitle = '', onClick: onClickProp, model = DEFAULT_OAI_MODEL, @@ -45,27 +43,34 @@ export const GenAIButton = ({ }: GenAIButtonProps) => { const styles = useStyles2(getStyles); - const { setMessages, reply, value, error, streamStatus } = useOpenAIStream(model, temperature); + const { setMessages, setStopGeneration, reply, value, error, streamStatus } = useOpenAIStream(model, temperature); const [history, setHistory] = useState([]); const [showHistory, setShowHistory] = useState(true); const hasHistory = history.length > 0; const isFirstHistoryEntry = streamStatus === StreamStatus.GENERATING && !hasHistory; - const isButtonDisabled = disabled || isFirstHistoryEntry || (value && !value.enabled && !error); + const isButtonDisabled = disabled || (value && !value.enabled && !error); const reportInteraction = (item: AutoGenerateItem) => reportAutoGenerateInteraction(eventTrackingSrc, item); const onClick = (e: React.MouseEvent) => { - if (!hasHistory) { - onClickProp?.(e); - setMessages(messages); + if (streamStatus === StreamStatus.GENERATING) { + setStopGeneration(true); } else { - if (setShowHistory) { - setShowHistory(true); + if (!hasHistory) { + onClickProp?.(e); + setMessages(messages); + } else { + if (setShowHistory) { + setShowHistory(true); + } } } + const buttonItem = error ? AutoGenerateItem.erroredRetryButton + : isFirstHistoryEntry + ? AutoGenerateItem.stopGenerationButton : hasHistory ? AutoGenerateItem.improveButton : AutoGenerateItem.autoGenerateButton; @@ -123,7 +128,7 @@ export const GenAIButton = ({ } if (isFirstHistoryEntry) { - buttonText = loadingText; + buttonText = STOP_GENERATION_TEXT; } if (hasHistory) { @@ -176,7 +181,7 @@ export const GenAIButton = ({ return (
- {isFirstHistoryEntry && } + {isFirstHistoryEntry && } {!hasHistory && ( ({ wrapper: css({ display: 'flex', }), + spinner: css({ + color: theme.colors.text.link, + }), }); diff --git a/public/app/features/dashboard/components/GenAI/GenAIDashDescriptionButton.tsx b/public/app/features/dashboard/components/GenAI/GenAIDashDescriptionButton.tsx index 4f9cc429ed0..64789197c91 100644 --- a/public/app/features/dashboard/components/GenAI/GenAIDashDescriptionButton.tsx +++ b/public/app/features/dashboard/components/GenAI/GenAIDashDescriptionButton.tsx @@ -31,7 +31,6 @@ export const GenAIDashDescriptionButton = ({ onGenerate, dashboard }: GenAIDashD diff --git a/public/app/features/dashboard/components/GenAI/GenAIDashTitleButton.tsx b/public/app/features/dashboard/components/GenAI/GenAIDashTitleButton.tsx index cebd5cb9478..3f7c418e869 100644 --- a/public/app/features/dashboard/components/GenAI/GenAIDashTitleButton.tsx +++ b/public/app/features/dashboard/components/GenAI/GenAIDashTitleButton.tsx @@ -31,7 +31,6 @@ export const GenAIDashTitleButton = ({ onGenerate, dashboard }: GenAIDashTitleBu diff --git a/public/app/features/dashboard/components/GenAI/GenAIDashboardChangesButton.tsx b/public/app/features/dashboard/components/GenAI/GenAIDashboardChangesButton.tsx index aa0e8ebbbaa..846ebe48586 100644 --- a/public/app/features/dashboard/components/GenAI/GenAIDashboardChangesButton.tsx +++ b/public/app/features/dashboard/components/GenAI/GenAIDashboardChangesButton.tsx @@ -41,7 +41,6 @@ export const GenAIDashboardChangesButton = ({ dashboard, onGenerate, disabled }: { - onApplySuggestion(history[currentIndex - 1]); + if (isStreamGenerating) { + setStopGeneration(true); + if (reply !== '') { + updateHistory(sanitizeReply(reply)); + } + } else { + onApplySuggestion(history[currentIndex - 1]); + } }; const onNavigate = (index: number) => { @@ -148,9 +158,8 @@ export const GenAIHistory = ({
- {isStreamGenerating && } -
diff --git a/public/app/features/dashboard/components/GenAI/GenAIPanelDescriptionButton.tsx b/public/app/features/dashboard/components/GenAI/GenAIPanelDescriptionButton.tsx index e7381dcfd13..b76b213b874 100644 --- a/public/app/features/dashboard/components/GenAI/GenAIPanelDescriptionButton.tsx +++ b/public/app/features/dashboard/components/GenAI/GenAIPanelDescriptionButton.tsx @@ -30,7 +30,6 @@ export const GenAIPanelDescriptionButton = ({ onGenerate, panel }: GenAIPanelDes diff --git a/public/app/features/dashboard/components/GenAI/GenAIPanelTitleButton.tsx b/public/app/features/dashboard/components/GenAI/GenAIPanelTitleButton.tsx index 64a079ae105..fe56af69ba4 100644 --- a/public/app/features/dashboard/components/GenAI/GenAIPanelTitleButton.tsx +++ b/public/app/features/dashboard/components/GenAI/GenAIPanelTitleButton.tsx @@ -26,7 +26,6 @@ export const GenAIPanelTitleButton = ({ onGenerate, panel }: GenAIPanelTitleButt diff --git a/public/app/features/dashboard/components/GenAI/hooks.ts b/public/app/features/dashboard/components/GenAI/hooks.ts index 1524d7df868..14cc7436bbb 100644 --- a/public/app/features/dashboard/components/GenAI/hooks.ts +++ b/public/app/features/dashboard/components/GenAI/hooks.ts @@ -26,6 +26,7 @@ export function useOpenAIStream( temperature = 1 ): { setMessages: React.Dispatch>; + setStopGeneration: React.Dispatch>; reply: string; streamStatus: StreamStatus; error: Error | undefined; @@ -42,6 +43,7 @@ export function useOpenAIStream( } { // The messages array to send to the LLM, updated when the button is clicked. const [messages, setMessages] = useState([]); + const [stopGeneration, setStopGeneration] = useState(false); // The latest reply from the LLM. const [reply, setReply] = useState(''); const [streamStatus, setStreamStatus] = useState(StreamStatus.IDLE); @@ -52,6 +54,7 @@ export function useOpenAIStream( (e: Error) => { setStreamStatus(StreamStatus.IDLE); setMessages([]); + setStopGeneration(false); setError(e); notifyError( 'Failed to generate content using OpenAI', @@ -104,6 +107,7 @@ export function useOpenAIStream( setStreamStatus(StreamStatus.IDLE); }); setMessages([]); + setStopGeneration(false); setError(undefined); }, }), @@ -119,6 +123,17 @@ export function useOpenAIStream( }; }, [value]); + // Unsubscribe from the stream when user stops the generation. + useEffect(() => { + if (stopGeneration) { + value?.stream?.unsubscribe(); + setStreamStatus(StreamStatus.IDLE); + setStopGeneration(false); + setError(undefined); + setMessages([]); + } + }, [stopGeneration, value?.stream]); + // If the stream is generating and we haven't received a reply, it times out. useEffect(() => { let timeout: NodeJS.Timeout | undefined; @@ -138,6 +153,7 @@ export function useOpenAIStream( return { setMessages, + setStopGeneration, reply, streamStatus, error, diff --git a/public/app/features/dashboard/components/GenAI/tracking.ts b/public/app/features/dashboard/components/GenAI/tracking.ts index 8c51e8a729e..8327fdf430e 100644 --- a/public/app/features/dashboard/components/GenAI/tracking.ts +++ b/public/app/features/dashboard/components/GenAI/tracking.ts @@ -16,6 +16,7 @@ export enum EventTrackingSrc { export enum AutoGenerateItem { autoGenerateButton = 'auto-generate-button', erroredRetryButton = 'errored-retry-button', + stopGenerationButton = 'stop-generating-button', improveButton = 'improve-button', backHistoryItem = 'back-history-item', forwardHistoryItem = 'forward-history-item',