diff --git a/bridge/matrix/cache.go b/bridge/matrix/cache.go new file mode 100644 index 00000000..071cb742 --- /dev/null +++ b/bridge/matrix/cache.go @@ -0,0 +1,155 @@ +package bmatrix + +import ( + "sort" + "sync" + "time" + + "maunium.net/go/mautrix/id" +) + +type UserInRoomCacheEntry struct { + displayName *string + avatarURL *string + lastUpdated time.Time + conflictWithOtherUsername bool +} + +type UserCacheEntry struct { + globalEntry *UserInRoomCacheEntry + perChannel map[id.RoomID]UserInRoomCacheEntry +} + +type UserInfoCache struct { + users map[id.UserID]UserCacheEntry + sync.RWMutex +} + +func NewUserInfoCache() *UserInfoCache { + return &UserInfoCache{ + users: make(map[id.UserID]UserCacheEntry), + RWMutex: sync.RWMutex{}, + } +} + +// note: cache is locked inside this function +func (c *UserInfoCache) retrieveUserInRoomFromCache(channelID id.RoomID, mxid id.UserID) *UserInRoomCacheEntry { + var cachedEntry *UserInRoomCacheEntry = nil + + c.RLock() + if user, userPresent := c.users[mxid]; userPresent { + // try first the name of the user in the room, then globally + if roomCachedEntry, roomPresent := user.perChannel[channelID]; roomPresent { + cachedEntry = &roomCachedEntry + } else if user.globalEntry != nil { + cachedEntry = user.globalEntry + } + } + c.RUnlock() + + return cachedEntry +} + +// note: cache is locked inside this function +func (b *Bmatrix) cacheEntry(channelID id.RoomID, mxid id.UserID, callback func(UserInRoomCacheEntry) UserInRoomCacheEntry) { + now := time.Now() + + cache := b.UserCache + + cache.Lock() + defer cache.Unlock() + + cache.clearObsoleteEntries(mxid) + + var newEntry UserCacheEntry + if user, userPresent := cache.users[mxid]; userPresent { + newEntry = user + } else { + newEntry = UserCacheEntry{ + globalEntry: nil, + perChannel: make(map[id.RoomID]UserInRoomCacheEntry), + } + } + + cacheEntry := UserInRoomCacheEntry{ + lastUpdated: now, + } + if channelID == "" && newEntry.globalEntry != nil { + cacheEntry = *newEntry.globalEntry + } else if channelID != "" { + if roomCachedEntry, roomPresent := newEntry.perChannel[channelID]; roomPresent { + cacheEntry = roomCachedEntry + } + } + + newCacheEntry := callback(cacheEntry) + // try first the name of the user in the room, then globally + if channelID == "" { + newEntry.globalEntry = &newCacheEntry + } else { + // this is a local (room-specific) state, let's cache it as such + newEntry.perChannel[channelID] = newCacheEntry + } + + cache.users[mxid] = newEntry +} + +// scan to delete old entries, to stop memory usage from becoming high with obsolete entries. +// note: assume the cache is already write-locked +// TODO: should we update the timestamp when the entry is used? +func (c *UserInfoCache) clearObsoleteEntries(mxid id.UserID) { + // we have a "off-by-one" to account for when the user being added to the + // cache already have obsolete cache entries, as we want to keep it because + // we will be refreshing it in a minute + if len(c.users) <= MaxNumberOfUsersInCache+1 { + return + } + + usersLastTimestamp := make(map[id.UserID]int64, len(c.users)) + // compute the last updated timestamp entry for each user + for mxidIter, NicknameCacheIter := range c.users { + userLastTimestamp := time.Unix(0, 0) + for _, userInChannelCacheEntry := range NicknameCacheIter.perChannel { + if userInChannelCacheEntry.lastUpdated.After(userLastTimestamp) { + userLastTimestamp = userInChannelCacheEntry.lastUpdated + } + } + + if NicknameCacheIter.globalEntry != nil { + if NicknameCacheIter.globalEntry.lastUpdated.After(userLastTimestamp) { + userLastTimestamp = NicknameCacheIter.globalEntry.lastUpdated + } + } + + usersLastTimestamp[mxidIter] = userLastTimestamp.UnixNano() + } + + // get the limit timestamp before which we must clear entries as obsolete + sortedTimestamps := make([]int64, 0, len(usersLastTimestamp)) + for _, value := range usersLastTimestamp { + sortedTimestamps = append(sortedTimestamps, value) + } + sort.Slice(sortedTimestamps, func(i, j int) bool { return sortedTimestamps[i] < sortedTimestamps[j] }) + limitTimestamp := sortedTimestamps[len(sortedTimestamps)-MaxNumberOfUsersInCache] + + // delete entries older than the limit + for mxidIter, timestamp := range usersLastTimestamp { + // do not clear the user that we are adding to the cache + if timestamp <= limitTimestamp && mxidIter != mxid { + delete(c.users, mxidIter) + } + } +} + +// note: cache is locked inside this function +func (c *UserInfoCache) removeFromCache(roomID id.RoomID, mxid id.UserID) { + c.Lock() + defer c.Unlock() + + if user, userPresent := c.users[mxid]; userPresent { + if _, roomPresent := user.perChannel[roomID]; roomPresent { + delete(user.perChannel, roomID) + c.users[mxid] = user + } + } +} diff --git a/bridge/matrix/handlers.go b/bridge/matrix/handlers.go index 0900a94f..c8ef604c 100644 --- a/bridge/matrix/handlers.go +++ b/bridge/matrix/handlers.go @@ -145,12 +145,12 @@ func (b *Bmatrix) handleMemberChange(ev *event.Event) { // Update the displayname on join messages, according to https://spec.matrix.org/v1.3/client-server-api/#events-on-change-of-profile-information if member.Membership == event.MembershipJoin { b.cacheDisplayName(ev.RoomID, ev.Sender, member.Displayname) + b.cacheAvatarURL(ev.RoomID, ev.Sender, member.AvatarURL) } else if member.Membership == event.MembershipLeave || member.Membership == event.MembershipBan { - b.removeDisplayNameFromCache(ev.Sender, ev.RoomID) + b.UserCache.removeFromCache(ev.RoomID, ev.Sender) } } -//nolint: funlen func (b *Bmatrix) handleMessage(rmsg config.Message, ev *event.Event) { msg := ev.Content.AsMessage() if msg == nil { @@ -162,13 +162,7 @@ func (b *Bmatrix) handleMessage(rmsg config.Message, ev *event.Event) { rmsg.Text = msg.Body - // TODO: cache the avatars - avatarURL := b.getAvatarURL(ev.Sender) - contentURI, err := id.ParseContentURI(avatarURL) - if err == nil { - avatarURL = b.mc.GetDownloadURL(contentURI) - rmsg.Avatar = avatarURL - } + rmsg.Avatar = b.getAvatarURL(ev.RoomID, ev.Sender) //nolint: exhaustive switch msg.MsgType { diff --git a/bridge/matrix/helpers.go b/bridge/matrix/helpers.go index bc9ca403..c2bc94ab 100644 --- a/bridge/matrix/helpers.go +++ b/bridge/matrix/helpers.go @@ -4,8 +4,6 @@ import ( "errors" "fmt" "html" - "sort" - "sync" "time" matrix "maunium.net/go/mautrix" @@ -45,55 +43,6 @@ func (b *Bmatrix) getRoomID(channelName string) id.RoomID { return "" } -type NicknameCacheEntry struct { - displayName string - lastUpdated time.Time - conflictWithOtherUsername bool -} - -type NicknameUserEntry struct { - globalEntry *NicknameCacheEntry - perChannel map[id.RoomID]NicknameCacheEntry -} - -type NicknameCache struct { - users map[id.UserID]NicknameUserEntry - sync.RWMutex -} - -func NewNicknameCache() *NicknameCache { - return &NicknameCache{ - users: make(map[id.UserID]NicknameUserEntry), - RWMutex: sync.RWMutex{}, - } -} - -// note: cache is not locked here -func (c *NicknameCache) retrieveDisplaynameFromCache(channelID id.RoomID, mxid id.UserID) string { - var cachedEntry *NicknameCacheEntry = nil - - c.RLock() - if user, userPresent := c.users[mxid]; userPresent { - // try first the name of the user in the room, then globally - if roomCachedEntry, roomPresent := user.perChannel[channelID]; roomPresent { - cachedEntry = &roomCachedEntry - } else if user.globalEntry != nil { - cachedEntry = user.globalEntry - } - } - c.RUnlock() - - if cachedEntry == nil { - return "" - } - - if cachedEntry.conflictWithOtherUsername { - return fmt.Sprintf("%s (%s)", cachedEntry.displayName, mxid) - } - - return cachedEntry.displayName -} - func (b *Bmatrix) retrieveGlobalDisplayname(mxid id.UserID) string { displayName, err := b.mc.GetDisplayName(mxid) var httpError *matrix.HTTPError @@ -114,67 +63,28 @@ func (b *Bmatrix) getDisplayName(channelID id.RoomID, mxid id.UserID) string { return string(mxid)[1:] } - displayname := b.NicknameCache.retrieveDisplaynameFromCache(channelID, mxid) - if displayname != "" { - return displayname + cachedEntry := b.UserCache.retrieveUserInRoomFromCache(channelID, mxid) + if cachedEntry == nil || cachedEntry.displayName == nil { + // retrieve the global display name + return b.cacheDisplayName("", mxid, b.retrieveGlobalDisplayname(mxid)) } - // retrieve the global display name - return b.cacheDisplayName("", mxid, b.retrieveGlobalDisplayname(mxid)) -} - -// scan to delete old entries, to stop memory usage from becoming high with obsolete entries. -// note: assume the cache is already write-locked -// TODO: should we update the timestamp when the entry is used? -func (c *NicknameCache) clearObsoleteEntries(mxid id.UserID) { - // we have a "off-by-one" to account for when the user being added to the - // cache already have obsolete cache entries, as we want to keep it because - // we will be refreshing it in a minute - if len(c.users) <= MaxNumberOfUsersInCache+1 { - return + if cachedEntry.conflictWithOtherUsername { + return fmt.Sprintf("%s (%s)", *cachedEntry.displayName, mxid) } - usersLastTimestamp := make(map[id.UserID]int64, len(c.users)) - // compute the last updated timestamp entry for each user - for mxidIter, NicknameCacheIter := range c.users { - userLastTimestamp := time.Unix(0, 0) - for _, userInChannelCacheEntry := range NicknameCacheIter.perChannel { - if userInChannelCacheEntry.lastUpdated.After(userLastTimestamp) { - userLastTimestamp = userInChannelCacheEntry.lastUpdated - } - } - - if NicknameCacheIter.globalEntry != nil { - if NicknameCacheIter.globalEntry.lastUpdated.After(userLastTimestamp) { - userLastTimestamp = NicknameCacheIter.globalEntry.lastUpdated - } - } - - usersLastTimestamp[mxidIter] = userLastTimestamp.UnixNano() - } - - // get the limit timestamp before which we must clear entries as obsolete - sortedTimestamps := make([]int64, 0, len(usersLastTimestamp)) - for _, value := range usersLastTimestamp { - sortedTimestamps = append(sortedTimestamps, value) - } - sort.Slice(sortedTimestamps, func(i, j int) bool { return sortedTimestamps[i] < sortedTimestamps[j] }) - limitTimestamp := sortedTimestamps[len(sortedTimestamps)-MaxNumberOfUsersInCache] - - // delete entries older than the limit - for mxidIter, timestamp := range usersLastTimestamp { - // do not clear the user that we are adding to the cache - if timestamp <= limitTimestamp && mxidIter != mxid { - delete(c.users, mxidIter) - } - } + return *cachedEntry.displayName } // to prevent username reuse across matrix rooms - or even inside the same room, if a user uses multiple servers - -// identify users with naming conflicts -func (c *NicknameCache) detectConflict(mxid id.UserID, displayName string) bool { +// identify users with naming conflicts. +// Note: this function locks the cache +func (c *UserInfoCache) detectDisplayNameConflicts(mxid id.UserID, displayName string) bool { conflict := false + c.RLock() + defer c.RUnlock() + for mxidIter, NicknameCacheIter := range c.users { // skip conflict detection against ourselves, obviously if mxidIter == mxid { @@ -182,14 +92,14 @@ func (c *NicknameCache) detectConflict(mxid id.UserID, displayName string) bool } for channelID, userInChannelCacheEntry := range NicknameCacheIter.perChannel { - if userInChannelCacheEntry.displayName == displayName { + if userInChannelCacheEntry.displayName != nil && *userInChannelCacheEntry.displayName == displayName { userInChannelCacheEntry.conflictWithOtherUsername = true c.users[mxidIter].perChannel[channelID] = userInChannelCacheEntry conflict = true } } - if NicknameCacheIter.globalEntry != nil && NicknameCacheIter.globalEntry.displayName == displayName { + if NicknameCacheIter.globalEntry != nil && NicknameCacheIter.globalEntry.displayName != nil && *NicknameCacheIter.globalEntry.displayName == displayName { c.users[mxidIter].globalEntry.conflictWithOtherUsername = true conflict = true } @@ -202,68 +112,56 @@ func (c *NicknameCache) detectConflict(mxid id.UserID, displayName string) bool // later without performing a query to the homeserver. // Note that old entries are cleaned when this function is called. func (b *Bmatrix) cacheDisplayName(channelID id.RoomID, mxid id.UserID, displayName string) string { - now := time.Now() + conflict := b.UserCache.detectDisplayNameConflicts(mxid, displayName) - cache := b.NicknameCache - - cache.Lock() - defer cache.Unlock() - - conflict := cache.detectConflict(mxid, displayName) - - cache.clearObsoleteEntries(mxid) - - var newEntry NicknameUserEntry - if user, userPresent := cache.users[mxid]; userPresent { - newEntry = user - } else { - newEntry = NicknameUserEntry{ - globalEntry: nil, - perChannel: make(map[id.RoomID]NicknameCacheEntry), - } - } - - cacheEntry := NicknameCacheEntry{ - displayName: displayName, - lastUpdated: now, - conflictWithOtherUsername: conflict, - } - - if channelID == "" { - newEntry.globalEntry = &cacheEntry - } else { - // this is a local (room-specific) display name, let's cache it as such - newEntry.perChannel[channelID] = cacheEntry - } - - cache.users[mxid] = newEntry + b.cacheEntry(channelID, mxid, func(entry UserInRoomCacheEntry) UserInRoomCacheEntry { + entry.displayName = &displayName + entry.conflictWithOtherUsername = conflict + return entry + }) return displayName } -func (b *Bmatrix) removeDisplayNameFromCache(mxid id.UserID, roomID id.RoomID) { - cache := b.NicknameCache - - cache.Lock() - defer cache.Unlock() - - if user, userPresent := cache.users[mxid]; userPresent { - if _, roomPresent := user.perChannel[roomID]; roomPresent { - delete(user.perChannel, roomID) - cache.users[mxid] = user - } - } -} - -// getAvatarURL returns the avatar URL of the specified sender. -func (b *Bmatrix) getAvatarURL(sender id.UserID) string { - url, err := b.mc.GetAvatarURL(sender) +// retrieveGlobalAvatarURL returns the global avatar URL of the specified user. +func (b *Bmatrix) retrieveGlobalAvatarURL(mxid id.UserID) id.ContentURIString { + url, err := b.mc.GetAvatarURL(mxid) if err != nil { - b.Log.Errorf("Couldn't retrieve the URL of the avatar for MXID %s", sender) + b.Log.Errorf("Couldn't retrieve the URL of the avatar for MXID %s", mxid) return "" } - return url.String() + return id.ContentURIString(url.String()) +} + +// getAvatarURL retrieves the avatar URL for mxid, querying the homeserver if the mxid is not in the cache. +func (b *Bmatrix) getAvatarURL(channelID id.RoomID, mxid id.UserID) string { + cachedEntry := b.UserCache.retrieveUserInRoomFromCache(channelID, mxid) + if cachedEntry == nil || cachedEntry.avatarURL == nil { + // retrieve the global display name + return b.cacheAvatarURL("", mxid, b.retrieveGlobalAvatarURL(mxid)) + } + + return *cachedEntry.avatarURL +} + +// cacheAvatarURL stores the mapping between a mxid and the URL of the user avatar, to be reused +// later without performing a query to the homeserver. +// Note that old entries are cleaned when this function is called. +func (b *Bmatrix) cacheAvatarURL(channelID id.RoomID, mxid id.UserID, avatarURL id.ContentURIString) string { + contentURI, err := id.ParseContentURI(string(avatarURL)) + if err != nil { + return "" + } + + fullURL := b.mc.GetDownloadURL(contentURI) + + b.cacheEntry(channelID, mxid, func(entry UserInRoomCacheEntry) UserInRoomCacheEntry { + entry.avatarURL = &fullURL + return entry + }) + + return fullURL } // handleRatelimit handles the ratelimit errors and return if we're ratelimited and the amount of time to sleep diff --git a/bridge/matrix/matrix.go b/bridge/matrix/matrix.go index 100924e0..20a33114 100644 --- a/bridge/matrix/matrix.go +++ b/bridge/matrix/matrix.go @@ -33,13 +33,13 @@ type RoomInfo struct { } type Bmatrix struct { - mc *matrix.Client - UserID id.UserID - appService *AppServiceWrapper - NicknameCache *NicknameCache - RoomMap map[id.RoomID]RoomInfo - rateMutex sync.RWMutex - joinedRooms []id.RoomID + mc *matrix.Client + UserID id.UserID + appService *AppServiceWrapper + UserCache *UserInfoCache + RoomMap map[id.RoomID]RoomInfo + rateMutex sync.RWMutex + joinedRooms []id.RoomID sync.RWMutex *bridge.Config stopNormalSync chan struct{} @@ -54,7 +54,7 @@ type matrixUsername struct { func New(cfg *bridge.Config) bridge.Bridger { b := &Bmatrix{Config: cfg} b.RoomMap = make(map[id.RoomID]RoomInfo) - b.NicknameCache = NewNicknameCache() + b.UserCache = NewUserInfoCache() b.stopNormalSync = make(chan struct{}, 1) b.stopNormalSyncAck = make(chan struct{}, 1) return b @@ -333,6 +333,7 @@ func (b *Bmatrix) Send(msg config.Message) (string, error) { // DontProcessOldEvents returns true if a sync event should be considered for further processing. // We use that function to filter out events we have already read. +//nolint: gocognit func (b *Bmatrix) DontProcessOldEvents(resp *matrix.RespSync, since string) bool { // we only filter old events in the initial sync(), because subsequent sync() // (where since != "") should only return new events