Compare commits
10 Commits
34f34e3def
...
ollama-gtx
| Author | SHA1 | Date | |
|---|---|---|---|
| 4dbacab952 | |||
| 0ef3fd1e82 | |||
| f81a76e505 | |||
| 8212edc406 | |||
| 386a7e764d | |||
| 4880229e26 | |||
| bfef23eb72 | |||
| 940b6b80d9 | |||
| f77940dff9 | |||
| 68696e25f0 |
@@ -18,6 +18,7 @@ RUN apk update && \
|
||||
tcl-dev
|
||||
|
||||
RUN pip3 install --no-cache-dir \
|
||||
config \
|
||||
discord \
|
||||
openai \
|
||||
aiohttp \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
MIT No Attribution
|
||||
|
||||
Copyright 2025 Aaron Crate
|
||||
Copyright 2026 Aaron Crate
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this
|
||||
software and associated documentation files (the "Software"), to deal in the Software
|
||||
|
||||
67
garfmain.py
67
garfmain.py
@@ -1,7 +1,9 @@
|
||||
import re
|
||||
import config
|
||||
import asyncio
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
from garfpy import (
|
||||
help,
|
||||
@@ -16,7 +18,7 @@ from garfpy import (
|
||||
)
|
||||
|
||||
|
||||
gapikey = config.GIF_TOKEN
|
||||
# gapikey = config.GIF_TOKEN
|
||||
garfkey = config.GARFBOT_TOKEN
|
||||
txtmodel = config.TXT_MODEL
|
||||
imgmodel = config.IMG_MODEL
|
||||
@@ -40,6 +42,36 @@ kroger = Kroger()
|
||||
weather = WeatherAPI()
|
||||
|
||||
|
||||
URL_PATTERNS = [
|
||||
r'https?://(?:www\.)?youtube\.com/watch\?[^\s]*',
|
||||
r'https?://youtu\.be/[^\s]*',
|
||||
r'https?://(?:open\.)?spotify\.com/[^\s]*',
|
||||
]
|
||||
|
||||
def clean_url(url):
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
if 'youtube.com' in parsed.hostname:
|
||||
params = parse_qs(parsed.query)
|
||||
video_id = params.get('v', [None])[0]
|
||||
if not video_id:
|
||||
return None
|
||||
# timestamp = params.get('t', [None])[0]
|
||||
# if timestamp:
|
||||
# return f"https://www.youtube.com/watch?v={video_id}&t={timestamp}"
|
||||
return f"https://www.youtube.com/watch?v={video_id}"
|
||||
|
||||
if 'youtu.be' in parsed.hostname:
|
||||
return f"https://youtu.be{parsed.path}"
|
||||
|
||||
if 'spotify.com' in parsed.hostname:
|
||||
return f"https://open.spotify.com{parsed.path}"
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@garfbot.event
|
||||
async def on_ready():
|
||||
try:
|
||||
@@ -85,30 +117,30 @@ async def garfbot_qr(ctx, *, text):
|
||||
f"QR Code Request - User: {ctx.author.name}, Server: {ctx.guild.name}, Text: {text}"
|
||||
)
|
||||
if len(text) > 1000:
|
||||
await ctx.send("❌ Text too long! Maximum 1000 characters.")
|
||||
await ctx.reply("❌ Text too long! Maximum 1000 characters.")
|
||||
else:
|
||||
try:
|
||||
qr_code = await generate_qr(text)
|
||||
sendfile = discord.File(fp=qr_code, filename="qrcode.png")
|
||||
await ctx.send(file=sendfile)
|
||||
await ctx.reply(file=sendfile)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
await ctx.send(e)
|
||||
await ctx.reply(e)
|
||||
|
||||
|
||||
@garfbot.command(name="wiki")
|
||||
async def garfbot_wiki(ctx, *, query):
|
||||
summary = await garfield.wikisum(query)
|
||||
await ctx.send(summary)
|
||||
await ctx.reply(summary)
|
||||
|
||||
|
||||
@garfbot.command(name="shop")
|
||||
async def garfbot_shop(ctx, *, query):
|
||||
try:
|
||||
response = kroger.garfshop(query)
|
||||
await ctx.send(response)
|
||||
await ctx.reply(response)
|
||||
except Exception as e:
|
||||
await ctx.send(f"`GarfBot Error: {str(e)}`")
|
||||
await ctx.reply(f"`GarfBot Error: {str(e)}`")
|
||||
|
||||
|
||||
@garfbot.command(name="weather")
|
||||
@@ -118,11 +150,15 @@ async def garfbot_weather(ctx, *, location):
|
||||
|
||||
@garfbot.command(name="chat")
|
||||
async def garfchat(ctx, *, prompt):
|
||||
if "is this true" in prompt.lower():
|
||||
messages = [msg async for msg in ctx.channel.history(limit=2)]
|
||||
prompt = messages[1].content
|
||||
prompt = f"Is this true: {prompt}"
|
||||
answer = await garfield.generate_chat(prompt)
|
||||
logger.info(
|
||||
f"Chat Request - User: {ctx.author.name}, Server: {ctx.guild.name}, Prompt: {prompt}"
|
||||
)
|
||||
await ctx.send(answer)
|
||||
await ctx.reply(answer)
|
||||
|
||||
|
||||
@garfbot.command(name="pic")
|
||||
@@ -130,7 +166,7 @@ async def garfpic(ctx, *, prompt):
|
||||
logger.info(
|
||||
f"Image Request - User: {ctx.author.name}, Server: {ctx.guild.name}, Prompt: {prompt}"
|
||||
)
|
||||
await ctx.send(f"`Please wait... image generation queued: {prompt}`")
|
||||
await ctx.reply(f"`Please wait... image generation queued: {prompt}`")
|
||||
await garfield.garfpic(ctx, prompt)
|
||||
|
||||
|
||||
@@ -148,6 +184,19 @@ async def on_message(message):
|
||||
content = message.content.strip()
|
||||
lower = content.lower()
|
||||
|
||||
# # Remove tracking stuff from youtube and spotify links
|
||||
# cleaned_urls = []
|
||||
|
||||
# for pattern in URL_PATTERNS:
|
||||
# for match in re.finditer(pattern, message.content):
|
||||
# cleaned = clean_url(match.group(0))
|
||||
# if cleaned and cleaned != match.group(0):
|
||||
# cleaned_urls.append(cleaned)
|
||||
|
||||
# if cleaned_urls:
|
||||
# links = '\n'.join(cleaned_urls)
|
||||
# await message.reply(f"🔗 Cleaned link{'s' if len(cleaned_urls) > 1 else ''}:\n{links}")
|
||||
|
||||
# Chats & pics
|
||||
if lower.startswith("hey garfield") or isinstance(
|
||||
message.channel, discord.DMChannel
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import discord
|
||||
from garfpy import logger
|
||||
@@ -121,3 +122,17 @@ async def aod_message(garfbot, message):
|
||||
for field, values in zip(table_fields, table_columns):
|
||||
stats_embed.add_field(name=field, value="\n".join(values), inline=True)
|
||||
await message.channel.send(embed=stats_embed)
|
||||
|
||||
# # Boy You Said It
|
||||
# words = re.findall(r"[a-zA-Z']+", message.content.lower())
|
||||
# stops = {"a", "an", "the", "and", "or", "but", "is", "it", "in", "on", "at", "to", "of"}
|
||||
# words = [w for w in words if w not in stops]
|
||||
|
||||
# if words:
|
||||
# firsts = [w[0] for w in words]
|
||||
# commons = max(set(firsts), key=firsts.count)
|
||||
# count = firsts.count(commons)
|
||||
|
||||
# if count >= 3 or (len(words) >= 2 and count / len(words) >= 0.75):
|
||||
# await message.channel.send("Boy, you said it!")
|
||||
|
||||
|
||||
253
garfpy/garfai.py
253
garfpy/garfai.py
@@ -1,4 +1,6 @@
|
||||
import io
|
||||
import time
|
||||
import uuid
|
||||
import openai
|
||||
import config
|
||||
import aiohttp
|
||||
@@ -9,90 +11,212 @@ from openai import AsyncOpenAI
|
||||
from garfpy import logger
|
||||
|
||||
|
||||
INVOKEAI_BASE = config.INVOKEAI_URL
|
||||
|
||||
_MODEL_KEY = "0eb50094-5c9b-431b-ba01-87e145edb849"
|
||||
_VAE_KEY = "dde3627c-8a45-4088-93d1-66c44acbb337"
|
||||
_ENCODER_KEY = "7ba22542-4687-4946-a52e-c92f925f4b75"
|
||||
|
||||
_MODEL_REF = {"key": _MODEL_KEY, "hash": "blake3:c3ee838d71d99497db01fae6f304eafd9e734e935f3b783e968d50febb56be2c", "name": "FLUX.2 Klein 4B (GGUF Q4)", "base": "flux2", "type": "main"}
|
||||
_VAE_REF = {"key": _VAE_KEY, "hash": "blake3:531855de70db993d0f6181f82cde27d15411d58b7ffa3b2fdce2b9434c0173c2", "name": "FLUX.2 VAE", "base": "flux2", "type": "vae"}
|
||||
_ENCODER_REF = {"key": _ENCODER_KEY, "hash": "blake3:af5840e6770dc99f678e69867949c8b9264835915eb82a990e940fa6e4fa6c81", "name": "FLUX.2 Klein Qwen3 4B Encoder", "base": "any", "type": "qwen3_encoder"}
|
||||
|
||||
_POLL_INTERVAL = 2
|
||||
_POLL_ATTEMPTS = 60
|
||||
_MIN_IMAGE_BYTES = 1000
|
||||
|
||||
|
||||
def _node_id(prefix: str) -> str:
|
||||
return f"{prefix}:{uuid.uuid4().hex[:10]}"
|
||||
|
||||
|
||||
def _build_graph(prompt: str) -> dict:
|
||||
seed = int(time.time() * 1000) % (2 ** 31)
|
||||
|
||||
p = _node_id("positive_prompt")
|
||||
ml = _node_id("flux2_klein_model_loader")
|
||||
te = _node_id("flux2_klein_text_encoder")
|
||||
dn = _node_id("flux2_denoise")
|
||||
out = _node_id("canvas_output")
|
||||
|
||||
nodes = {
|
||||
p: {"id": p, "is_intermediate": True, "use_cache": True, "value": prompt, "type": "string"},
|
||||
ml: {"id": ml, "is_intermediate": True, "use_cache": True, "type": "flux2_klein_model_loader",
|
||||
"model": _MODEL_REF, "vae_model": _VAE_REF, "qwen3_encoder_model": _ENCODER_REF},
|
||||
te: {"id": te, "is_intermediate": True, "use_cache": True, "type": "flux2_klein_text_encoder"},
|
||||
dn: {"id": dn, "is_intermediate": True, "use_cache": True, "type": "flux2_denoise", "seed": seed},
|
||||
out: {"id": out, "is_intermediate": False, "use_cache": False, "type": "flux2_vae_decode"},
|
||||
}
|
||||
|
||||
edges = [
|
||||
{"source": {"node_id": ml, "field": "qwen3_encoder"}, "destination": {"node_id": te, "field": "qwen3_encoder"}},
|
||||
{"source": {"node_id": ml, "field": "max_seq_len"}, "destination": {"node_id": te, "field": "max_seq_len"}},
|
||||
{"source": {"node_id": p, "field": "value"}, "destination": {"node_id": te, "field": "prompt"}},
|
||||
{"source": {"node_id": ml, "field": "transformer"}, "destination": {"node_id": dn, "field": "transformer"}},
|
||||
{"source": {"node_id": ml, "field": "vae"}, "destination": {"node_id": dn, "field": "vae"}},
|
||||
{"source": {"node_id": te, "field": "conditioning"}, "destination": {"node_id": dn, "field": "positive_text_conditioning"}},
|
||||
{"source": {"node_id": ml, "field": "vae"}, "destination": {"node_id": out, "field": "vae"}},
|
||||
{"source": {"node_id": dn, "field": "latents"}, "destination": {"node_id": out, "field": "latents"}},
|
||||
]
|
||||
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
|
||||
async def _poll_batch(session: aiohttp.ClientSession, base: str, batch_id: str) -> bool:
|
||||
"""Poll batch status until completed, failed, or timed out. Returns True on success."""
|
||||
for _ in range(_POLL_ATTEMPTS):
|
||||
await asyncio.sleep(_POLL_INTERVAL)
|
||||
try:
|
||||
async with session.get(f"{base}/api/v1/queue/default/b/{batch_id}/status") as resp:
|
||||
if not resp.ok:
|
||||
continue
|
||||
s = await resp.json(content_type=None)
|
||||
total = s.get("total", 0)
|
||||
completed = s.get("completed", 0)
|
||||
failed = s.get("failed", 0)
|
||||
if total > 0 and failed >= total:
|
||||
logger.error(f"Batch {batch_id} failed")
|
||||
return False
|
||||
if total > 0 and completed >= total:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"InvokeAI poll error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _get_image_name(session: aiohttp.ClientSession, base: str, batch_id: str) -> str | None:
|
||||
try:
|
||||
async with session.get(f"{base}/api/v1/queue/default/i/{batch_id}") as resp:
|
||||
if resp.ok:
|
||||
data = await resp.json(content_type=None)
|
||||
for node_result in data.get("session", {}).get("results", {}).values():
|
||||
if node_result.get("type") == "image_output":
|
||||
return node_result["image"]["image_name"]
|
||||
except Exception as e:
|
||||
logger.error(f"Item fetch error: {e}")
|
||||
|
||||
try:
|
||||
async with session.get(
|
||||
f"{base}/api/v1/images/",
|
||||
params={"limit": 1, "order_by": "created_at", "direction": "DESC"},
|
||||
) as resp:
|
||||
if resp.ok:
|
||||
items = (await resp.json(content_type=None)).get("items", [])
|
||||
if items:
|
||||
return items[0]["image_name"]
|
||||
except Exception as e:
|
||||
logger.error(f"Image list fallback error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _fetch_image_bytes(session: aiohttp.ClientSession, base: str, name: str) -> bytes | None:
|
||||
"""Try the full image endpoint, then fall back to thumbnail."""
|
||||
urls = [
|
||||
f"{base}/api/v1/images/i/{name}/full",
|
||||
f"{base}/api/v1/images/i/{name}/thumbnail",
|
||||
]
|
||||
for url in urls:
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
ct = resp.headers.get("Content-Type", "")
|
||||
data = await resp.read()
|
||||
logger.info(f"Image fetch {url}: status={resp.status} content-type={ct} size={len(data)}")
|
||||
if "html" not in ct and len(data) >= _MIN_IMAGE_BYTES:
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"InvokeAI image fetch error ({url}): {e}")
|
||||
return None
|
||||
|
||||
|
||||
class GarfAI:
|
||||
def __init__(self):
|
||||
self.openaikey = config.OPENAI_TOKEN
|
||||
self.txtmodel = config.TXT_MODEL
|
||||
self.imgmodel = config.IMG_MODEL
|
||||
self.baseurl = config.BASE_URL
|
||||
self.sysprompt = config.SYSTEM_PROMPT
|
||||
self.txtmodel = config.TXT_MODEL
|
||||
self.imgmodel = config.IMG_MODEL
|
||||
|
||||
self._oai = AsyncOpenAI(
|
||||
api_key=config.OPENAI_TOKEN,
|
||||
base_url=config.BASE_URL,
|
||||
)
|
||||
self.image_request_queue = asyncio.Queue()
|
||||
|
||||
async def garfpic(self, ctx, prompt):
|
||||
await self.image_request_queue.put({"ctx": ctx, "prompt": prompt})
|
||||
|
||||
async def generate_image(self, prompt):
|
||||
client = AsyncOpenAI(api_key=self.openaikey)
|
||||
async def generate_image(self, session: aiohttp.ClientSession, prompt: str) -> bytes | str:
|
||||
base = INVOKEAI_BASE
|
||||
|
||||
try:
|
||||
response = await client.images.generate(
|
||||
model=self.imgmodel, prompt=prompt, n=1, size="1024x1024"
|
||||
)
|
||||
except openai.BadRequestError as e:
|
||||
return f"`GarfBot Error: ({e.status_code}) - Your request was rejected as a result of our safety system.`"
|
||||
except openai.InternalServerError as e:
|
||||
logger.error(e)
|
||||
return f"`GarfBot Error: ({e.status_code}) - Monday`"
|
||||
async with session.post(
|
||||
f"{base}/api/v1/queue/default/enqueue_batch",
|
||||
json={"batch": {"graph": _build_graph(prompt), "runs": 1}},
|
||||
) as resp:
|
||||
if not resp.ok:
|
||||
text = await resp.text()
|
||||
logger.error(f"InvokeAI enqueue failed {resp.status}: {text}")
|
||||
return "`GarfBot Error: InvokeAI rejected the request`"
|
||||
data = await resp.json(content_type=None)
|
||||
batch_id = data["batch"]["batch_id"]
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return "`GarfBot Error: Lasagna`"
|
||||
data = getattr(response, "data", None)
|
||||
if not data:
|
||||
logger.error("No data in response")
|
||||
return "`GarfBot Error: No images generated`"
|
||||
logger.error(f"InvokeAI enqueue error: {e}")
|
||||
return "`GarfBot Error: Couldn't reach InvokeAI`"
|
||||
|
||||
first_image = data[0] if len(data) > 0 else None
|
||||
if not first_image:
|
||||
logger.error("No image in response data")
|
||||
return "`GarfBot Error: No images generated`"
|
||||
logger.info(f"InvokeAI batch queued: {batch_id}")
|
||||
|
||||
image_url = getattr(first_image, "url", None)
|
||||
if not image_url:
|
||||
logger.error("No URL in image response")
|
||||
return "`GarfBot Error: No image URL returned`"
|
||||
if not await _poll_batch(session, base, batch_id):
|
||||
return "`GarfBot Error: InvokeAI generation failed or timed out`"
|
||||
|
||||
return image_url
|
||||
image_name = await _get_image_name(session, base, batch_id)
|
||||
if not image_name:
|
||||
return "`GarfBot Error: Could not resolve image name`"
|
||||
|
||||
logger.info(f"Got image: {image_name}")
|
||||
|
||||
data = await _fetch_image_bytes(session, base, image_name)
|
||||
if data:
|
||||
return data
|
||||
|
||||
logger.error("All image download attempts failed")
|
||||
return "`GarfBot Error: Odie`"
|
||||
|
||||
async def process_image_requests(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(headers={"Accept": "application/json"}) as session:
|
||||
while True:
|
||||
request = await self.image_request_queue.get()
|
||||
ctx = request["ctx"]
|
||||
ctx = request["ctx"]
|
||||
prompt = request["prompt"]
|
||||
image_url = await self.generate_image(prompt)
|
||||
if image_url and "GarfBot Error" not in image_url:
|
||||
logger.info("Downloading & sending image...")
|
||||
async with session.get(image_url) as resp:
|
||||
if resp.status == 200:
|
||||
image_data = await resp.read()
|
||||
image = io.BytesIO(image_data)
|
||||
image.seek(0)
|
||||
timestamp = ctx.message.created_at.strftime("%Y%m%d%H%M%S")
|
||||
filename = f"{timestamp}_generated_image.png"
|
||||
sendfile = discord.File(fp=image, filename=filename)
|
||||
try:
|
||||
await ctx.send(file=sendfile)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
else:
|
||||
await ctx.send("`GarfBot Error: Odie`")
|
||||
|
||||
result = await self.generate_image(session, prompt)
|
||||
|
||||
if isinstance(result, bytes):
|
||||
logger.info("Sending image...")
|
||||
image = io.BytesIO(result)
|
||||
timestamp = ctx.message.created_at.strftime("%Y%m%d%H%M%S")
|
||||
filename = f"{timestamp}_generated_image.png"
|
||||
try:
|
||||
await ctx.reply(file=discord.File(fp=image, filename=filename))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
else:
|
||||
await ctx.send(image_url)
|
||||
await ctx.reply(result)
|
||||
|
||||
self.image_request_queue.task_done()
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def generate_chat(self, question):
|
||||
async def generate_chat(self, question: str) -> str:
|
||||
try:
|
||||
client = AsyncOpenAI(api_key=self.openaikey)
|
||||
response = await client.chat.completions.create(
|
||||
response = await self._oai.chat.completions.create(
|
||||
model=self.txtmodel,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Pretend you are sarcastic Garfield.",
|
||||
},
|
||||
{"role": "user", "content": f"{question}"},
|
||||
{"role": "system", "content": self.sysprompt},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
max_tokens=400,
|
||||
temperature=1.2,
|
||||
)
|
||||
answer = str(response.choices[0].message.content)
|
||||
answer = response.choices[0].message.content
|
||||
return answer.replace("an AI language model", "a cartoon animal")
|
||||
except openai.BadRequestError as e:
|
||||
logger.error(e)
|
||||
@@ -104,12 +228,15 @@ class GarfAI:
|
||||
logger.error(e)
|
||||
return "`GarfBot Error: Lasagna`"
|
||||
|
||||
async def wikisum(self, query):
|
||||
async def wikisum(self, query: str) -> str:
|
||||
try:
|
||||
summary = wikipedia.summary(query)
|
||||
garfsum = await self.generate_chat(
|
||||
f"Please summarize in your own words: {summary}"
|
||||
)
|
||||
return garfsum
|
||||
return await self.generate_chat(f"Please summarize in your own words: {summary}")
|
||||
except wikipedia.exceptions.DisambiguationError as e:
|
||||
options = ", ".join(e.options[:3])
|
||||
return f"`GarfBot Error: Ambiguous query — did you mean: {options}?`"
|
||||
except wikipedia.exceptions.PageError:
|
||||
return "`GarfBot Error: No Wikipedia page found for that query`"
|
||||
except Exception as e:
|
||||
return e
|
||||
logger.error(e)
|
||||
return f"`GarfBot Error: {e}`"
|
||||
|
||||
@@ -6,9 +6,9 @@ async def help(message):
|
||||
embed.add_field(
|
||||
name="hey garfield `prompt`", value="*Responds with text.*", inline=True
|
||||
)
|
||||
embed.add_field(
|
||||
name="garfpic `prompt`", value="*Responds with an image.*", inline=True
|
||||
)
|
||||
# embed.add_field(
|
||||
# name="garfpic `prompt`", value="*Responds with an image.*", inline=True
|
||||
# )
|
||||
embed.add_field(
|
||||
name="garfping `target`",
|
||||
value="*Responds with iputils-ping result from target.*",
|
||||
|
||||
@@ -126,7 +126,7 @@ class WeatherAPI:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Error fetching weather data for '{location}'")
|
||||
logger.error(f"Error fetching weather data for '{location}' - {e}")
|
||||
await ctx.send(f"`Error fetching weather data for '{location}'`")
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user