diff --git a/garfmain.py b/garfmain.py index 2aee6e1..fb4b04a 100644 --- a/garfmain.py +++ b/garfmain.py @@ -161,13 +161,13 @@ async def garfchat(ctx, *, prompt): await ctx.reply(answer) -# @garfbot.command(name="pic") -# async def garfpic(ctx, *, prompt): -# logger.info( -# f"Image Request - User: {ctx.author.name}, Server: {ctx.guild.name}, Prompt: {prompt}" -# ) -# await ctx.reply(f"`Please wait... image generation queued: {prompt}`") -# await garfield.garfpic(ctx, prompt) +@garfbot.command(name="pic") +async def garfpic(ctx, *, prompt): + logger.info( + f"Image Request - User: {ctx.author.name}, Server: {ctx.guild.name}, Prompt: {prompt}" + ) + await ctx.reply(f"`Please wait... image generation queued: {prompt}`") + await garfield.garfpic(ctx, prompt) @garfbot.command(name="help") diff --git a/garfpy/garfai.py b/garfpy/garfai.py index 3b92f89..0b6f35a 100644 --- a/garfpy/garfai.py +++ b/garfpy/garfai.py @@ -1,4 +1,6 @@ import io +import time +import uuid import openai import config import aiohttp @@ -9,96 +11,212 @@ from openai import AsyncOpenAI from garfpy import logger +INVOKEAI_BASE = config.INVOKEAI_URL + +_MODEL_KEY = "0eb50094-5c9b-431b-ba01-87e145edb849" +_VAE_KEY = "dde3627c-8a45-4088-93d1-66c44acbb337" +_ENCODER_KEY = "7ba22542-4687-4946-a52e-c92f925f4b75" + +_MODEL_REF = {"key": _MODEL_KEY, "hash": "blake3:c3ee838d71d99497db01fae6f304eafd9e734e935f3b783e968d50febb56be2c", "name": "FLUX.2 Klein 4B (GGUF Q4)", "base": "flux2", "type": "main"} +_VAE_REF = {"key": _VAE_KEY, "hash": "blake3:531855de70db993d0f6181f82cde27d15411d58b7ffa3b2fdce2b9434c0173c2", "name": "FLUX.2 VAE", "base": "flux2", "type": "vae"} +_ENCODER_REF = {"key": _ENCODER_KEY, "hash": "blake3:af5840e6770dc99f678e69867949c8b9264835915eb82a990e940fa6e4fa6c81", "name": "FLUX.2 Klein Qwen3 4B Encoder", "base": "any", "type": "qwen3_encoder"} + +_POLL_INTERVAL = 2 +_POLL_ATTEMPTS = 60 +_MIN_IMAGE_BYTES = 1000 + + +def _node_id(prefix: str) -> str: + return f"{prefix}:{uuid.uuid4().hex[:10]}" + + +def _build_graph(prompt: str) -> dict: + seed = int(time.time() * 1000) % (2 ** 31) + + p = _node_id("positive_prompt") + ml = _node_id("flux2_klein_model_loader") + te = _node_id("flux2_klein_text_encoder") + dn = _node_id("flux2_denoise") + out = _node_id("canvas_output") + + nodes = { + p: {"id": p, "is_intermediate": True, "use_cache": True, "value": prompt, "type": "string"}, + ml: {"id": ml, "is_intermediate": True, "use_cache": True, "type": "flux2_klein_model_loader", + "model": _MODEL_REF, "vae_model": _VAE_REF, "qwen3_encoder_model": _ENCODER_REF}, + te: {"id": te, "is_intermediate": True, "use_cache": True, "type": "flux2_klein_text_encoder"}, + dn: {"id": dn, "is_intermediate": True, "use_cache": True, "type": "flux2_denoise", "seed": seed}, + out: {"id": out, "is_intermediate": False, "use_cache": False, "type": "flux2_vae_decode"}, + } + + edges = [ + {"source": {"node_id": ml, "field": "qwen3_encoder"}, "destination": {"node_id": te, "field": "qwen3_encoder"}}, + {"source": {"node_id": ml, "field": "max_seq_len"}, "destination": {"node_id": te, "field": "max_seq_len"}}, + {"source": {"node_id": p, "field": "value"}, "destination": {"node_id": te, "field": "prompt"}}, + {"source": {"node_id": ml, "field": "transformer"}, "destination": {"node_id": dn, "field": "transformer"}}, + {"source": {"node_id": ml, "field": "vae"}, "destination": {"node_id": dn, "field": "vae"}}, + {"source": {"node_id": te, "field": "conditioning"}, "destination": {"node_id": dn, "field": "positive_text_conditioning"}}, + {"source": {"node_id": ml, "field": "vae"}, "destination": {"node_id": out, "field": "vae"}}, + {"source": {"node_id": dn, "field": "latents"}, "destination": {"node_id": out, "field": "latents"}}, + ] + + return {"nodes": nodes, "edges": edges} + + +async def _poll_batch(session: aiohttp.ClientSession, base: str, batch_id: str) -> bool: + """Poll batch status until completed, failed, or timed out. Returns True on success.""" + for _ in range(_POLL_ATTEMPTS): + await asyncio.sleep(_POLL_INTERVAL) + try: + async with session.get(f"{base}/api/v1/queue/default/b/{batch_id}/status") as resp: + if not resp.ok: + continue + s = await resp.json(content_type=None) + total = s.get("total", 0) + completed = s.get("completed", 0) + failed = s.get("failed", 0) + if total > 0 and failed >= total: + logger.error(f"Batch {batch_id} failed") + return False + if total > 0 and completed >= total: + return True + except Exception as e: + logger.error(f"InvokeAI poll error: {e}") + return False + + +async def _get_image_name(session: aiohttp.ClientSession, base: str, batch_id: str) -> str | None: + try: + async with session.get(f"{base}/api/v1/queue/default/i/{batch_id}") as resp: + if resp.ok: + data = await resp.json(content_type=None) + for node_result in data.get("session", {}).get("results", {}).values(): + if node_result.get("type") == "image_output": + return node_result["image"]["image_name"] + except Exception as e: + logger.error(f"Item fetch error: {e}") + + try: + async with session.get( + f"{base}/api/v1/images/", + params={"limit": 1, "order_by": "created_at", "direction": "DESC"}, + ) as resp: + if resp.ok: + items = (await resp.json(content_type=None)).get("items", []) + if items: + return items[0]["image_name"] + except Exception as e: + logger.error(f"Image list fallback error: {e}") + + return None + + +async def _fetch_image_bytes(session: aiohttp.ClientSession, base: str, name: str) -> bytes | None: + """Try the full image endpoint, then fall back to thumbnail.""" + urls = [ + f"{base}/api/v1/images/i/{name}/full", + f"{base}/api/v1/images/i/{name}/thumbnail", + ] + for url in urls: + try: + async with session.get(url) as resp: + ct = resp.headers.get("Content-Type", "") + data = await resp.read() + logger.info(f"Image fetch {url}: status={resp.status} content-type={ct} size={len(data)}") + if "html" not in ct and len(data) >= _MIN_IMAGE_BYTES: + return data + except Exception as e: + logger.error(f"InvokeAI image fetch error ({url}): {e}") + return None + + class GarfAI: def __init__(self): - self.baseurl = config.BASE_URL - self.openaikey = config.OPENAI_TOKEN + self.baseurl = config.BASE_URL self.sysprompt = config.SYSTEM_PROMPT - self.txtmodel = config.TXT_MODEL - self.imgmodel = config.IMG_MODEL + self.txtmodel = config.TXT_MODEL + self.imgmodel = config.IMG_MODEL + + self._oai = AsyncOpenAI( + api_key=config.OPENAI_TOKEN, + base_url=config.BASE_URL, + ) self.image_request_queue = asyncio.Queue() async def garfpic(self, ctx, prompt): await self.image_request_queue.put({"ctx": ctx, "prompt": prompt}) - async def generate_image(self, prompt): - client = AsyncOpenAI(api_key=self.openaikey) + async def generate_image(self, session: aiohttp.ClientSession, prompt: str) -> bytes | str: + base = INVOKEAI_BASE + try: - response = await client.images.generate( - model=self.imgmodel, prompt=prompt, n=1, size="1024x1024" - ) - except openai.BadRequestError as e: - return f"`GarfBot Error: ({e.status_code}) - Your request was rejected as a result of our safety system.`" - except openai.InternalServerError as e: - logger.error(e) - return f"`GarfBot Error: ({e.status_code}) - Monday`" + async with session.post( + f"{base}/api/v1/queue/default/enqueue_batch", + json={"batch": {"graph": _build_graph(prompt), "runs": 1}}, + ) as resp: + if not resp.ok: + text = await resp.text() + logger.error(f"InvokeAI enqueue failed {resp.status}: {text}") + return "`GarfBot Error: InvokeAI rejected the request`" + data = await resp.json(content_type=None) + batch_id = data["batch"]["batch_id"] except Exception as e: - logger.error(e) - return "`GarfBot Error: Lasagna`" - data = getattr(response, "data", None) - if not data: - logger.error("No data in response") - return "`GarfBot Error: No images generated`" + logger.error(f"InvokeAI enqueue error: {e}") + return "`GarfBot Error: Couldn't reach InvokeAI`" - first_image = data[0] if len(data) > 0 else None - if not first_image: - logger.error("No image in response data") - return "`GarfBot Error: No images generated`" + logger.info(f"InvokeAI batch queued: {batch_id}") - image_url = getattr(first_image, "url", None) - if not image_url: - logger.error("No URL in image response") - return "`GarfBot Error: No image URL returned`" + if not await _poll_batch(session, base, batch_id): + return "`GarfBot Error: InvokeAI generation failed or timed out`" - return image_url + image_name = await _get_image_name(session, base, batch_id) + if not image_name: + return "`GarfBot Error: Could not resolve image name`" + + logger.info(f"Got image: {image_name}") + + data = await _fetch_image_bytes(session, base, image_name) + if data: + return data + + logger.error("All image download attempts failed") + return "`GarfBot Error: Odie`" async def process_image_requests(self): - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(headers={"Accept": "application/json"}) as session: while True: request = await self.image_request_queue.get() - ctx = request["ctx"] + ctx = request["ctx"] prompt = request["prompt"] - image_url = await self.generate_image(prompt) - if image_url and "GarfBot Error" not in image_url: - logger.info("Downloading & sending image...") - async with session.get(image_url) as resp: - if resp.status == 200: - image_data = await resp.read() - image = io.BytesIO(image_data) - image.seek(0) - timestamp = ctx.message.created_at.strftime("%Y%m%d%H%M%S") - filename = f"{timestamp}_generated_image.png" - sendfile = discord.File(fp=image, filename=filename) - try: - await ctx.send(file=sendfile) - except Exception as e: - logger.error(e) - else: - await ctx.send("`GarfBot Error: Odie`") + + result = await self.generate_image(session, prompt) + + if isinstance(result, bytes): + logger.info("Sending image...") + image = io.BytesIO(result) + timestamp = ctx.message.created_at.strftime("%Y%m%d%H%M%S") + filename = f"{timestamp}_generated_image.png" + try: + await ctx.reply(file=discord.File(fp=image, filename=filename)) + except Exception as e: + logger.error(e) else: - await ctx.send(image_url) + await ctx.reply(result) + self.image_request_queue.task_done() await asyncio.sleep(2) - async def generate_chat(self, question): + async def generate_chat(self, question: str) -> str: try: - client = AsyncOpenAI( - api_key=self.openaikey, - base_url=self.baseurl - ) - response = await client.chat.completions.create( + response = await self._oai.chat.completions.create( model=self.txtmodel, messages=[ - { - "role": "system", - "content": self.sysprompt, - }, - {"role": "user", "content": question}, + {"role": "system", "content": self.sysprompt}, + {"role": "user", "content": question}, ], max_tokens=400, temperature=1.2, ) - answer = str(response.choices[0].message.content) + answer = response.choices[0].message.content return answer.replace("an AI language model", "a cartoon animal") except openai.BadRequestError as e: logger.error(e) @@ -110,12 +228,15 @@ class GarfAI: logger.error(e) return "`GarfBot Error: Lasagna`" - async def wikisum(self, query): + async def wikisum(self, query: str) -> str: try: summary = wikipedia.summary(query) - garfsum = await self.generate_chat( - f"Please summarize in your own words: {summary}" - ) - return garfsum + return await self.generate_chat(f"Please summarize in your own words: {summary}") + except wikipedia.exceptions.DisambiguationError as e: + options = ", ".join(e.options[:3]) + return f"`GarfBot Error: Ambiguous query — did you mean: {options}?`" + except wikipedia.exceptions.PageError: + return "`GarfBot Error: No Wikipedia page found for that query`" except Exception as e: - return e + logger.error(e) + return f"`GarfBot Error: {e}`"