MikuAI features: LLM and RVC

This commit is contained in:
James Shiffer
2024-03-31 21:36:09 +00:00
parent 8346f52f23
commit a8efab7788
5 changed files with 225 additions and 65 deletions

View File

@@ -4,12 +4,17 @@
*/
import {
Attachment,
AttachmentBuilder,
Client,
Collection,
Events,
GatewayIntentBits,
Interaction,
Message,
MessageFlags,
MessageReaction,
MessageType,
PartialMessageReaction,
Partials, SlashCommandBuilder,
TextChannel,
@@ -18,8 +23,10 @@ import {
import fs = require('node:fs');
import path = require('node:path');
import fetch from 'node-fetch';
import FormData = require('form-data');
import tmp = require('tmp');
import { JSDOM } from 'jsdom';
import {logError, logInfo, logWarn} from '../logging';
import { logError, logInfo, logWarn } from '../logging';
import {
db,
openDb,
@@ -27,14 +34,16 @@ import {
recordReaction,
sync
} from './util';
import 'dotenv/config';
const config = {};
interface CommandClient extends Client {
commands?: Collection<string, { data: SlashCommandBuilder, execute: (interaction: Interaction) => Promise<void> }>
}
const client: CommandClient = new Client({
intents: [GatewayIntentBits.Guilds, GatewayIntentBits.GuildMessages, GatewayIntentBits.GuildMessageReactions],
intents: [GatewayIntentBits.Guilds, GatewayIntentBits.GuildMessages, GatewayIntentBits.GuildMessageReactions, GatewayIntentBits.MessageContent],
partials: [Partials.Message, Partials.Channel, Partials.Reaction],
});
client.commands = new Collection();
@@ -75,6 +84,86 @@ async function onMessageReactionChanged(reaction: MessageReaction | PartialMessa
await recordReaction(<MessageReaction> reaction);
}
function textOnlyMessages(message: Message)
{
return message.cleanContent.length > 0 &&
(message.type === MessageType.Default || message.type === MessageType.Reply);
}
function isGoodResponse(response: string)
{
return response.length > 0 && !(response in [
'@Today Man-San(1990)🍁🍂',
'@1981 Celical Man🍁🍂',
'@Exiled Sammy 🔒🏝⏱'
]);
}
async function onNewMessage(message: Message)
{
if (message.author.bot) {
return;
}
/** First, handle audio messages */
if (message.flags.has(MessageFlags.IsVoiceMessage)) {
try {
const audio = await requestRVCResponse(message.attachments.first());
const audioBuf = await audio.arrayBuffer();
const audioFile = new AttachmentBuilder(Buffer.from(audioBuf)).setName('mikuified.wav');
await message.reply({
files: [audioFile]
});
} catch (err) {
logError(`[bot] Failed to generate audio message reply: ${err}`);
}
}
/** Text messages */
if (!textOnlyMessages(message)) {
return;
}
// Miku must reply when spoken to
const mustReply = message.mentions.has(process.env.CLIENT) || message.cleanContent.toLowerCase().includes('miku');
const history = await message.channel.messages.fetch({
limit: 4,
before: message.id
});
// change Miku's message probability depending on current message frequency
const historyMessages = [...history.values()].reverse();
//const historyTimes = historyMessages.map((m: Message) => m.createdAt.getTime());
//const historyAvgDelayMins = (historyTimes[historyTimes.length - 1] - historyTimes[0]) / 60000;
const replyChance = Math.floor(Math.random() * 1/Number(process.env.REPLY_CHANCE)) === 0;
const willReply = mustReply || replyChance;
if (!willReply) {
return;
}
const cleanHistory = historyMessages.filter(textOnlyMessages);
const cleanHistoryList = [
...cleanHistory,
message
];
await message.channel.sendTyping();
try {
const response = await requestLLMResponse(cleanHistoryList);
// evaluate response
if (!isGoodResponse(response)) {
logWarn(`[bot] Burning bad response: "${response}"`);
return;
}
await message.reply(response);
} catch (err) {
logError(`[bot] Error while generating LLM response: ${err}`);
}
}
async function fetchMotd()
{
const res = await fetch(process.env.MOTD_HREF);
@@ -84,6 +173,68 @@ async function fetchMotd()
return doc.querySelector(process.env.MOTD_QUERY).textContent;
}
async function requestRVCResponse(src: Attachment): Promise<Blob>
{
logInfo(`[bot] Downloading audio message ${src.url}`);
const srcres = await fetch(src.url);
const srcbuf = await srcres.arrayBuffer();
const tmpFile = tmp.fileSync();
const tmpFileName = tmpFile.name;
fs.writeFileSync(tmpFileName, Buffer.from(srcbuf));
logInfo(`[bot] Got audio file: ${srcbuf.size} bytes`);
const queryParams = new URLSearchParams();
queryParams.append("token", process.env.LLM_TOKEN);
const fd = new FormData();
fd.append('file', fs.readFileSync(tmpFileName), 'voice-message.ogg');
const rvcEndpoint = `http://${process.env.LLM_HOST}:${process.env.LLM_PORT}/rvc?${queryParams.toString()}`;
logInfo(`[bot] Requesting RVC response for ${src.id}`);
const res = await fetch(rvcEndpoint, {
method: 'POST',
body: fd
});
const resContents = await res.blob();
return resContents;
}
async function requestLLMResponse(messages)
{
const queryParams = new URLSearchParams();
queryParams.append("token", process.env.LLM_TOKEN);
for (const field of Object.keys(config["llmconf"].llmSettings)) {
queryParams.append(field, config["llmconf"].llmSettings[field]);
}
const llmEndpoint = `http://${process.env.LLM_HOST}:${process.env.LLM_PORT}/?${queryParams.toString()}`;
const messageList = messages.map((m: Message) => ({
role: m.author.bot ? "assistant" : "user",
content: m.cleanContent,
}));
const reqBody = [
{
"role": "system",
"content": config["llmconf"].sys_prompt
},
...messageList
];
logInfo("[bot] Requesting LLM response with message list: " + reqBody.map(m => m.content));
const res = await fetch(llmEndpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(reqBody)
});
const txt = await res.json();
const txtRaw: string = txt["raw"][0];
const prefix = "<|im_start|>assistant\n";
const suffix = "<|im_end|>";
const txtStart = txtRaw.lastIndexOf(prefix);
const txtEnd = txtRaw.lastIndexOf(suffix) > txtStart ? txtRaw.lastIndexOf(suffix) : txtRaw.length;
return txtRaw.slice(txtStart + prefix.length, txtEnd);
}
async function scheduleRandomMessage(firstTime = false)
{
if (!firstTime) {
@@ -108,6 +259,7 @@ client.on(Events.InteractionCreate, async interaction => {
if (!interaction.isChatInputCommand()) return;
});
client.on(Events.MessageCreate, onNewMessage);
client.on(Events.MessageReactionAdd, onMessageReactionChanged);
client.on(Events.MessageReactionRemove, onMessageReactionChanged);
client.on(Events.InteractionCreate, async interaction => {
@@ -135,6 +287,7 @@ client.on(Events.InteractionCreate, async interaction => {
// startup
(async () => {
tmp.setGracefulCleanup();
logInfo("[db] Opening...");
await openDb();
logInfo("[db] Migrating...");
@@ -151,6 +304,8 @@ client.on(Events.InteractionCreate, async interaction => {
const filePath = path.join(commandsPath, file);
const command = require(filePath);
client.commands.set(command.data.name, command);
config[command.data.name] = command.config;
logInfo(`[bot] Found command: /${command.data.name}`);
}
}