From 818d8296c439dbcbc0154247cf2e6eb6a3a96a09 Mon Sep 17 00:00:00 2001 From: FoxxMD Date: Fri, 1 Mar 2024 08:55:35 -0500 Subject: [PATCH 1/2] Use job presence to detect end of generation --- core/stablecog.py | 52 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/core/stablecog.py b/core/stablecog.py index b71de84..c385ef3 100644 --- a/core/stablecog.py +++ b/core/stablecog.py @@ -19,18 +19,36 @@ from core import settingscog from threading import Thread -async def update_progress(event_loop, status_message_task, s, queue_object, tries): +async def update_progress(event_loop, status_message_task, s, queue_object, tries, any_job, tries_since_no_job): status_message = status_message_task.result() try: progress_data = s.get(url=f'{settings.global_var.url}/sdapi/v1/progress').json() - - if progress_data["current_image"] is None and tries <= 10: - time.sleep(3) - event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, tries + 1)) - return - - if progress_data["current_image"] is None and tries > 10: - return + job_name = progress_data.get('state').get('job') + if job_name != '': + any_job = True + + if progress_data["current_image"] is None: + if job_name == '': + if any_job: + if tries_since_no_job >= 2: + return + time.sleep(3) + event_loop.create_task( + update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job + 1)) + return + else: + # escape hatch + if tries > 10: + return + time.sleep(3) + event_loop.create_task( + update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job)) + return + else: + time.sleep(3) + event_loop.create_task( + update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0)) + return image = Image.open(io.BytesIO(base64.b64decode(progress_data["current_image"]))) @@ -38,15 +56,15 @@ async def update_progress(event_loop, status_message_task, s, queue_object, trie buffer = stack.enter_context(io.BytesIO()) image.save(buffer, 'PNG') buffer.seek(0) - filename=f'{queue_object.seed}.png' + filename = f'{queue_object.seed}.png' if queue_object.spoiler: - filename=f'SPOILER_{queue_object.seed}.png' - fp=buffer + filename = f'SPOILER_{queue_object.seed}.png' + fp = buffer file = discord.File(fp, filename) - ips = '?' if progress_data["eta_relative"] != 0: - ips = round((int(queue_object.steps) - progress_data["state"]["sampling_step"]) / progress_data["eta_relative"], 2) + ips = round( + (int(queue_object.steps) - progress_data["state"]["sampling_step"]) / progress_data["eta_relative"], 2) view = viewhandler.ProgressView() @@ -61,7 +79,9 @@ async def update_progress(event_loop, status_message_task, s, queue_object, trie print('Something goes wrong...', str(e)) time.sleep(1) - event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, tries)) + + event_loop.create_task( + update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0)) class StableCog(commands.Cog, name='Stable Diffusion', description='Create images from natural language.'): ctx_parse = discord.ApplicationContext @@ -435,7 +455,7 @@ def dream(self, event_loop: queuehandler.GlobalQueue.event_loop, queue_object: q f'\n**Relative ETA**: initialization...')) def worker(): - event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, 0)) + event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, 0, False, 0)) return status_thread = threading.Thread(target=worker) From a7f46ab214061e99915c57e082ea095a847c8f4a Mon Sep 17 00:00:00 2001 From: FoxxMD Date: Fri, 1 Mar 2024 09:04:49 -0500 Subject: [PATCH 2/2] Make progress interval check configurable --- core/settings.py | 5 +++++ core/stablecog.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/core/settings.py b/core/settings.py index e3f590d..3214239 100644 --- a/core/settings.py +++ b/core/settings.py @@ -59,6 +59,8 @@ display_ignored_words = "False" # These words will be added to the beginning of the negative prompt negative_prompt_prefix = [] +# the time, in seconds, between when AIYA checks for generation progress from SD -- can be a float +preview_update_interval = 3 # the fallback channel defaults template for AIYA if nothing is set @@ -125,6 +127,7 @@ class GlobalVar: negative_prompt_prefix = [] spoiler = False spoiler_role = None + preview_update_interval = 3 global_var = GlobalVar() @@ -512,6 +515,8 @@ def populate_global_vars(): global_var.prompt_ignore_list = [x for x in config['prompt_ignore_list']] global_var.display_ignored_words = config['display_ignored_words'] global_var.negative_prompt_prefix = [x for x in config['negative_prompt_prefix']] + if config['preview_update_interval'] is not None: + global_var.preview_update_interval = float(config['preview_update_interval']) # create persistent session since we'll need to do a few API calls s = authenticate_user() diff --git a/core/stablecog.py b/core/stablecog.py index c385ef3..1bcf3e1 100644 --- a/core/stablecog.py +++ b/core/stablecog.py @@ -32,7 +32,7 @@ async def update_progress(event_loop, status_message_task, s, queue_object, trie if any_job: if tries_since_no_job >= 2: return - time.sleep(3) + time.sleep(settings.global_var.preview_update_interval) event_loop.create_task( update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job + 1)) return @@ -40,12 +40,12 @@ async def update_progress(event_loop, status_message_task, s, queue_object, trie # escape hatch if tries > 10: return - time.sleep(3) + time.sleep(settings.global_var.preview_update_interval) event_loop.create_task( update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job)) return else: - time.sleep(3) + time.sleep(settings.global_var.preview_update_interval) event_loop.create_task( update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0)) return