import io import time import uuid import openai import config import aiohttp import asyncio import discord import wikipedia from openai import AsyncOpenAI from garfpy import logger INVOKEAI_BASE = config.INVOKEAI_URL _MODEL_KEY = "0eb50094-5c9b-431b-ba01-87e145edb849" _VAE_KEY = "dde3627c-8a45-4088-93d1-66c44acbb337" _ENCODER_KEY = "7ba22542-4687-4946-a52e-c92f925f4b75" _MODEL_REF = {"key": _MODEL_KEY, "hash": "blake3:c3ee838d71d99497db01fae6f304eafd9e734e935f3b783e968d50febb56be2c", "name": "FLUX.2 Klein 4B (GGUF Q4)", "base": "flux2", "type": "main"} _VAE_REF = {"key": _VAE_KEY, "hash": "blake3:531855de70db993d0f6181f82cde27d15411d58b7ffa3b2fdce2b9434c0173c2", "name": "FLUX.2 VAE", "base": "flux2", "type": "vae"} _ENCODER_REF = {"key": _ENCODER_KEY, "hash": "blake3:af5840e6770dc99f678e69867949c8b9264835915eb82a990e940fa6e4fa6c81", "name": "FLUX.2 Klein Qwen3 4B Encoder", "base": "any", "type": "qwen3_encoder"} _POLL_INTERVAL = 2 _POLL_ATTEMPTS = 60 _MIN_IMAGE_BYTES = 1000 def _node_id(prefix: str) -> str: return f"{prefix}:{uuid.uuid4().hex[:10]}" def _build_graph(prompt: str) -> dict: seed = int(time.time() * 1000) % (2 ** 31) p = _node_id("positive_prompt") ml = _node_id("flux2_klein_model_loader") te = _node_id("flux2_klein_text_encoder") dn = _node_id("flux2_denoise") out = _node_id("canvas_output") nodes = { p: {"id": p, "is_intermediate": True, "use_cache": True, "value": prompt, "type": "string"}, ml: {"id": ml, "is_intermediate": True, "use_cache": True, "type": "flux2_klein_model_loader", "model": _MODEL_REF, "vae_model": _VAE_REF, "qwen3_encoder_model": _ENCODER_REF}, te: {"id": te, "is_intermediate": True, "use_cache": True, "type": "flux2_klein_text_encoder"}, dn: {"id": dn, "is_intermediate": True, "use_cache": True, "type": "flux2_denoise", "seed": seed}, out: {"id": out, "is_intermediate": False, "use_cache": False, "type": "flux2_vae_decode"}, } edges = [ {"source": {"node_id": ml, "field": "qwen3_encoder"}, "destination": {"node_id": te, "field": "qwen3_encoder"}}, {"source": {"node_id": ml, "field": "max_seq_len"}, "destination": {"node_id": te, "field": "max_seq_len"}}, {"source": {"node_id": p, "field": "value"}, "destination": {"node_id": te, "field": "prompt"}}, {"source": {"node_id": ml, "field": "transformer"}, "destination": {"node_id": dn, "field": "transformer"}}, {"source": {"node_id": ml, "field": "vae"}, "destination": {"node_id": dn, "field": "vae"}}, {"source": {"node_id": te, "field": "conditioning"}, "destination": {"node_id": dn, "field": "positive_text_conditioning"}}, {"source": {"node_id": ml, "field": "vae"}, "destination": {"node_id": out, "field": "vae"}}, {"source": {"node_id": dn, "field": "latents"}, "destination": {"node_id": out, "field": "latents"}}, ] return {"nodes": nodes, "edges": edges} async def _poll_batch(session: aiohttp.ClientSession, base: str, batch_id: str) -> bool: """Poll batch status until completed, failed, or timed out. Returns True on success.""" for _ in range(_POLL_ATTEMPTS): await asyncio.sleep(_POLL_INTERVAL) try: async with session.get(f"{base}/api/v1/queue/default/b/{batch_id}/status") as resp: if not resp.ok: continue s = await resp.json(content_type=None) total = s.get("total", 0) completed = s.get("completed", 0) failed = s.get("failed", 0) if total > 0 and failed >= total: logger.error(f"Batch {batch_id} failed") return False if total > 0 and completed >= total: return True except Exception as e: logger.error(f"InvokeAI poll error: {e}") return False async def _get_image_name(session: aiohttp.ClientSession, base: str, batch_id: str) -> str | None: try: async with session.get(f"{base}/api/v1/queue/default/i/{batch_id}") as resp: if resp.ok: data = await resp.json(content_type=None) for node_result in data.get("session", {}).get("results", {}).values(): if node_result.get("type") == "image_output": return node_result["image"]["image_name"] except Exception as e: logger.error(f"Item fetch error: {e}") try: async with session.get( f"{base}/api/v1/images/", params={"limit": 1, "order_by": "created_at", "direction": "DESC"}, ) as resp: if resp.ok: items = (await resp.json(content_type=None)).get("items", []) if items: return items[0]["image_name"] except Exception as e: logger.error(f"Image list fallback error: {e}") return None async def _fetch_image_bytes(session: aiohttp.ClientSession, base: str, name: str) -> bytes | None: """Try the full image endpoint, then fall back to thumbnail.""" urls = [ f"{base}/api/v1/images/i/{name}/full", f"{base}/api/v1/images/i/{name}/thumbnail", ] for url in urls: try: async with session.get(url) as resp: ct = resp.headers.get("Content-Type", "") data = await resp.read() logger.info(f"Image fetch {url}: status={resp.status} content-type={ct} size={len(data)}") if "html" not in ct and len(data) >= _MIN_IMAGE_BYTES: return data except Exception as e: logger.error(f"InvokeAI image fetch error ({url}): {e}") return None class GarfAI: def __init__(self): self.baseurl = config.BASE_URL self.sysprompt = config.SYSTEM_PROMPT self.txtmodel = config.TXT_MODEL self.imgmodel = config.IMG_MODEL self._oai = AsyncOpenAI( api_key=config.OPENAI_TOKEN, base_url=config.BASE_URL, ) self.image_request_queue = asyncio.Queue() async def garfpic(self, ctx, prompt): await self.image_request_queue.put({"ctx": ctx, "prompt": prompt}) async def generate_image(self, session: aiohttp.ClientSession, prompt: str) -> bytes | str: base = INVOKEAI_BASE try: async with session.post( f"{base}/api/v1/queue/default/enqueue_batch", json={"batch": {"graph": _build_graph(prompt), "runs": 1}}, ) as resp: if not resp.ok: text = await resp.text() logger.error(f"InvokeAI enqueue failed {resp.status}: {text}") return "`GarfBot Error: InvokeAI rejected the request`" data = await resp.json(content_type=None) batch_id = data["batch"]["batch_id"] except Exception as e: logger.error(f"InvokeAI enqueue error: {e}") return "`GarfBot Error: Couldn't reach InvokeAI`" logger.info(f"InvokeAI batch queued: {batch_id}") if not await _poll_batch(session, base, batch_id): return "`GarfBot Error: InvokeAI generation failed or timed out`" image_name = await _get_image_name(session, base, batch_id) if not image_name: return "`GarfBot Error: Could not resolve image name`" logger.info(f"Got image: {image_name}") data = await _fetch_image_bytes(session, base, image_name) if data: return data logger.error("All image download attempts failed") return "`GarfBot Error: Odie`" async def process_image_requests(self): async with aiohttp.ClientSession(headers={"Accept": "application/json"}) as session: while True: request = await self.image_request_queue.get() ctx = request["ctx"] prompt = request["prompt"] result = await self.generate_image(session, prompt) if isinstance(result, bytes): logger.info("Sending image...") image = io.BytesIO(result) timestamp = ctx.message.created_at.strftime("%Y%m%d%H%M%S") filename = f"{timestamp}_generated_image.png" try: await ctx.reply(file=discord.File(fp=image, filename=filename)) except Exception as e: logger.error(e) else: await ctx.reply(result) self.image_request_queue.task_done() await asyncio.sleep(2) async def generate_chat(self, question: str) -> str: try: response = await self._oai.chat.completions.create( model=self.txtmodel, messages=[ {"role": "system", "content": self.sysprompt}, {"role": "user", "content": question}, ], max_tokens=400, temperature=1.2, ) answer = response.choices[0].message.content return answer.replace("an AI language model", "a cartoon animal") except openai.BadRequestError as e: logger.error(e) return f"`GarfBot Error: {e}`" except openai.APIError as e: logger.error(e) return "`GarfBot Error: Monday`" except Exception as e: logger.error(e) return "`GarfBot Error: Lasagna`" async def wikisum(self, query: str) -> str: try: summary = wikipedia.summary(query) return await self.generate_chat(f"Please summarize in your own words: {summary}") except wikipedia.exceptions.DisambiguationError as e: options = ", ".join(e.options[:3]) return f"`GarfBot Error: Ambiguous query — did you mean: {options}?`" except wikipedia.exceptions.PageError: return "`GarfBot Error: No Wikipedia page found for that query`" except Exception as e: logger.error(e) return f"`GarfBot Error: {e}`"