Skip to content

Commit

Permalink
Merge pull request #11 from FoxxMD/Live-preview-foxx
Browse files Browse the repository at this point in the history
Use job presence to detect end of generation and make interval check configurable
  • Loading branch information
sebaxakerhtc authored Mar 1, 2024
2 parents e70bf01 + a7f46ab commit 9548e40
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
5 changes: 5 additions & 0 deletions core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
display_ignored_words = "False"
# These words will be added to the beginning of the negative prompt
negative_prompt_prefix = []
# the time, in seconds, between when AIYA checks for generation progress from SD -- can be a float
preview_update_interval = 3
# the fallback channel defaults template for AIYA if nothing is set
Expand Down Expand Up @@ -125,6 +127,7 @@ class GlobalVar:
negative_prompt_prefix = []
spoiler = False
spoiler_role = None
preview_update_interval = 3


global_var = GlobalVar()
Expand Down Expand Up @@ -512,6 +515,8 @@ def populate_global_vars():
global_var.prompt_ignore_list = [x for x in config['prompt_ignore_list']]
global_var.display_ignored_words = config['display_ignored_words']
global_var.negative_prompt_prefix = [x for x in config['negative_prompt_prefix']]
if config['preview_update_interval'] is not None:
global_var.preview_update_interval = float(config['preview_update_interval'])

# create persistent session since we'll need to do a few API calls
s = authenticate_user()
Expand Down
52 changes: 36 additions & 16 deletions core/stablecog.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,52 @@
from core import settingscog
from threading import Thread

async def update_progress(event_loop, status_message_task, s, queue_object, tries):
async def update_progress(event_loop, status_message_task, s, queue_object, tries, any_job, tries_since_no_job):
status_message = status_message_task.result()
try:
progress_data = s.get(url=f'{settings.global_var.url}/sdapi/v1/progress').json()

if progress_data["current_image"] is None and tries <= 10:
time.sleep(3)
event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, tries + 1))
return

if progress_data["current_image"] is None and tries > 10:
return
job_name = progress_data.get('state').get('job')
if job_name != '':
any_job = True

if progress_data["current_image"] is None:
if job_name == '':
if any_job:
if tries_since_no_job >= 2:
return
time.sleep(settings.global_var.preview_update_interval)
event_loop.create_task(
update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job + 1))
return
else:
# escape hatch
if tries > 10:
return
time.sleep(settings.global_var.preview_update_interval)
event_loop.create_task(
update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, tries_since_no_job))
return
else:
time.sleep(settings.global_var.preview_update_interval)
event_loop.create_task(
update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0))
return

image = Image.open(io.BytesIO(base64.b64decode(progress_data["current_image"])))

with contextlib.ExitStack() as stack:
buffer = stack.enter_context(io.BytesIO())
image.save(buffer, 'PNG')
buffer.seek(0)
filename=f'{queue_object.seed}.png'
filename = f'{queue_object.seed}.png'
if queue_object.spoiler:
filename=f'SPOILER_{queue_object.seed}.png'
fp=buffer
filename = f'SPOILER_{queue_object.seed}.png'
fp = buffer
file = discord.File(fp, filename)

ips = '?'
if progress_data["eta_relative"] != 0:
ips = round((int(queue_object.steps) - progress_data["state"]["sampling_step"]) / progress_data["eta_relative"], 2)
ips = round(
(int(queue_object.steps) - progress_data["state"]["sampling_step"]) / progress_data["eta_relative"], 2)

view = viewhandler.ProgressView()

Expand All @@ -61,7 +79,9 @@ async def update_progress(event_loop, status_message_task, s, queue_object, trie
print('Something goes wrong...', str(e))

time.sleep(1)
event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, tries))

event_loop.create_task(
update_progress(event_loop, status_message_task, s, queue_object, tries + 1, any_job, 0))

class StableCog(commands.Cog, name='Stable Diffusion', description='Create images from natural language.'):
ctx_parse = discord.ApplicationContext
Expand Down Expand Up @@ -435,7 +455,7 @@ def dream(self, event_loop: queuehandler.GlobalQueue.event_loop, queue_object: q
f'\n**Relative ETA**: initialization...'))

def worker():
event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, 0))
event_loop.create_task(update_progress(event_loop, status_message_task, s, queue_object, 0, False, 0))
return

status_thread = threading.Thread(target=worker)
Expand Down

0 comments on commit 9548e40

Please sign in to comment.