Add persistent message map

Resolves #541
This commit is contained in:
Yousef Mansy
2023-02-28 22:50:17 -08:00
parent 0527f01904
commit c0f5d0c5f7
106 changed files with 22946 additions and 6 deletions

View File

@@ -15,6 +15,7 @@ import (
"github.com/d5/tengo/v2/stdlib"
lru "github.com/hashicorp/golang-lru"
"github.com/kyokomi/emoji/v2"
"github.com/philippgille/gokv"
"github.com/sirupsen/logrus"
)
@@ -29,14 +30,17 @@ type Gateway struct {
Message chan config.Message
Name string
Messages *lru.Cache
MessageStore gokv.Store
CanonicalStore gokv.Store
logger *logrus.Entry
}
type BrMsgID struct {
br *bridge.Bridge
ID string
Protocol string
DestName string
ChannelID string
ID string
}
const apiProtocol = "api"
@@ -59,12 +63,41 @@ func New(rootLogger *logrus.Logger, cfg *config.Gateway, r *Router) *Gateway {
if err := gw.AddConfig(cfg); err != nil {
logger.Errorf("Failed to add configuration to gateway: %#v", err)
}
persistentMessageStorePath, usePersistent := gw.Config.GetString("PersistentMessageStorePath")
if usePersistent {
rootPath := fmt.Sprintf("%s/%s", persistentMessageStorePath, gw.Name)
os.MkdirAll(rootPath, os.ModePerm)
gw.MessageStore = gw.getMessageMapStore(fmt.Sprintf("%s/Messages", rootPath))
gw.CanonicalStore = gw.getMessageMapStore(fmt.Sprintf("%s/Canonical", rootPath))
}
return gw
}
func (gw *Gateway) SetMessageMap(canonicalMsgID string, msgIDs []*BrMsgID) {
_, usePersistent := gw.Config.GetString("PersistentMessageStorePath")
if usePersistent {
gw.setDestMessagesToStore(canonicalMsgID, msgIDs)
} else {
gw.Messages.Add(canonicalMsgID, msgIDs)
}
}
// FindCanonicalMsgID returns the ID under which a message was stored in the cache.
func (gw *Gateway) FindCanonicalMsgID(protocol string, mID string) string {
ID := protocol + " " + mID
_, usePersistent := gw.Config.GetString("PersistentMessageStorePath")
if usePersistent {
return gw.getCanonicalMessageFromStore(ID)
} else {
return gw.getCanonicalMessageFromMemCache(ID)
}
}
func (gw *Gateway) getCanonicalMessageFromMemCache(ID string) string {
if gw.Messages.Contains(ID) {
return ID
}
@@ -259,13 +292,26 @@ func (gw *Gateway) getDestChannel(msg *config.Message, dest bridge.Bridge) []con
}
func (gw *Gateway) getDestMsgID(msgID string, dest *bridge.Bridge, channel *config.ChannelInfo) string {
var destID string
_, usePersistent := gw.Config.GetString("PersistentMessageStorePath")
if usePersistent {
destID = gw.getDestMessagesFromStore(msgID, dest, channel)
} else {
destID = gw.getDestMessageFromMemCache(msgID, dest, channel)
}
return strings.Replace(destID, dest.Protocol+" ", "", 1)
}
func (gw *Gateway) getDestMessageFromMemCache(msgID string, dest *bridge.Bridge, channel *config.ChannelInfo) string {
if res, ok := gw.Messages.Get(msgID); ok {
IDs := res.([]*BrMsgID)
for _, id := range IDs {
// check protocol, bridge name and channelname
// for people that reuse the same bridge multiple times. see #342
if dest.Protocol == id.br.Protocol && dest.Name == id.br.Name && channel.ID == id.ChannelID {
return strings.Replace(id.ID, dest.Protocol+" ", "", 1)
if dest.Protocol == id.Protocol && dest.Name == id.DestName && channel.ID == id.ChannelID {
return id.ID
}
}
}

View File

@@ -231,7 +231,13 @@ func (gw *Gateway) handleMessage(rmsg *config.Message, dest *bridge.Bridge) []*B
if msgID == "" {
continue
}
brMsgIDs = append(brMsgIDs, &BrMsgID{dest, dest.Protocol + " " + msgID, channel.ID})
brMsgIDs = append(brMsgIDs,
&BrMsgID{
Protocol: dest.Protocol,
DestName: dest.Name,
ChannelID: channel.ID,
ID: msgID,
})
}
return brMsgIDs
}

83
gateway/persistent.go Normal file
View File

@@ -0,0 +1,83 @@
package gateway
import (
"github.com/42wim/matterbridge/bridge"
"github.com/42wim/matterbridge/bridge/config"
"github.com/philippgille/gokv"
"github.com/philippgille/gokv/badgerdb"
"github.com/philippgille/gokv/encoding"
)
func (gw *Gateway) getMessageMapStore(path string) gokv.Store {
options := badgerdb.Options{
Dir: path,
Codec: encoding.Gob,
}
store, err := badgerdb.NewStore(options)
if err != nil {
gw.logger.Error(err)
gw.logger.Errorf("Could not connect to db: %s", path)
}
return store
}
func (gw *Gateway) getCanonicalMessageFromStore(messageID string) string {
if messageID == "" {
return ""
}
canonicalMsgID := new(string)
found, err := gw.CanonicalStore.Get(messageID, canonicalMsgID)
if err != nil {
gw.logger.Error(err)
}
if found {
return *canonicalMsgID
}
return ""
}
func (gw *Gateway) setCanonicalMessageToStore(messageID string, canonicalMsgID string) {
err := gw.CanonicalStore.Set(messageID, canonicalMsgID)
if err != nil {
gw.logger.Error(err)
}
}
func (gw *Gateway) getDestMessagesFromStore(canonicalMsgID string, dest *bridge.Bridge, channel *config.ChannelInfo) string {
if canonicalMsgID == "" {
return ""
}
destMessageIds := new([]BrMsgID)
found, err := gw.MessageStore.Get(canonicalMsgID, destMessageIds)
if err != nil {
gw.logger.Error(err)
}
if found {
for _, id := range *destMessageIds {
// check protocol, bridge name and channelname
// for people that reuse the same bridge multiple times. see #342
if dest.Protocol == id.Protocol && dest.Name == id.DestName && channel.ID == id.ChannelID {
return id.ID
}
}
}
return ""
}
func (gw *Gateway) setDestMessagesToStore(canonicalMsgID string, msgIDs []*BrMsgID) {
for _, msgID := range msgIDs {
gw.setCanonicalMessageToStore(msgID.Protocol+" "+msgID.ID, canonicalMsgID)
}
err := gw.MessageStore.Set(canonicalMsgID, msgIDs)
if err != nil {
gw.logger.Error(err)
}
}

View File

@@ -163,7 +163,21 @@ func (r *Router) handleReceive() {
// This is necessary as msgIDs will change if a bridge returns
// a different ID in response to edits.
if !exists {
gw.Messages.Add(msg.Protocol+" "+msg.ID, msgIDs)
// we're adding the original message as a "dest message"
// as when we get the dest messages for a delete the source message isnt in the list
// therefore the delete doesnt happen on the source platform.
/* ! use this when merging #1991 (these many branches are getting hard to keep track of)
msgIDs = append(msgIDs,
&BrMsgID{
Protocol: srcBridge.Protocol,
DestName: srcBridge.Name,
ChannelID: msg.Channel + srcBridge.Account,
ID: msg.ID,
})
*/
gw.SetMessageMap(msg.Protocol+" "+msg.ID, msgIDs)
}
}
}