Files
garfbot/garfpy/garfai.py

243 lines
10 KiB
Python

import io
import time
import uuid
import openai
import config
import aiohttp
import asyncio
import discord
import wikipedia
from openai import AsyncOpenAI
from garfpy import logger
INVOKEAI_BASE = config.INVOKEAI_URL
_MODEL_KEY = "0eb50094-5c9b-431b-ba01-87e145edb849"
_VAE_KEY = "dde3627c-8a45-4088-93d1-66c44acbb337"
_ENCODER_KEY = "7ba22542-4687-4946-a52e-c92f925f4b75"
_MODEL_REF = {"key": _MODEL_KEY, "hash": "blake3:c3ee838d71d99497db01fae6f304eafd9e734e935f3b783e968d50febb56be2c", "name": "FLUX.2 Klein 4B (GGUF Q4)", "base": "flux2", "type": "main"}
_VAE_REF = {"key": _VAE_KEY, "hash": "blake3:531855de70db993d0f6181f82cde27d15411d58b7ffa3b2fdce2b9434c0173c2", "name": "FLUX.2 VAE", "base": "flux2", "type": "vae"}
_ENCODER_REF = {"key": _ENCODER_KEY, "hash": "blake3:af5840e6770dc99f678e69867949c8b9264835915eb82a990e940fa6e4fa6c81", "name": "FLUX.2 Klein Qwen3 4B Encoder", "base": "any", "type": "qwen3_encoder"}
_POLL_INTERVAL = 2
_POLL_ATTEMPTS = 60
_MIN_IMAGE_BYTES = 1000
def _node_id(prefix: str) -> str:
return f"{prefix}:{uuid.uuid4().hex[:10]}"
def _build_graph(prompt: str) -> dict:
seed = int(time.time() * 1000) % (2 ** 31)
p = _node_id("positive_prompt")
ml = _node_id("flux2_klein_model_loader")
te = _node_id("flux2_klein_text_encoder")
dn = _node_id("flux2_denoise")
out = _node_id("canvas_output")
nodes = {
p: {"id": p, "is_intermediate": True, "use_cache": True, "value": prompt, "type": "string"},
ml: {"id": ml, "is_intermediate": True, "use_cache": True, "type": "flux2_klein_model_loader",
"model": _MODEL_REF, "vae_model": _VAE_REF, "qwen3_encoder_model": _ENCODER_REF},
te: {"id": te, "is_intermediate": True, "use_cache": True, "type": "flux2_klein_text_encoder"},
dn: {"id": dn, "is_intermediate": True, "use_cache": True, "type": "flux2_denoise", "seed": seed},
out: {"id": out, "is_intermediate": False, "use_cache": False, "type": "flux2_vae_decode"},
}
edges = [
{"source": {"node_id": ml, "field": "qwen3_encoder"}, "destination": {"node_id": te, "field": "qwen3_encoder"}},
{"source": {"node_id": ml, "field": "max_seq_len"}, "destination": {"node_id": te, "field": "max_seq_len"}},
{"source": {"node_id": p, "field": "value"}, "destination": {"node_id": te, "field": "prompt"}},
{"source": {"node_id": ml, "field": "transformer"}, "destination": {"node_id": dn, "field": "transformer"}},
{"source": {"node_id": ml, "field": "vae"}, "destination": {"node_id": dn, "field": "vae"}},
{"source": {"node_id": te, "field": "conditioning"}, "destination": {"node_id": dn, "field": "positive_text_conditioning"}},
{"source": {"node_id": ml, "field": "vae"}, "destination": {"node_id": out, "field": "vae"}},
{"source": {"node_id": dn, "field": "latents"}, "destination": {"node_id": out, "field": "latents"}},
]
return {"nodes": nodes, "edges": edges}
async def _poll_batch(session: aiohttp.ClientSession, base: str, batch_id: str) -> bool:
"""Poll batch status until completed, failed, or timed out. Returns True on success."""
for _ in range(_POLL_ATTEMPTS):
await asyncio.sleep(_POLL_INTERVAL)
try:
async with session.get(f"{base}/api/v1/queue/default/b/{batch_id}/status") as resp:
if not resp.ok:
continue
s = await resp.json(content_type=None)
total = s.get("total", 0)
completed = s.get("completed", 0)
failed = s.get("failed", 0)
if total > 0 and failed >= total:
logger.error(f"Batch {batch_id} failed")
return False
if total > 0 and completed >= total:
return True
except Exception as e:
logger.error(f"InvokeAI poll error: {e}")
return False
async def _get_image_name(session: aiohttp.ClientSession, base: str, batch_id: str) -> str | None:
try:
async with session.get(f"{base}/api/v1/queue/default/i/{batch_id}") as resp:
if resp.ok:
data = await resp.json(content_type=None)
for node_result in data.get("session", {}).get("results", {}).values():
if node_result.get("type") == "image_output":
return node_result["image"]["image_name"]
except Exception as e:
logger.error(f"Item fetch error: {e}")
try:
async with session.get(
f"{base}/api/v1/images/",
params={"limit": 1, "order_by": "created_at", "direction": "DESC"},
) as resp:
if resp.ok:
items = (await resp.json(content_type=None)).get("items", [])
if items:
return items[0]["image_name"]
except Exception as e:
logger.error(f"Image list fallback error: {e}")
return None
async def _fetch_image_bytes(session: aiohttp.ClientSession, base: str, name: str) -> bytes | None:
"""Try the full image endpoint, then fall back to thumbnail."""
urls = [
f"{base}/api/v1/images/i/{name}/full",
f"{base}/api/v1/images/i/{name}/thumbnail",
]
for url in urls:
try:
async with session.get(url) as resp:
ct = resp.headers.get("Content-Type", "")
data = await resp.read()
logger.info(f"Image fetch {url}: status={resp.status} content-type={ct} size={len(data)}")
if "html" not in ct and len(data) >= _MIN_IMAGE_BYTES:
return data
except Exception as e:
logger.error(f"InvokeAI image fetch error ({url}): {e}")
return None
class GarfAI:
def __init__(self):
self.baseurl = config.BASE_URL
self.sysprompt = config.SYSTEM_PROMPT
self.txtmodel = config.TXT_MODEL
self.imgmodel = config.IMG_MODEL
self._oai = AsyncOpenAI(
api_key=config.OPENAI_TOKEN,
base_url=config.BASE_URL,
)
self.image_request_queue = asyncio.Queue()
async def garfpic(self, ctx, prompt):
await self.image_request_queue.put({"ctx": ctx, "prompt": prompt})
async def generate_image(self, session: aiohttp.ClientSession, prompt: str) -> bytes | str:
base = INVOKEAI_BASE
try:
async with session.post(
f"{base}/api/v1/queue/default/enqueue_batch",
json={"batch": {"graph": _build_graph(prompt), "runs": 1}},
) as resp:
if not resp.ok:
text = await resp.text()
logger.error(f"InvokeAI enqueue failed {resp.status}: {text}")
return "`GarfBot Error: InvokeAI rejected the request`"
data = await resp.json(content_type=None)
batch_id = data["batch"]["batch_id"]
except Exception as e:
logger.error(f"InvokeAI enqueue error: {e}")
return "`GarfBot Error: Couldn't reach InvokeAI`"
logger.info(f"InvokeAI batch queued: {batch_id}")
if not await _poll_batch(session, base, batch_id):
return "`GarfBot Error: InvokeAI generation failed or timed out`"
image_name = await _get_image_name(session, base, batch_id)
if not image_name:
return "`GarfBot Error: Could not resolve image name`"
logger.info(f"Got image: {image_name}")
data = await _fetch_image_bytes(session, base, image_name)
if data:
return data
logger.error("All image download attempts failed")
return "`GarfBot Error: Odie`"
async def process_image_requests(self):
async with aiohttp.ClientSession(headers={"Accept": "application/json"}) as session:
while True:
request = await self.image_request_queue.get()
ctx = request["ctx"]
prompt = request["prompt"]
result = await self.generate_image(session, prompt)
if isinstance(result, bytes):
logger.info("Sending image...")
image = io.BytesIO(result)
timestamp = ctx.message.created_at.strftime("%Y%m%d%H%M%S")
filename = f"{timestamp}_generated_image.png"
try:
await ctx.reply(file=discord.File(fp=image, filename=filename))
except Exception as e:
logger.error(e)
else:
await ctx.reply(result)
self.image_request_queue.task_done()
await asyncio.sleep(2)
async def generate_chat(self, question: str) -> str:
try:
response = await self._oai.chat.completions.create(
model=self.txtmodel,
messages=[
{"role": "system", "content": self.sysprompt},
{"role": "user", "content": question},
],
max_tokens=400,
temperature=1.2,
)
answer = response.choices[0].message.content
return answer.replace("an AI language model", "a cartoon animal")
except openai.BadRequestError as e:
logger.error(e)
return f"`GarfBot Error: {e}`"
except openai.APIError as e:
logger.error(e)
return "`GarfBot Error: Monday`"
except Exception as e:
logger.error(e)
return "`GarfBot Error: Lasagna`"
async def wikisum(self, query: str) -> str:
try:
summary = wikipedia.summary(query)
return await self.generate_chat(f"Please summarize in your own words: {summary}")
except wikipedia.exceptions.DisambiguationError as e:
options = ", ".join(e.options[:3])
return f"`GarfBot Error: Ambiguous query — did you mean: {options}?`"
except wikipedia.exceptions.PageError:
return "`GarfBot Error: No Wikipedia page found for that query`"
except Exception as e:
logger.error(e)
return f"`GarfBot Error: {e}`"