diff --git a/discord/__tests__/helpers.test.ts b/discord/__tests__/helpers.test.ts new file mode 100644 index 0000000..b9eb3e5 --- /dev/null +++ b/discord/__tests__/helpers.test.ts @@ -0,0 +1,440 @@ +/** + * Tests for helpers.ts functions + */ + +jest.mock('../../logging', () => ({ + logInfo: jest.fn(), + logWarn: jest.fn(), + logError: jest.fn(), +})); + +jest.mock('../util', () => ({ + REAL_NAMES: {}, + LOSER_WHITELIST: [], +})); + +jest.mock('node:path', () => ({ + join: jest.fn(() => '/tmp/streaks.json'), +})); + +jest.mock('node:fs', () => ({ + existsSync: jest.fn(() => false), + readFileSync: jest.fn(), + writeFileSync: jest.fn(), +})); + +// Mock Discord.js Collection class (mimics Map with filter method) +class MockCollection { + private map: Map; + + constructor(entries?: Array<[any, any]>) { + this.map = new Map(entries || []); + } + + get size() { + return this.map.size; + } + + filter(fn: (value: any, key: any) => boolean) { + const result = new MockCollection(); + for (const [key, value] of this.map.entries()) { + if (fn(value, key)) { + result.map.set(key, value); + } + } + return result; + } + + values() { + return this.map.values(); + } + + entries() { + return this.map.entries(); + } + + [Symbol.iterator]() { + return this.map[Symbol.iterator](); + } +} + +const { + dateToSnowflake, + triggerThrowback, + KAWAII_PHRASES, + parseLoadingEmojis, + getRandomLoadingEmoji, + getRandomKawaiiPhrase, + createStatusEmbed, + createSimpleStatusEmbed, +} = require('../commands/helpers'); + +describe('helpers.ts', () => { + describe('dateToSnowflake', () => { + it('should convert Discord epoch to snowflake 0', () => { + const discordEpoch = new Date('2015-01-01T00:00:00.000Z'); + const result = dateToSnowflake(discordEpoch); + expect(result).toBe('0'); + }); + + it('should convert a known date to snowflake', () => { + const testDate = new Date('2024-01-01T00:00:00.000Z'); + const result = dateToSnowflake(testDate); + expect(result).toMatch(/^\d+$/); + expect(result.length).toBeGreaterThan(10); + }); + + it('should produce increasing snowflakes for increasing dates', () => { + const date1 = new Date('2024-01-01T00:00:00.000Z'); + const date2 = new Date('2024-01-02T00:00:00.000Z'); + const snowflake1 = dateToSnowflake(date1); + const snowflake2 = dateToSnowflake(date2); + expect(BigInt(snowflake2)).toBeGreaterThan(BigInt(snowflake1)); + }); + }); + + describe('KAWAII_PHRASES', () => { + it('should contain kawaii phrases', () => { + expect(KAWAII_PHRASES.length).toBeGreaterThan(0); + expect(KAWAII_PHRASES).toContain('Hmm... let me think~ ♪'); + }); + }); + + describe('parseLoadingEmojis', () => { + it('should parse emojis from environment variable', () => { + const original = process.env.LOADING_EMOJIS; + process.env.LOADING_EMOJIS = + '<:clueless:123>,,,'; + const result = parseLoadingEmojis(); + process.env.LOADING_EMOJIS = original; + expect(result).toHaveLength(4); + expect(result).toEqual([ + '<:clueless:123>', + '', + '', + '', + ]); + }); + + it('should return default emojis when LOADING_EMOJIS is empty', () => { + const original = process.env.LOADING_EMOJIS; + process.env.LOADING_EMOJIS = ''; + const result = parseLoadingEmojis(); + process.env.LOADING_EMOJIS = original; + expect(result).toEqual(['🤔', '✨', '🎵']); + }); + + it('should handle whitespace in emoji list', () => { + const original = process.env.LOADING_EMOJIS; + process.env.LOADING_EMOJIS = ' <:test:123> , '; + const result = parseLoadingEmojis(); + process.env.LOADING_EMOJIS = original; + expect(result).toEqual(['<:test:123>', '']); + }); + }); + + describe('getRandomLoadingEmoji', () => { + it('should return a valid emoji from the list', () => { + const result = getRandomLoadingEmoji(); + const validEmojis = parseLoadingEmojis(); + expect(validEmojis).toContain(result); + }); + }); + + describe('getRandomKawaiiPhrase', () => { + it('should return a valid kawaii phrase', () => { + const result = getRandomKawaiiPhrase(); + expect(KAWAII_PHRASES).toContain(result); + }); + }); + + describe('createStatusEmbed', () => { + it('should create an embed with emoji, phrase, and status', () => { + const embed = createStatusEmbed('🤔', 'Hmm... let me think~ ♪', 'Processing...'); + expect(embed).toBeDefined(); + expect(embed.data.author).toBeDefined(); + expect(embed.data.author?.name).toBe('Hmm... let me think~ ♪'); + }); + }); + + describe('createSimpleStatusEmbed', () => { + it('should create an embed with random emoji and phrase', () => { + const embed = createSimpleStatusEmbed('Working...'); + expect(embed).toBeDefined(); + expect(embed.data.author).toBeDefined(); + }); + }); + + describe('triggerThrowback', () => { + const mockClient = { + guilds: { + fetch: jest.fn(), + }, + }; + + const mockProvider = { + requestLLMResponse: jest.fn(), + }; + + const mockSysprompt = 'You are a helpful assistant.'; + const mockLlmconf = { + msg_context: 10, + streaming: false, + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should fetch messages from 1 year ago', async () => { + const mockMessage = { + id: '123456789', + author: { username: 'testuser', bot: false }, + cleanContent: 'Hello from a year ago!', + type: 0, + reply: jest.fn(), + }; + + const mockChannel = { + messages: { + fetch: jest.fn().mockResolvedValue(new MockCollection([['123456789', mockMessage]])), + }, + }; + + mockProvider.requestLLMResponse.mockResolvedValue('Nice throwback!'); + + await triggerThrowback( + mockClient as any, + mockChannel as any, + mockChannel as any, + mockProvider, + mockSysprompt, + mockLlmconf + ); + + // Verify messages.fetch was called with around date from 1 year ago + const fetchCall = mockChannel.messages.fetch.mock.calls[0][0]; + expect(fetchCall.around).toBeDefined(); + expect(fetchCall.limit).toBe(50); + }); + + it('should fetch message history for context before generating LLM response', async () => { + const mockReply = jest.fn(); + const mockMessage = { + id: '123456789', + author: { username: 'testuser', bot: false }, + cleanContent: 'Hello from a year ago!', + type: 0, + reply: mockReply, + }; + + const mockHistoryMessage = { + id: '123456788', + author: { username: 'testuser', bot: false }, + cleanContent: 'Previous context', + type: 0, + }; + + const mockChannel = { + messages: { + fetch: jest + .fn() + .mockResolvedValueOnce( + new MockCollection([ + ['123456788', mockHistoryMessage], + ['123456789', mockMessage], + ]) + ) + .mockResolvedValueOnce(new MockCollection([['123456788', mockHistoryMessage]])), + }, + }; + + mockProvider.requestLLMResponse.mockResolvedValue('Nice throwback!'); + + await triggerThrowback( + mockClient as any, + mockChannel as any, + mockChannel as any, + mockProvider, + mockSysprompt, + mockLlmconf + ); + + // Verify messages.fetch was called twice: once for throwback, once for history + expect(mockChannel.messages.fetch).toHaveBeenCalledTimes(2); + + // Verify history fetch used msg_context from llmconf + const historyFetchCall = mockChannel.messages.fetch.mock.calls[1][0]; + expect(historyFetchCall.limit).toBe(mockLlmconf.msg_context - 1); + expect(historyFetchCall.before).toBe(mockMessage.id); + + // Verify LLM was called with context (history + selected message) + expect(mockProvider.requestLLMResponse).toHaveBeenCalledWith( + expect.arrayContaining([expect.objectContaining({ id: '123456788' })]), + mockSysprompt, + mockLlmconf + ); + }); + + it('should reply to the original message', async () => { + const mockReply = jest.fn(); + const mockMessage = { + id: '123456789', + author: { username: 'testuser', bot: false }, + cleanContent: 'Hello from a year ago!', + type: 0, + reply: mockReply, + }; + + const mockChannel = { + messages: { + fetch: jest.fn().mockResolvedValue(new MockCollection([['123456789', mockMessage]])), + }, + }; + + mockProvider.requestLLMResponse.mockResolvedValue('Nice throwback!'); + + await triggerThrowback( + mockClient as any, + mockChannel as any, + mockChannel as any, + mockProvider, + mockSysprompt, + mockLlmconf + ); + + // Verify reply was called on the original message, not send on channel + expect(mockReply).toHaveBeenCalledWith('Nice throwback!'); + }); + + it('should throw error when no messages found from 1 year ago', async () => { + const mockChannel = { + messages: { + fetch: jest.fn().mockResolvedValue(new MockCollection()), + }, + }; + + await expect( + triggerThrowback( + mockClient as any, + mockChannel as any, + mockChannel as any, + mockProvider, + mockSysprompt, + mockLlmconf + ) + ).rejects.toThrow('No messages found from 1 year ago.'); + }); + + it('should filter out bot messages', async () => { + const mockBotMessage = { + id: '111', + author: { username: 'bot', bot: true }, + cleanContent: 'Bot message', + type: 0, + }; + + const mockUserMessage = { + id: '222', + author: { username: 'user', bot: false }, + cleanContent: 'User message', + type: 0, + reply: jest.fn(), + }; + + const mockChannel = { + messages: { + fetch: jest + .fn() + .mockResolvedValue(new MockCollection([['111', mockBotMessage], ['222', mockUserMessage]])), + }, + }; + + mockProvider.requestLLMResponse.mockResolvedValue('Reply!'); + + await triggerThrowback( + mockClient as any, + mockChannel as any, + mockChannel as any, + mockProvider, + mockSysprompt, + mockLlmconf + ); + + // Verify only user message was considered (bot filtered out) + expect(mockProvider.requestLLMResponse).toHaveBeenCalled(); + }); + + it('should filter out messages without content', async () => { + const mockEmptyMessage = { + id: '111', + author: { username: 'user1', bot: false }, + cleanContent: '', + type: 0, + }; + + const mockValidMessage = { + id: '222', + author: { username: 'user2', bot: false }, + cleanContent: 'Valid message', + type: 0, + reply: jest.fn(), + }; + + const mockChannel = { + messages: { + fetch: jest + .fn() + .mockResolvedValue(new MockCollection([['111', mockEmptyMessage], ['222', mockValidMessage]])), + }, + }; + + mockProvider.requestLLMResponse.mockResolvedValue('Reply!'); + + await triggerThrowback( + mockClient as any, + mockChannel as any, + mockChannel as any, + mockProvider, + mockSysprompt, + mockLlmconf + ); + + // Verify only valid message was considered + expect(mockProvider.requestLLMResponse).toHaveBeenCalled(); + }); + + it('should return throwback result with original message, author, and response', async () => { + const mockMessage = { + id: '123456789', + author: { username: 'testuser', bot: false }, + cleanContent: 'Hello from a year ago!', + type: 0, + reply: jest.fn(), + }; + + const mockChannel = { + messages: { + fetch: jest.fn().mockResolvedValue(new MockCollection([['123456789', mockMessage]])), + }, + }; + + mockProvider.requestLLMResponse.mockResolvedValue('Nice throwback!'); + + const result = await triggerThrowback( + mockClient as any, + mockChannel as any, + mockChannel as any, + mockProvider, + mockSysprompt, + mockLlmconf + ); + + expect(result).toEqual({ + originalMessage: 'Hello from a year ago!', + author: 'testuser', + response: 'Nice throwback!', + }); + }); + }); +}); diff --git a/discord/bot.ts b/discord/bot.ts index 18d0769..5c070b8 100644 --- a/discord/bot.ts +++ b/discord/bot.ts @@ -4,7 +4,6 @@ */ import { - Attachment, AttachmentBuilder, Client, Collection, @@ -25,7 +24,6 @@ import { import fs = require('node:fs'); import path = require('node:path'); import fetch, { Blob as NodeFetchBlob } from 'node-fetch'; -import FormData = require('form-data'); import tmp = require('tmp'); import { JSDOM } from 'jsdom'; import { logError, logInfo, logWarn } from '../logging'; @@ -34,6 +32,7 @@ import { openDb, reactionEmojis, recordReaction, + requestRVCResponse, requestTTSResponse, serializeMessageHistory, sync, @@ -354,31 +353,6 @@ async function onNewMessage(message: Message) { } } -async function requestRVCResponse(src: Attachment): Promise { - logInfo(`[bot] Downloading audio message ${src.url}`); - const srcres = await fetch(src.url); - const srcbuf = await srcres.arrayBuffer(); - const tmpFile = tmp.fileSync(); - const tmpFileName = tmpFile.name; - fs.writeFileSync(tmpFileName, Buffer.from(srcbuf)); - logInfo(`[bot] Got audio file: ${srcbuf.byteLength} bytes`); - - const queryParams = new URLSearchParams(); - queryParams.append('token', process.env.LLM_TOKEN || ''); - - const fd = new FormData(); - fd.append('file', fs.readFileSync(tmpFileName), 'voice-message.ogg'); - - const rvcEndpoint = `${process.env.RVC_HOST}/rvc?${queryParams.toString()}`; - logInfo(`[bot] Requesting RVC response for ${src.id}`); - const res = await fetch(rvcEndpoint, { - method: 'POST', - body: fd, - }); - const resContents = await res.blob(); - return resContents; -} - async function scheduleRandomMessage(firstTime = false) { if (!firstTime) { if (!process.env.MOTD_CHANNEL) { diff --git a/discord/util.ts b/discord/util.ts index b9aafcf..d870da8 100644 --- a/discord/util.ts +++ b/discord/util.ts @@ -4,6 +4,7 @@ */ import { + Attachment, Collection, GuildManager, GuildTextBasedChannel, @@ -13,12 +14,14 @@ import { User, } from 'discord.js'; import { get as getEmojiName } from 'emoji-unicode-map'; -import { createWriteStream, existsSync, unlinkSync } from 'fs'; +import { createWriteStream, existsSync, readFileSync, unlinkSync, writeFileSync } from 'fs'; import { get as httpGet } from 'https'; import { Database, open } from 'sqlite'; import { Database as Database3 } from 'sqlite3'; import 'dotenv/config'; +import FormData = require('form-data'); import fetch, { Blob as NodeFetchBlob } from 'node-fetch'; +import tmp = require('tmp'); import { logError, logInfo, logWarn } from '../logging'; import { ScoreboardMessageRow } from '../models'; import { LLMDiscordMessage } from './provider/provider'; @@ -331,12 +334,37 @@ async function requestTTSResponse( return resContents; } +async function requestRVCResponse(src: Attachment, pitch?: number): Promise { + logInfo(`[bot] Downloading audio message ${src.url}`); + const srcres = await fetch(src.url); + const srcbuf = await srcres.arrayBuffer(); + const tmpFile = tmp.fileSync(); + const tmpFileName = tmpFile.name; + writeFileSync(tmpFileName, Buffer.from(srcbuf)); + logInfo(`[bot] Got audio file: ${srcbuf.byteLength} bytes`); + + const fd = new FormData(); + fd.append('input_audio', readFileSync(tmpFileName), 'voice-message.ogg'); + fd.append('modelpath', 'model.pth'); + fd.append('f0_up_key', pitch ?? 0); + + const rvcEndpoint = `${process.env.RVC_HOST}/inference`; + logInfo(`[bot] Requesting RVC response for ${src.id}`); + const res = await fetch(rvcEndpoint, { + method: 'POST', + body: fd, + }); + const resContents = await res.blob(); + return resContents; +} + export { db, clearDb, openDb, reactionEmojis, recordReaction, + requestRVCResponse, requestTTSResponse, serializeMessageHistory, sync,