From 9c280ba1e6897aff4a07df5c05fc3566c0de88ba Mon Sep 17 00:00:00 2001 From: chlorine Date: Wed, 24 May 2023 10:33:53 +0800 Subject: [PATCH] feat: support upscale and variation in midjourney, support image zoom (#64) --- package.json | 3 +- pnpm-lock.yaml | 22 ++- src/components/MessageBox/index.tsx | 83 ++++++--- src/components/MidjourneyOperations/index.tsx | 60 +++++++ src/interfaces/index.ts | 9 + src/layouts/Layout.astro | 7 + src/modules/Content/index.tsx | 160 ++++++++++++++---- src/modules/Main.tsx | 6 +- src/pages/api/images.ts | 63 ++++++- src/utils/index.ts | 8 + src/utils/midjourney.ts | 78 +++++++++ 11 files changed, 434 insertions(+), 65 deletions(-) create mode 100644 src/components/MidjourneyOperations/index.tsx create mode 100644 src/utils/midjourney.ts diff --git a/package.json b/package.json index c399d9a..038f131 100644 --- a/package.json +++ b/package.json @@ -39,7 +39,8 @@ "markdown-it": "^13.0.1", "markdown-it-highlightjs": "^4.0.1", "markdown-it-kbd": "^2.2.2", - "midjourney-fetch": "0.1.4", + "medium-zoom": "^1.0.8", + "midjourney-fetch": "1.0.0", "react": "^18.0.0", "react-dom": "^18.0.0", "replicate-fetch": "^0.1.1", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9885ba1..b5cea1b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -35,7 +35,8 @@ specifiers: markdown-it: ^13.0.1 markdown-it-highlightjs: ^4.0.1 markdown-it-kbd: ^2.2.2 - midjourney-fetch: 0.1.4 + medium-zoom: ^1.0.8 + midjourney-fetch: 1.0.0 prettier: ^2.8.4 prettier-plugin-astro: ^0.8.0 punycode: ^2.3.0 @@ -64,7 +65,8 @@ dependencies: markdown-it: 13.0.1 markdown-it-highlightjs: 4.0.1 markdown-it-kbd: 2.2.2 - midjourney-fetch: 0.1.4 + medium-zoom: 1.0.8 + midjourney-fetch: 1.0.0 react: 18.2.0 react-dom: 18.2.0_react@18.2.0 replicate-fetch: 0.1.1 @@ -1002,6 +1004,11 @@ packages: picomatch: 2.3.1 dev: false + /@sapphire/snowflake/3.5.1: + resolution: {integrity: sha512-BxcYGzgEsdlG0dKAyOm0ehLGm2CafIrfQTZGWgkfKYbj+pNNsorZ7EotuZukc2MT70E0UbppVbtpBrqpzVzjNA==} + engines: {node: '>=v14.0.0', npm: '>=7.0.0'} + dev: false + /@tailwindcss/typography/0.5.9_tailwindcss@3.2.7: resolution: {integrity: sha512-t8Sg3DyynFysV9f4JDOVISGsjazNb48AeIYQwcL+Bsq5uf4RYL75C1giZ43KISjeDGBaTN3Kxh7Xj/vRSMJUUg==} peerDependencies: @@ -3786,6 +3793,10 @@ packages: resolution: {integrity: sha512-/sKlQJCBYVY9Ers9hqzKou4H6V5UWc/M59TH2dvkt+84itfnq7uFOMLpOiOS4ujvHP4etln18fmIxA5R5fll0g==} dev: false + /medium-zoom/1.0.8: + resolution: {integrity: sha512-CjFVuFq/IfrdqesAXfg+hzlDKu6A2n80ZIq0Kl9kWjoHh9j1N9Uvk5X0/MmN0hOfm5F9YBswlClhcwnmtwz7gA==} + dev: false + /merge-stream/2.0.0: resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==} @@ -4049,9 +4060,12 @@ packages: braces: 3.0.2 picomatch: 2.3.1 - /midjourney-fetch/0.1.4: - resolution: {integrity: sha512-okpsmDLJGyu6kQTlIB+NdqNQVothb6XdmC76WtjXC+nCGWuydauPTyay+ShoVJBxJ6Ryxe+ZP2dK0mMongl+yg==} + /midjourney-fetch/1.0.0: + resolution: {integrity: sha512-UKZ7UxjsaBAxfA5YwMTMUxCkocFjWitjoabWpiDk8II9EsjpgOsHeXw3jqrFw9HpcF3SL5SH1NZSXF2W8mBpQA==} engines: {node: '>=18', pnpm: '>=8'} + dependencies: + '@sapphire/snowflake': 3.5.1 + dayjs: 1.11.7 dev: false /mime/1.6.0: diff --git a/src/components/MessageBox/index.tsx b/src/components/MessageBox/index.tsx index 3498e5a..7b8445e 100644 --- a/src/components/MessageBox/index.tsx +++ b/src/components/MessageBox/index.tsx @@ -1,18 +1,27 @@ import { FC, useCallback, useContext, useEffect } from 'react'; import { throttle } from 'lodash-es'; +import type { MessageType } from 'midjourney-fetch'; import GlobalContext from '@contexts/global'; import { ConversationMode, Message } from '@interfaces'; import markdown from '@utils/markdown'; import { getRelativeTime } from '@utils/date'; import SystemAvatar from '@components/Avatar/system'; import useCopyCode from '@hooks/useCopyCode'; +import MidjourneyOperations from '@components/MidjourneyOperations'; +import { hasUpscaleOrVariation } from '@utils/midjourney'; import './index.css'; const MessageItem: FC<{ message: Message; + onOperationClick?: ( + type: MessageType, + customId: string, + messageId: string, + prompt: string + ) => void; mode?: ConversationMode; index?: number; -}> = ({ message, mode, index }) => { +}> = ({ message, onOperationClick, mode, index }) => { const { i18n } = useContext(GlobalContext); const isExpired = message.expiredAt && message.expiredAt <= Date.now(); const createdAt = getRelativeTime(message.createdAt, true); @@ -25,27 +34,43 @@ const MessageItem: FC<{ {message.role === 'assistant' ? ( ) : null} -
- {createdAt ? ( +
- {createdAt} -
- ) : null} + dangerouslySetInnerHTML={{ + __html: isExpired + ? i18n.status_image_expired + : markdown.render(message.content), + }} + className={`prose message-box shadow-sm p-4 ${ + message.role === 'user' ? 'bg-gradient text-white' : 'bg-[#ebeced]' + } ${ + mode === 'image' ? 'img-no-margin' : '' + } break-words overflow-hidden rounded-[16px]`} + /> + {createdAt ? ( +
+ {createdAt} +
+ ) : null} + {message.midjourneyMessage && + hasUpscaleOrVariation(message.midjourneyMessage) ? ( + + onOperationClick( + type, + customId, + message.midjourneyMessage.id, + message.midjourneyMessage.prompt + ) + } + /> + ) : null} +
); }; @@ -55,7 +80,13 @@ const MessageBox: FC<{ messages: Message[]; mode: ConversationMode; loading: boolean; -}> = ({ streamMessage, messages, mode, loading }) => { + onOperationClick?: ( + type: MessageType, + customId: string, + messageId: string, + prompt: string + ) => void; +}> = ({ streamMessage, messages, mode, loading, onOperationClick }) => { const { i18n } = useContext(GlobalContext); useCopyCode(i18n.success_copy); @@ -100,7 +131,13 @@ const MessageBox: FC<{ /> ) : null} {messages.map((message, index) => ( - + null : onOperationClick} + /> ))} {streamMessage ? ( diff --git a/src/components/MidjourneyOperations/index.tsx b/src/components/MidjourneyOperations/index.tsx new file mode 100644 index 0000000..d1479da --- /dev/null +++ b/src/components/MidjourneyOperations/index.tsx @@ -0,0 +1,60 @@ +import { FC } from 'react'; +import { Tag } from 'antd'; +import type { MessageType } from 'midjourney-fetch'; +import { MidjourneyMessage } from '@interfaces'; +import { + filterUpscale, + filterVariation, + isComponentAvailable, +} from '@utils/midjourney'; + +const { CheckableTag } = Tag; + +const MidjourneyOperations: FC<{ + message: MidjourneyMessage; + onClick?: (type: MessageType, id: string) => void; +}> = ({ message, onClick }) => { + const upscale = filterUpscale(message.components[0]?.components ?? []); + const variation = filterVariation([ + ...(message.components[0]?.components ?? []), + ...(message.components[1]?.components ?? []), + ]); + return ( +
+ {upscale.map((option) => { + const isAvailable = isComponentAvailable(option); + return ( + + isAvailable && onClick?.('upscale', option.custom_id) + } + > + {option.label || option.emoji?.name} + + ); + })} + {variation.map((option) => { + const isAvailable = isComponentAvailable(option); + return ( + + isAvailable && onClick?.('variation', option.custom_id) + } + > + {option.label || option.emoji?.name} + + ); + })} +
+ ); +}; + +export default MidjourneyOperations; diff --git a/src/interfaces/index.ts b/src/interfaces/index.ts index 950f25f..7c352f0 100644 --- a/src/interfaces/index.ts +++ b/src/interfaces/index.ts @@ -1,3 +1,4 @@ +import type { MessageItem } from 'midjourney-fetch'; import { LayoutConfig, SupportedImageModels, @@ -5,10 +6,18 @@ import { SupportedModel, } from '@configs'; +export type MidjourneyMessage = Pick< + MessageItem, + 'id' | 'components' | 'attachments' +> & { + prompt: string; +}; + export interface Message { content: string; role: 'assistant' | 'user'; imageModel?: SupportedImageModels; // distinguish avator + midjourneyMessage?: MidjourneyMessage; createdAt?: number; expiredAt?: number; // for image mode } diff --git a/src/layouts/Layout.astro b/src/layouts/Layout.astro index 3be3c83..c862e0b 100644 --- a/src/layouts/Layout.astro +++ b/src/layouts/Layout.astro @@ -179,6 +179,13 @@ const { title } = Astro.props; overflow: hidden; text-overflow: ellipsis; } + /** medium zoom */ + .medium-zoom-overlay { + z-index: 20; + } + .medium-zoom-image { + z-index: 21; + } /* mobile style */ @media screen and (max-width: 768px) { html { diff --git a/src/modules/Content/index.tsx b/src/modules/Content/index.tsx index be40f6a..633e5d5 100644 --- a/src/modules/Content/index.tsx +++ b/src/modules/Content/index.tsx @@ -7,10 +7,12 @@ import { hasMathJax, initMathJax, renderMaxJax } from '@utils/markdown'; import { hasMath } from '@utils'; import { midjourneyConfigs } from '@configs'; import { - type MessageAttachment, type MessageItem, + type MessageType, isInProgress, + getHashFromCustomId, } from 'midjourney-fetch'; +import { updateComponentStatus } from '@utils/midjourney'; import MessageInput from './MessageInput'; import ContentHeader from './ContentHeader'; @@ -197,15 +199,50 @@ const Content: FC = ({ setActiveSetting }) => { } }; - const sendImageChatMessages = async (content: string) => { + const sendImageChatMessages = async ( + content: string, + type: MessageType = 'imagine', + extraParams: Partial<{ + customId: string; + messageId: string; + index: number; + }> = {} + ) => { + const { messageId, index } = extraParams; const current = currentId; - const allMessages: Message[] = messages.concat([ + let messageInput = content; + let allMessages: Message[] = messages; + if ( + (type === 'upscale' || type === 'variation') && + typeof index === 'number' && + messageId + ) { + // add flag to mark, only in frontend + if (index > 0) { + messageInput = `/${ + type === 'variation' ? 'V' : 'U' + }${index} ${messageInput}`; + } else { + messageInput = `/🔄 ${messageInput}`; + } + + // update status + allMessages = updateComponentStatus({ + type, + messages: allMessages, + messageId, + index, + }); + } + // concat user input + allMessages = allMessages.concat([ { role: 'user', - content, + content: messageInput, createdAt: Date.now(), }, ]); + updateMessages(allMessages); setText(''); setLoadingMap((map) => ({ @@ -213,20 +250,45 @@ const Content: FC = ({ setActiveSetting }) => { [current]: true, })); const model = configs.imageModel; + let params: Record = { + password: configs.password, + model, + prompt: content, + }; + if (model === 'Midjourney') { + params = { + ...params, + serverId: configs.discordServerId, + channelId: configs.discordChannelId, + type, + }; + if (type === 'upscale' || type === 'variation') { + params = { + ...params, + ...extraParams, + }; + } + } else if (model === 'Replicate') { + params = { + ...params, + size: configs.imageSize || '256x256', + }; + } else { + params = { + ...params, + key: configs.openAIApiKey, + size: configs.imageSize || '256x256', + n: configs.imagesCount || 1, + }; + } try { + const timestamp = new Date().toISOString(); const res = await fetch('/api/images', { method: 'POST', - body: JSON.stringify({ - key: configs.openAIApiKey, - prompt: content, - size: configs.imageSize || '256x256', - n: configs.imagesCount || 1, - password: configs.password, - model, - serverId: configs.discordServerId, - channelId: configs.discordChannelId, - token: configs.discordToken, - }), + body: JSON.stringify(params), + headers: { + Authorization: model === 'Midjourney' ? configs.discordToken : '', + }, }); const { data = [], msg } = await res.json(); @@ -234,7 +296,7 @@ const Content: FC = ({ setActiveSetting }) => { if (model === 'Midjourney') { const times = midjourneyConfigs.timeout / midjourneyConfigs.interval; let count = 0; - let image: MessageAttachment | null = null; + let result: MessageItem | undefined; while (count < times) { try { count += 1; @@ -243,13 +305,22 @@ const Content: FC = ({ setActiveSetting }) => { ); const message: MessageItem & { msg?: string } = await ( await fetch( - `/api/images?model=Midjourney&prompt=${content}&serverId=${configs.discordServerId}&channelId=${configs.discordChannelId}&token=${configs.discordToken}` + `/api/images?model=Midjourney&prompt=${content}&serverId=${ + configs.discordServerId + }&channelId=${configs.discordChannelId}&type=${type}&index=${ + index ?? '' + }×tamp=${timestamp}`, + { + headers: { + Authorization: configs.discordToken, + }, + } ) ).json(); console.log(count, JSON.stringify(message)); // msg means error message if (message && !message.msg && !isInProgress(message)) { - [image] = message.attachments; + result = message; break; } } catch (e) { @@ -257,19 +328,38 @@ const Content: FC = ({ setActiveSetting }) => { continue; } } - updateMessages( - allMessages.concat([ - { - role: 'assistant', - content: image ? `![](${image.url})` : 'No result or timeout', - imageModel: model, - createdAt: Date.now(), - }, - ]) - ); + if (result) { + updateMessages( + allMessages.concat([ + { + role: 'assistant', + content: `![](${result.attachments[0].url})`, + imageModel: model, + midjourneyMessage: { + id: result.id, + attachments: result.attachments, + components: result.components, + prompt: content, + }, + createdAt: Date.now(), + }, + ]) + ); + } else { + updateMessages( + allMessages.concat([ + { + role: 'assistant', + content: 'No result or timeout', + imageModel: model, + createdAt: Date.now(), + }, + ]) + ); + } } else { - const params = new URLSearchParams(data?.[0]); - const expiredAt = params.get('se'); + const searchParams = new URLSearchParams(data?.[0]); + const expiredAt = searchParams.get('se'); updateMessages( allMessages.concat([ { @@ -328,6 +418,16 @@ const Content: FC = ({ setActiveSetting }) => { messages={messages} mode={mode} loading={loading} + onOperationClick={(type, customId, messageId, prompt) => { + const { index } = getHashFromCustomId(type, customId); + if (typeof index === 'number') { + sendImageChatMessages(prompt, type, { + customId, + index, + messageId, + }); + } + }} /> = ({ lang, inVercel }) => { window.addEventListener('resize', handleDebounceResize); }, []); + useEffect(() => { + registerMediumZoom(isMobile); + }, [currentId, conversations, isMobile]); + const setConversationsFromLocal = useCallback(() => { try { const localConversation = localStorage.getItem(localConversationKey); diff --git a/src/pages/api/images.ts b/src/pages/api/images.ts index 7562cfb..170e609 100644 --- a/src/pages/api/images.ts +++ b/src/pages/api/images.ts @@ -3,7 +3,11 @@ import type { APIRoute } from 'astro'; import { loadBalancer } from '@utils/server'; import { createOpenjourney } from 'replicate-fetch'; import { SupportedImageModels } from '@configs'; -import { Midjourney } from 'midjourney-fetch'; +import { + Midjourney, + type MessageType, + type MessageTypeProps, +} from 'midjourney-fetch'; import { apiKeyStrategy, apiKeys, @@ -18,14 +22,16 @@ import { export { config }; export const get: APIRoute = async ({ request }) => { - const { url } = request; + const { url, headers } = request; const params = new URL(url).searchParams; const model = params.get('model') as SupportedImageModels; const serverId = params.get('serverId') || dicordServerId; const channelId = params.get('channelId') || discordChannelId; - const token = params.get('token') || discordToken; + const token = headers.get('Authorization') || discordToken; const prompt = params.get('prompt'); + const type = (params.get('type') as MessageType) || 'imagine'; + const timestamp = params.get('timestamp'); if (model === 'Midjourney') { if (!prompt) { @@ -56,12 +62,35 @@ export const get: APIRoute = async ({ request }) => { token, }); midjourney.debugger = true; + try { - const message = await midjourney.getMessage(prompt); + let options: MessageTypeProps = { type: 'imagine', timestamp }; + + if (type === 'upscale' || type === 'variation') { + const index = params.get('index'); + if (!index) { + return new Response( + JSON.stringify({ + msg: `No ${type} index provided`, + }), + { + status: 400, + } + ); + } + options = { + ...options, + index: Number(index), + type, + }; + } + + const message = await midjourney.getMessage(prompt, options); if (message) { return new Response(JSON.stringify(message), { status: 200 }); } + return new Response(JSON.stringify({ msg: 'No content found' }), { status: 200, }); @@ -134,7 +163,8 @@ export const post: APIRoute = async ({ request }) => { if (model === 'Midjourney') { const serverId = body.serverId || dicordServerId; const channelId = body.channelId || discordChannelId; - const token = body.token || discordToken; + const token = request.headers.get('Authorization') || discordToken; + const type: MessageType = body.type || 'imagine'; if (!serverId || !channelId || !token) { return new Response( @@ -154,7 +184,28 @@ export const post: APIRoute = async ({ request }) => { }); midjourney.debugger = true; - await midjourney.interactions(prompt); + if (type === 'upscale' || type === 'variation') { + const { messageId, customId }: { messageId: string; customId: string } = + body; + + if (!messageId || !customId) { + return new Response( + JSON.stringify({ + msg: 'No messageId or customId', + }), + { + status: 400, + } + ); + } + + await midjourney.createUpscaleOrVariation(type, { + messageId, + customId, + }); + } else { + await midjourney.createImage(prompt); + } return new Response('{}', { status: 200 }); } diff --git a/src/utils/index.ts b/src/utils/index.ts index 27ed4f7..b260d27 100644 --- a/src/utils/index.ts +++ b/src/utils/index.ts @@ -1,5 +1,6 @@ import { LayoutConfig } from '@configs'; import { Conversation, Message, RecordCardItem } from '@interfaces'; +import mediumZoom from 'medium-zoom'; export const getMaxIndex = (tabs: RecordCardItem[]) => { let max = tabs.length; @@ -83,3 +84,10 @@ export const setClassByLayout = (layout?: LayoutConfig) => { const targetClass = classMap[layout ?? 'default']; container.className = targetClass; }; + +export const registerMediumZoom = (isMobile = false) => { + mediumZoom('.prose img:not(.medium-zoom-image)', { + background: 'rgba(0, 0, 0, 0.6)', + margin: isMobile ? 16 : 48, + }); +}; diff --git a/src/utils/midjourney.ts b/src/utils/midjourney.ts new file mode 100644 index 0000000..5788d63 --- /dev/null +++ b/src/utils/midjourney.ts @@ -0,0 +1,78 @@ +import type { Message, MidjourneyMessage } from '@interfaces'; +import type { MessageComponent, MessageType } from 'midjourney-fetch'; + +export const hasUpscaleOrVariation = (message: MidjourneyMessage) => { + const { components } = message; + + if (components.length === 0) return false; + + const options = components.map((component) => component.components).flat(1); + + return options.some( + (option) => + option.custom_id?.startsWith('MJ::JOB::upsample') || + option.custom_id?.startsWith('MJ::JOB::variation') + ); +}; + +export const filterUpscale = (options: MessageComponent[]) => + options.filter((option) => option.custom_id?.startsWith('MJ::JOB::upsample')); + +export const filterVariation = (options: MessageComponent[]) => + options.filter( + (option) => + option.custom_id?.startsWith('MJ::JOB::reroll') || + option.custom_id?.startsWith('MJ::JOB::variation') + ); + +export const isComponentAvailable = (option: MessageComponent) => + option.style === 2; + +export const updateComponentStatus = ({ + type, + messageId, + index, + messages, +}: { + type: MessageType; + messageId: string; + index: number; + messages: Message[]; +}) => { + messages.some((msg) => { + const match = + msg.midjourneyMessage?.id && msg.midjourneyMessage.id === messageId; + if (match) { + let typeIndex: number; + let prefix: string; + if (type === 'upscale') { + typeIndex = 0; + prefix = 'MJ::JOB::upsample::'; + } else if (index > 0) { + typeIndex = 1; + prefix = 'MJ::JOB::variation::'; + } else { + typeIndex = 0; + prefix = 'MJ::JOB::reroll::'; + } + msg.midjourneyMessage?.components?.[typeIndex]?.components?.some( + (component) => { + if (component.custom_id?.includes(`${prefix}${index}`)) { + if (type === 'upscale') { + // eslint-disable-next-line no-param-reassign + component.style = 1; + } else if (type === 'variation') { + // eslint-disable-next-line no-param-reassign + component.style = index > 0 ? 3 : 1; + } + return true; + } + return false; + } + ); + } + return false; + }); + + return messages; +};