Use new RVC endpoint; more unit tests
This commit is contained in:
440
discord/__tests__/helpers.test.ts
Normal file
440
discord/__tests__/helpers.test.ts
Normal file
@@ -0,0 +1,440 @@
|
||||
/**
|
||||
* Tests for helpers.ts functions
|
||||
*/
|
||||
|
||||
jest.mock('../../logging', () => ({
|
||||
logInfo: jest.fn(),
|
||||
logWarn: jest.fn(),
|
||||
logError: jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock('../util', () => ({
|
||||
REAL_NAMES: {},
|
||||
LOSER_WHITELIST: [],
|
||||
}));
|
||||
|
||||
jest.mock('node:path', () => ({
|
||||
join: jest.fn(() => '/tmp/streaks.json'),
|
||||
}));
|
||||
|
||||
jest.mock('node:fs', () => ({
|
||||
existsSync: jest.fn(() => false),
|
||||
readFileSync: jest.fn(),
|
||||
writeFileSync: jest.fn(),
|
||||
}));
|
||||
|
||||
// Mock Discord.js Collection class (mimics Map with filter method)
|
||||
class MockCollection {
|
||||
private map: Map<any, any>;
|
||||
|
||||
constructor(entries?: Array<[any, any]>) {
|
||||
this.map = new Map(entries || []);
|
||||
}
|
||||
|
||||
get size() {
|
||||
return this.map.size;
|
||||
}
|
||||
|
||||
filter(fn: (value: any, key: any) => boolean) {
|
||||
const result = new MockCollection();
|
||||
for (const [key, value] of this.map.entries()) {
|
||||
if (fn(value, key)) {
|
||||
result.map.set(key, value);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
values() {
|
||||
return this.map.values();
|
||||
}
|
||||
|
||||
entries() {
|
||||
return this.map.entries();
|
||||
}
|
||||
|
||||
[Symbol.iterator]() {
|
||||
return this.map[Symbol.iterator]();
|
||||
}
|
||||
}
|
||||
|
||||
const {
|
||||
dateToSnowflake,
|
||||
triggerThrowback,
|
||||
KAWAII_PHRASES,
|
||||
parseLoadingEmojis,
|
||||
getRandomLoadingEmoji,
|
||||
getRandomKawaiiPhrase,
|
||||
createStatusEmbed,
|
||||
createSimpleStatusEmbed,
|
||||
} = require('../commands/helpers');
|
||||
|
||||
describe('helpers.ts', () => {
|
||||
describe('dateToSnowflake', () => {
|
||||
it('should convert Discord epoch to snowflake 0', () => {
|
||||
const discordEpoch = new Date('2015-01-01T00:00:00.000Z');
|
||||
const result = dateToSnowflake(discordEpoch);
|
||||
expect(result).toBe('0');
|
||||
});
|
||||
|
||||
it('should convert a known date to snowflake', () => {
|
||||
const testDate = new Date('2024-01-01T00:00:00.000Z');
|
||||
const result = dateToSnowflake(testDate);
|
||||
expect(result).toMatch(/^\d+$/);
|
||||
expect(result.length).toBeGreaterThan(10);
|
||||
});
|
||||
|
||||
it('should produce increasing snowflakes for increasing dates', () => {
|
||||
const date1 = new Date('2024-01-01T00:00:00.000Z');
|
||||
const date2 = new Date('2024-01-02T00:00:00.000Z');
|
||||
const snowflake1 = dateToSnowflake(date1);
|
||||
const snowflake2 = dateToSnowflake(date2);
|
||||
expect(BigInt(snowflake2)).toBeGreaterThan(BigInt(snowflake1));
|
||||
});
|
||||
});
|
||||
|
||||
describe('KAWAII_PHRASES', () => {
|
||||
it('should contain kawaii phrases', () => {
|
||||
expect(KAWAII_PHRASES.length).toBeGreaterThan(0);
|
||||
expect(KAWAII_PHRASES).toContain('Hmm... let me think~ ♪');
|
||||
});
|
||||
});
|
||||
|
||||
describe('parseLoadingEmojis', () => {
|
||||
it('should parse emojis from environment variable', () => {
|
||||
const original = process.env.LOADING_EMOJIS;
|
||||
process.env.LOADING_EMOJIS =
|
||||
'<:clueless:123>,<a:hachune:456>,<a:chairspin:789>,<a:nekodance:012>';
|
||||
const result = parseLoadingEmojis();
|
||||
process.env.LOADING_EMOJIS = original;
|
||||
expect(result).toHaveLength(4);
|
||||
expect(result).toEqual([
|
||||
'<:clueless:123>',
|
||||
'<a:hachune:456>',
|
||||
'<a:chairspin:789>',
|
||||
'<a:nekodance:012>',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should return default emojis when LOADING_EMOJIS is empty', () => {
|
||||
const original = process.env.LOADING_EMOJIS;
|
||||
process.env.LOADING_EMOJIS = '';
|
||||
const result = parseLoadingEmojis();
|
||||
process.env.LOADING_EMOJIS = original;
|
||||
expect(result).toEqual(['🤔', '✨', '🎵']);
|
||||
});
|
||||
|
||||
it('should handle whitespace in emoji list', () => {
|
||||
const original = process.env.LOADING_EMOJIS;
|
||||
process.env.LOADING_EMOJIS = ' <:test:123> , <a:spin:456> ';
|
||||
const result = parseLoadingEmojis();
|
||||
process.env.LOADING_EMOJIS = original;
|
||||
expect(result).toEqual(['<:test:123>', '<a:spin:456>']);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRandomLoadingEmoji', () => {
|
||||
it('should return a valid emoji from the list', () => {
|
||||
const result = getRandomLoadingEmoji();
|
||||
const validEmojis = parseLoadingEmojis();
|
||||
expect(validEmojis).toContain(result);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getRandomKawaiiPhrase', () => {
|
||||
it('should return a valid kawaii phrase', () => {
|
||||
const result = getRandomKawaiiPhrase();
|
||||
expect(KAWAII_PHRASES).toContain(result);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createStatusEmbed', () => {
|
||||
it('should create an embed with emoji, phrase, and status', () => {
|
||||
const embed = createStatusEmbed('🤔', 'Hmm... let me think~ ♪', 'Processing...');
|
||||
expect(embed).toBeDefined();
|
||||
expect(embed.data.author).toBeDefined();
|
||||
expect(embed.data.author?.name).toBe('Hmm... let me think~ ♪');
|
||||
});
|
||||
});
|
||||
|
||||
describe('createSimpleStatusEmbed', () => {
|
||||
it('should create an embed with random emoji and phrase', () => {
|
||||
const embed = createSimpleStatusEmbed('Working...');
|
||||
expect(embed).toBeDefined();
|
||||
expect(embed.data.author).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('triggerThrowback', () => {
|
||||
const mockClient = {
|
||||
guilds: {
|
||||
fetch: jest.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
const mockProvider = {
|
||||
requestLLMResponse: jest.fn(),
|
||||
};
|
||||
|
||||
const mockSysprompt = 'You are a helpful assistant.';
|
||||
const mockLlmconf = {
|
||||
msg_context: 10,
|
||||
streaming: false,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should fetch messages from 1 year ago', async () => {
|
||||
const mockMessage = {
|
||||
id: '123456789',
|
||||
author: { username: 'testuser', bot: false },
|
||||
cleanContent: 'Hello from a year ago!',
|
||||
type: 0,
|
||||
reply: jest.fn(),
|
||||
};
|
||||
|
||||
const mockChannel = {
|
||||
messages: {
|
||||
fetch: jest.fn().mockResolvedValue(new MockCollection([['123456789', mockMessage]])),
|
||||
},
|
||||
};
|
||||
|
||||
mockProvider.requestLLMResponse.mockResolvedValue('Nice throwback!');
|
||||
|
||||
await triggerThrowback(
|
||||
mockClient as any,
|
||||
mockChannel as any,
|
||||
mockChannel as any,
|
||||
mockProvider,
|
||||
mockSysprompt,
|
||||
mockLlmconf
|
||||
);
|
||||
|
||||
// Verify messages.fetch was called with around date from 1 year ago
|
||||
const fetchCall = mockChannel.messages.fetch.mock.calls[0][0];
|
||||
expect(fetchCall.around).toBeDefined();
|
||||
expect(fetchCall.limit).toBe(50);
|
||||
});
|
||||
|
||||
it('should fetch message history for context before generating LLM response', async () => {
|
||||
const mockReply = jest.fn();
|
||||
const mockMessage = {
|
||||
id: '123456789',
|
||||
author: { username: 'testuser', bot: false },
|
||||
cleanContent: 'Hello from a year ago!',
|
||||
type: 0,
|
||||
reply: mockReply,
|
||||
};
|
||||
|
||||
const mockHistoryMessage = {
|
||||
id: '123456788',
|
||||
author: { username: 'testuser', bot: false },
|
||||
cleanContent: 'Previous context',
|
||||
type: 0,
|
||||
};
|
||||
|
||||
const mockChannel = {
|
||||
messages: {
|
||||
fetch: jest
|
||||
.fn()
|
||||
.mockResolvedValueOnce(
|
||||
new MockCollection([
|
||||
['123456788', mockHistoryMessage],
|
||||
['123456789', mockMessage],
|
||||
])
|
||||
)
|
||||
.mockResolvedValueOnce(new MockCollection([['123456788', mockHistoryMessage]])),
|
||||
},
|
||||
};
|
||||
|
||||
mockProvider.requestLLMResponse.mockResolvedValue('Nice throwback!');
|
||||
|
||||
await triggerThrowback(
|
||||
mockClient as any,
|
||||
mockChannel as any,
|
||||
mockChannel as any,
|
||||
mockProvider,
|
||||
mockSysprompt,
|
||||
mockLlmconf
|
||||
);
|
||||
|
||||
// Verify messages.fetch was called twice: once for throwback, once for history
|
||||
expect(mockChannel.messages.fetch).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Verify history fetch used msg_context from llmconf
|
||||
const historyFetchCall = mockChannel.messages.fetch.mock.calls[1][0];
|
||||
expect(historyFetchCall.limit).toBe(mockLlmconf.msg_context - 1);
|
||||
expect(historyFetchCall.before).toBe(mockMessage.id);
|
||||
|
||||
// Verify LLM was called with context (history + selected message)
|
||||
expect(mockProvider.requestLLMResponse).toHaveBeenCalledWith(
|
||||
expect.arrayContaining([expect.objectContaining({ id: '123456788' })]),
|
||||
mockSysprompt,
|
||||
mockLlmconf
|
||||
);
|
||||
});
|
||||
|
||||
it('should reply to the original message', async () => {
|
||||
const mockReply = jest.fn();
|
||||
const mockMessage = {
|
||||
id: '123456789',
|
||||
author: { username: 'testuser', bot: false },
|
||||
cleanContent: 'Hello from a year ago!',
|
||||
type: 0,
|
||||
reply: mockReply,
|
||||
};
|
||||
|
||||
const mockChannel = {
|
||||
messages: {
|
||||
fetch: jest.fn().mockResolvedValue(new MockCollection([['123456789', mockMessage]])),
|
||||
},
|
||||
};
|
||||
|
||||
mockProvider.requestLLMResponse.mockResolvedValue('Nice throwback!');
|
||||
|
||||
await triggerThrowback(
|
||||
mockClient as any,
|
||||
mockChannel as any,
|
||||
mockChannel as any,
|
||||
mockProvider,
|
||||
mockSysprompt,
|
||||
mockLlmconf
|
||||
);
|
||||
|
||||
// Verify reply was called on the original message, not send on channel
|
||||
expect(mockReply).toHaveBeenCalledWith('Nice throwback!');
|
||||
});
|
||||
|
||||
it('should throw error when no messages found from 1 year ago', async () => {
|
||||
const mockChannel = {
|
||||
messages: {
|
||||
fetch: jest.fn().mockResolvedValue(new MockCollection()),
|
||||
},
|
||||
};
|
||||
|
||||
await expect(
|
||||
triggerThrowback(
|
||||
mockClient as any,
|
||||
mockChannel as any,
|
||||
mockChannel as any,
|
||||
mockProvider,
|
||||
mockSysprompt,
|
||||
mockLlmconf
|
||||
)
|
||||
).rejects.toThrow('No messages found from 1 year ago.');
|
||||
});
|
||||
|
||||
it('should filter out bot messages', async () => {
|
||||
const mockBotMessage = {
|
||||
id: '111',
|
||||
author: { username: 'bot', bot: true },
|
||||
cleanContent: 'Bot message',
|
||||
type: 0,
|
||||
};
|
||||
|
||||
const mockUserMessage = {
|
||||
id: '222',
|
||||
author: { username: 'user', bot: false },
|
||||
cleanContent: 'User message',
|
||||
type: 0,
|
||||
reply: jest.fn(),
|
||||
};
|
||||
|
||||
const mockChannel = {
|
||||
messages: {
|
||||
fetch: jest
|
||||
.fn()
|
||||
.mockResolvedValue(new MockCollection([['111', mockBotMessage], ['222', mockUserMessage]])),
|
||||
},
|
||||
};
|
||||
|
||||
mockProvider.requestLLMResponse.mockResolvedValue('Reply!');
|
||||
|
||||
await triggerThrowback(
|
||||
mockClient as any,
|
||||
mockChannel as any,
|
||||
mockChannel as any,
|
||||
mockProvider,
|
||||
mockSysprompt,
|
||||
mockLlmconf
|
||||
);
|
||||
|
||||
// Verify only user message was considered (bot filtered out)
|
||||
expect(mockProvider.requestLLMResponse).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should filter out messages without content', async () => {
|
||||
const mockEmptyMessage = {
|
||||
id: '111',
|
||||
author: { username: 'user1', bot: false },
|
||||
cleanContent: '',
|
||||
type: 0,
|
||||
};
|
||||
|
||||
const mockValidMessage = {
|
||||
id: '222',
|
||||
author: { username: 'user2', bot: false },
|
||||
cleanContent: 'Valid message',
|
||||
type: 0,
|
||||
reply: jest.fn(),
|
||||
};
|
||||
|
||||
const mockChannel = {
|
||||
messages: {
|
||||
fetch: jest
|
||||
.fn()
|
||||
.mockResolvedValue(new MockCollection([['111', mockEmptyMessage], ['222', mockValidMessage]])),
|
||||
},
|
||||
};
|
||||
|
||||
mockProvider.requestLLMResponse.mockResolvedValue('Reply!');
|
||||
|
||||
await triggerThrowback(
|
||||
mockClient as any,
|
||||
mockChannel as any,
|
||||
mockChannel as any,
|
||||
mockProvider,
|
||||
mockSysprompt,
|
||||
mockLlmconf
|
||||
);
|
||||
|
||||
// Verify only valid message was considered
|
||||
expect(mockProvider.requestLLMResponse).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return throwback result with original message, author, and response', async () => {
|
||||
const mockMessage = {
|
||||
id: '123456789',
|
||||
author: { username: 'testuser', bot: false },
|
||||
cleanContent: 'Hello from a year ago!',
|
||||
type: 0,
|
||||
reply: jest.fn(),
|
||||
};
|
||||
|
||||
const mockChannel = {
|
||||
messages: {
|
||||
fetch: jest.fn().mockResolvedValue(new MockCollection([['123456789', mockMessage]])),
|
||||
},
|
||||
};
|
||||
|
||||
mockProvider.requestLLMResponse.mockResolvedValue('Nice throwback!');
|
||||
|
||||
const result = await triggerThrowback(
|
||||
mockClient as any,
|
||||
mockChannel as any,
|
||||
mockChannel as any,
|
||||
mockProvider,
|
||||
mockSysprompt,
|
||||
mockLlmconf
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
originalMessage: 'Hello from a year ago!',
|
||||
author: 'testuser',
|
||||
response: 'Nice throwback!',
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4,7 +4,6 @@
|
||||
*/
|
||||
|
||||
import {
|
||||
Attachment,
|
||||
AttachmentBuilder,
|
||||
Client,
|
||||
Collection,
|
||||
@@ -25,7 +24,6 @@ import {
|
||||
import fs = require('node:fs');
|
||||
import path = require('node:path');
|
||||
import fetch, { Blob as NodeFetchBlob } from 'node-fetch';
|
||||
import FormData = require('form-data');
|
||||
import tmp = require('tmp');
|
||||
import { JSDOM } from 'jsdom';
|
||||
import { logError, logInfo, logWarn } from '../logging';
|
||||
@@ -34,6 +32,7 @@ import {
|
||||
openDb,
|
||||
reactionEmojis,
|
||||
recordReaction,
|
||||
requestRVCResponse,
|
||||
requestTTSResponse,
|
||||
serializeMessageHistory,
|
||||
sync,
|
||||
@@ -354,31 +353,6 @@ async function onNewMessage(message: Message) {
|
||||
}
|
||||
}
|
||||
|
||||
async function requestRVCResponse(src: Attachment): Promise<NodeFetchBlob> {
|
||||
logInfo(`[bot] Downloading audio message ${src.url}`);
|
||||
const srcres = await fetch(src.url);
|
||||
const srcbuf = await srcres.arrayBuffer();
|
||||
const tmpFile = tmp.fileSync();
|
||||
const tmpFileName = tmpFile.name;
|
||||
fs.writeFileSync(tmpFileName, Buffer.from(srcbuf));
|
||||
logInfo(`[bot] Got audio file: ${srcbuf.byteLength} bytes`);
|
||||
|
||||
const queryParams = new URLSearchParams();
|
||||
queryParams.append('token', process.env.LLM_TOKEN || '');
|
||||
|
||||
const fd = new FormData();
|
||||
fd.append('file', fs.readFileSync(tmpFileName), 'voice-message.ogg');
|
||||
|
||||
const rvcEndpoint = `${process.env.RVC_HOST}/rvc?${queryParams.toString()}`;
|
||||
logInfo(`[bot] Requesting RVC response for ${src.id}`);
|
||||
const res = await fetch(rvcEndpoint, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
});
|
||||
const resContents = await res.blob();
|
||||
return resContents;
|
||||
}
|
||||
|
||||
async function scheduleRandomMessage(firstTime = false) {
|
||||
if (!firstTime) {
|
||||
if (!process.env.MOTD_CHANNEL) {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
*/
|
||||
|
||||
import {
|
||||
Attachment,
|
||||
Collection,
|
||||
GuildManager,
|
||||
GuildTextBasedChannel,
|
||||
@@ -13,12 +14,14 @@ import {
|
||||
User,
|
||||
} from 'discord.js';
|
||||
import { get as getEmojiName } from 'emoji-unicode-map';
|
||||
import { createWriteStream, existsSync, unlinkSync } from 'fs';
|
||||
import { createWriteStream, existsSync, readFileSync, unlinkSync, writeFileSync } from 'fs';
|
||||
import { get as httpGet } from 'https';
|
||||
import { Database, open } from 'sqlite';
|
||||
import { Database as Database3 } from 'sqlite3';
|
||||
import 'dotenv/config';
|
||||
import FormData = require('form-data');
|
||||
import fetch, { Blob as NodeFetchBlob } from 'node-fetch';
|
||||
import tmp = require('tmp');
|
||||
import { logError, logInfo, logWarn } from '../logging';
|
||||
import { ScoreboardMessageRow } from '../models';
|
||||
import { LLMDiscordMessage } from './provider/provider';
|
||||
@@ -331,12 +334,37 @@ async function requestTTSResponse(
|
||||
return resContents;
|
||||
}
|
||||
|
||||
async function requestRVCResponse(src: Attachment, pitch?: number): Promise<NodeFetchBlob> {
|
||||
logInfo(`[bot] Downloading audio message ${src.url}`);
|
||||
const srcres = await fetch(src.url);
|
||||
const srcbuf = await srcres.arrayBuffer();
|
||||
const tmpFile = tmp.fileSync();
|
||||
const tmpFileName = tmpFile.name;
|
||||
writeFileSync(tmpFileName, Buffer.from(srcbuf));
|
||||
logInfo(`[bot] Got audio file: ${srcbuf.byteLength} bytes`);
|
||||
|
||||
const fd = new FormData();
|
||||
fd.append('input_audio', readFileSync(tmpFileName), 'voice-message.ogg');
|
||||
fd.append('modelpath', 'model.pth');
|
||||
fd.append('f0_up_key', pitch ?? 0);
|
||||
|
||||
const rvcEndpoint = `${process.env.RVC_HOST}/inference`;
|
||||
logInfo(`[bot] Requesting RVC response for ${src.id}`);
|
||||
const res = await fetch(rvcEndpoint, {
|
||||
method: 'POST',
|
||||
body: fd,
|
||||
});
|
||||
const resContents = await res.blob();
|
||||
return resContents;
|
||||
}
|
||||
|
||||
export {
|
||||
db,
|
||||
clearDb,
|
||||
openDb,
|
||||
reactionEmojis,
|
||||
recordReaction,
|
||||
requestRVCResponse,
|
||||
requestTTSResponse,
|
||||
serializeMessageHistory,
|
||||
sync,
|
||||
|
||||
Reference in New Issue
Block a user