Files
FemScoreboard/discord/provider/ollama.ts

163 lines
5.7 KiB
TypeScript

import { Message } from 'discord.js';
import { LLMProvider } from './provider';
import 'dotenv/config';
import { serializeMessageHistory } from '../util';
import { logError, logInfo } from '../../logging';
import { LLMConfig } from '../commands/types';
import { Ollama } from 'ollama';
const USER_PROMPT = `Continue the following Discord conversation by completing the next message, playing the role of Hatsune Miku. The conversation must progress forward, and you must avoid repeating yourself.
Each message is represented as a line of JSON. Refer to other users by their "name" instead of their "author" field whenever possible.
The conversation is as follows. The last line is the message you have to complete. Please ONLY return the string contents of the "content" field, that go in place of the ellipses "...". Do not include the enclosing quotation marks, or any JSON syntax, in your response.
`;
export class OllamaProvider implements LLMProvider {
private client: Ollama;
private model: string;
constructor(
host: string | undefined = process.env.LLM_HOST,
model = 'socialnetwooky/hermes3-llama3.1-abliterated:8b-q5_k_m-64k'
) {
if (!host) {
throw new TypeError(
'Ollama host was not passed in, and environment variable LLM_HOST was unset!'
);
}
this.client = new Ollama({ host });
this.model = model;
}
name() {
return `Ollama (${this.model})`;
}
setModel(id: string) {
this.model = id;
}
async requestLLMResponse(
history: Message[],
sysprompt: string,
params: LLMConfig
): Promise<string> {
let messageList = await Promise.all(history.map(serializeMessageHistory));
messageList = messageList.filter((x) => !!x);
if (messageList.length === 0) {
throw new TypeError('No messages with content provided in history!');
}
// dummy message for last line of prompt
const lastMsg = messageList[messageList.length - 1];
// advance by 5 seconds
let newDate = new Date(lastMsg!.timestamp);
newDate.setSeconds(newDate.getSeconds() + 5);
let templateMsgTxt = JSON.stringify({
timestamp: newDate.toUTCString(),
author: 'Hatsune Miku',
name: 'Hatsune Miku',
context: lastMsg!.content,
content: '...',
});
const messageHistoryTxt =
messageList.map((msg) => JSON.stringify(msg)).join('\n') + '\n' + templateMsgTxt;
logInfo(`[ollama] Requesting response for message history: ${messageHistoryTxt}`);
try {
const chatCompletion = await this.client.chat({
model: this.model,
messages: [
{ role: 'system', content: sysprompt },
{ role: 'user', content: USER_PROMPT + messageHistoryTxt },
],
options: {
temperature: params?.temperature || 0.5,
top_p: params?.top_p || 0.9,
num_predict: params?.max_new_tokens || 128,
},
});
let response = chatCompletion.message.content;
logInfo(`[ollama] API response: ${response}`);
if (!response) {
throw new TypeError('Ollama chat API returned no message.');
}
return response;
} catch (err) {
logError(`[ollama] API Error: ` + err);
throw err;
}
}
async *requestLLMResponseStreaming(
history: Message[],
sysprompt: string,
params: LLMConfig
): AsyncGenerator<{ reasoning?: string; content?: string; done?: boolean }, string, unknown> {
let messageList = await Promise.all(history.map(serializeMessageHistory));
messageList = messageList.filter((x) => !!x);
if (messageList.length === 0) {
throw new TypeError('No messages with content provided in history!');
}
const lastMsg = messageList[messageList.length - 1];
let newDate = new Date(lastMsg!.timestamp);
newDate.setSeconds(newDate.getSeconds() + 5);
let templateMsgTxt = JSON.stringify({
timestamp: newDate.toUTCString(),
author: 'Hatsune Miku',
name: 'Hatsune Miku',
context: lastMsg!.content,
content: '...',
});
const messageHistoryTxt =
messageList.map((msg) => JSON.stringify(msg)).join('\n') + '\n' + templateMsgTxt;
logInfo(`[ollama] Requesting streaming response for message history: ${messageHistoryTxt}`);
try {
const stream = await this.client.chat({
model: this.model,
messages: [
{ role: 'system', content: sysprompt },
{ role: 'user', content: USER_PROMPT + messageHistoryTxt },
],
stream: true,
options: {
temperature: params?.temperature || 0.5,
top_p: params?.top_p || 0.9,
num_predict: params?.max_new_tokens || 128,
},
});
let fullContent = '';
for await (const chunk of stream) {
const messageContent = chunk.message?.content || '';
if (messageContent) {
fullContent += messageContent;
yield { content: fullContent };
}
}
logInfo(`[ollama] Streaming API response: ${fullContent}`);
return fullContent;
} catch (err) {
logError(`[ollama] Streaming API Error: ` + err);
throw err;
}
}
}