forked from jshiffer/matterbridge
246 lines
8.0 KiB
Go
246 lines
8.0 KiB
Go
|
// Copyright (c) 2021 Tulir Asokan
|
||
|
//
|
||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||
|
|
||
|
package sqlstore
|
||
|
|
||
|
import (
|
||
|
"crypto/rand"
|
||
|
"database/sql"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
mathRand "math/rand"
|
||
|
|
||
|
waProto "go.mau.fi/whatsmeow/binary/proto"
|
||
|
"go.mau.fi/whatsmeow/store"
|
||
|
"go.mau.fi/whatsmeow/types"
|
||
|
"go.mau.fi/whatsmeow/util/keys"
|
||
|
waLog "go.mau.fi/whatsmeow/util/log"
|
||
|
)
|
||
|
|
||
|
// Container is a wrapper for a SQL database that can contain multiple whatsmeow sessions.
|
||
|
type Container struct {
|
||
|
db *sql.DB
|
||
|
dialect string
|
||
|
log waLog.Logger
|
||
|
}
|
||
|
|
||
|
var _ store.DeviceContainer = (*Container)(nil)
|
||
|
|
||
|
// New connects to the given SQL database and wraps it in a Container.
|
||
|
//
|
||
|
// Only SQLite and Postgres are currently fully supported.
|
||
|
//
|
||
|
// The logger can be nil and will default to a no-op logger.
|
||
|
//
|
||
|
// When using SQLite, it's strongly recommended to enable foreign keys by adding `?_foreign_keys=true`:
|
||
|
// container, err := sqlstore.New("sqlite3", "file:yoursqlitefile.db?_foreign_keys=on", nil)
|
||
|
func New(dialect, address string, log waLog.Logger) (*Container, error) {
|
||
|
db, err := sql.Open(dialect, address)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||
|
}
|
||
|
container := NewWithDB(db, dialect, log)
|
||
|
err = container.Upgrade()
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to upgrade database: %w", err)
|
||
|
}
|
||
|
return container, nil
|
||
|
}
|
||
|
|
||
|
// NewWithDB wraps an existing SQL connection in a Container.
|
||
|
//
|
||
|
// Only SQLite and Postgres are currently fully supported.
|
||
|
//
|
||
|
// The logger can be nil and will default to a no-op logger.
|
||
|
//
|
||
|
// When using SQLite, it's strongly recommended to enable foreign keys by adding `?_foreign_keys=true`:
|
||
|
// db, err := sql.Open("sqlite3", "file:yoursqlitefile.db?_foreign_keys=on")
|
||
|
// if err != nil {
|
||
|
// panic(err)
|
||
|
// }
|
||
|
// container, err := sqlstore.NewWithDB(db, "sqlite3", nil)
|
||
|
func NewWithDB(db *sql.DB, dialect string, log waLog.Logger) *Container {
|
||
|
if log == nil {
|
||
|
log = waLog.Noop
|
||
|
}
|
||
|
return &Container{
|
||
|
db: db,
|
||
|
dialect: dialect,
|
||
|
log: log,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
const getAllDevicesQuery = `
|
||
|
SELECT jid, registration_id, noise_key, identity_key,
|
||
|
signed_pre_key, signed_pre_key_id, signed_pre_key_sig,
|
||
|
adv_key, adv_details, adv_account_sig, adv_device_sig,
|
||
|
platform, business_name, push_name
|
||
|
FROM whatsmeow_device
|
||
|
`
|
||
|
|
||
|
const getDeviceQuery = getAllDevicesQuery + " WHERE jid=$1"
|
||
|
|
||
|
type scannable interface {
|
||
|
Scan(dest ...interface{}) error
|
||
|
}
|
||
|
|
||
|
func (c *Container) scanDevice(row scannable) (*store.Device, error) {
|
||
|
var device store.Device
|
||
|
device.Log = c.log
|
||
|
device.SignedPreKey = &keys.PreKey{}
|
||
|
var noisePriv, identityPriv, preKeyPriv, preKeySig []byte
|
||
|
var account waProto.ADVSignedDeviceIdentity
|
||
|
|
||
|
err := row.Scan(
|
||
|
&device.ID, &device.RegistrationID, &noisePriv, &identityPriv,
|
||
|
&preKeyPriv, &device.SignedPreKey.KeyID, &preKeySig,
|
||
|
&device.AdvSecretKey, &account.Details, &account.AccountSignature, &account.DeviceSignature,
|
||
|
&device.Platform, &device.BusinessName, &device.PushName)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to scan session: %w", err)
|
||
|
} else if len(noisePriv) != 32 || len(identityPriv) != 32 || len(preKeyPriv) != 32 || len(preKeySig) != 64 {
|
||
|
return nil, ErrInvalidLength
|
||
|
}
|
||
|
|
||
|
device.NoiseKey = keys.NewKeyPairFromPrivateKey(*(*[32]byte)(noisePriv))
|
||
|
device.IdentityKey = keys.NewKeyPairFromPrivateKey(*(*[32]byte)(identityPriv))
|
||
|
device.SignedPreKey.KeyPair = *keys.NewKeyPairFromPrivateKey(*(*[32]byte)(preKeyPriv))
|
||
|
device.SignedPreKey.Signature = (*[64]byte)(preKeySig)
|
||
|
device.Account = &account
|
||
|
|
||
|
innerStore := NewSQLStore(c, *device.ID)
|
||
|
device.Identities = innerStore
|
||
|
device.Sessions = innerStore
|
||
|
device.PreKeys = innerStore
|
||
|
device.SenderKeys = innerStore
|
||
|
device.AppStateKeys = innerStore
|
||
|
device.AppState = innerStore
|
||
|
device.Contacts = innerStore
|
||
|
device.ChatSettings = innerStore
|
||
|
device.Container = c
|
||
|
device.Initialized = true
|
||
|
|
||
|
return &device, nil
|
||
|
}
|
||
|
|
||
|
// GetAllDevices finds all the devices in the database.
|
||
|
func (c *Container) GetAllDevices() ([]*store.Device, error) {
|
||
|
res, err := c.db.Query(getAllDevicesQuery)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("failed to query sessions: %w", err)
|
||
|
}
|
||
|
sessions := make([]*store.Device, 0)
|
||
|
for res.Next() {
|
||
|
sess, scanErr := c.scanDevice(res)
|
||
|
if scanErr != nil {
|
||
|
return sessions, scanErr
|
||
|
}
|
||
|
sessions = append(sessions, sess)
|
||
|
}
|
||
|
return sessions, nil
|
||
|
}
|
||
|
|
||
|
// GetFirstDevice is a convenience method for getting the first device in the store. If there are
|
||
|
// no devices, then a new device will be created. You should only use this if you don't want to
|
||
|
// have multiple sessions simultaneously.
|
||
|
func (c *Container) GetFirstDevice() (*store.Device, error) {
|
||
|
devices, err := c.GetAllDevices()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len(devices) == 0 {
|
||
|
return c.NewDevice(), nil
|
||
|
} else {
|
||
|
return devices[0], nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// GetDevice finds the device with the specified JID in the database.
|
||
|
//
|
||
|
// If the device is not found, nil is returned instead.
|
||
|
//
|
||
|
// Note that the parameter usually must be an AD-JID.
|
||
|
func (c *Container) GetDevice(jid types.JID) (*store.Device, error) {
|
||
|
sess, err := c.scanDevice(c.db.QueryRow(getDeviceQuery, jid))
|
||
|
if errors.Is(err, sql.ErrNoRows) {
|
||
|
return nil, nil
|
||
|
}
|
||
|
return sess, err
|
||
|
}
|
||
|
|
||
|
const (
|
||
|
insertDeviceQuery = `
|
||
|
INSERT INTO whatsmeow_device (jid, registration_id, noise_key, identity_key,
|
||
|
signed_pre_key, signed_pre_key_id, signed_pre_key_sig,
|
||
|
adv_key, adv_details, adv_account_sig, adv_device_sig,
|
||
|
platform, business_name, push_name)
|
||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)
|
||
|
ON CONFLICT (jid) DO UPDATE SET platform=$12, business_name=$13, push_name=$14
|
||
|
`
|
||
|
deleteDeviceQuery = `DELETE FROM whatsmeow_device WHERE jid=$1`
|
||
|
)
|
||
|
|
||
|
// NewDevice creates a new device in this database.
|
||
|
//
|
||
|
// No data is actually stored before Save is called. However, the pairing process will automatically
|
||
|
// call Save after a successful pairing, so you most likely don't need to call it yourself.
|
||
|
func (c *Container) NewDevice() *store.Device {
|
||
|
device := &store.Device{
|
||
|
Log: c.log,
|
||
|
Container: c,
|
||
|
|
||
|
NoiseKey: keys.NewKeyPair(),
|
||
|
IdentityKey: keys.NewKeyPair(),
|
||
|
RegistrationID: mathRand.Uint32(),
|
||
|
AdvSecretKey: make([]byte, 32),
|
||
|
}
|
||
|
_, err := rand.Read(device.AdvSecretKey)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
device.SignedPreKey = device.IdentityKey.CreateSignedPreKey(1)
|
||
|
return device
|
||
|
}
|
||
|
|
||
|
// ErrDeviceIDMustBeSet is the error returned by PutDevice if you try to save a device before knowing its JID.
|
||
|
var ErrDeviceIDMustBeSet = errors.New("device JID must be known before accessing database")
|
||
|
|
||
|
// PutDevice stores the given device in this database. This should be called through Device.Save()
|
||
|
// (which usually doesn't need to be called manually, as the library does that automatically when relevant).
|
||
|
func (c *Container) PutDevice(device *store.Device) error {
|
||
|
if device.ID == nil {
|
||
|
return ErrDeviceIDMustBeSet
|
||
|
}
|
||
|
_, err := c.db.Exec(insertDeviceQuery,
|
||
|
device.ID.String(), device.RegistrationID, device.NoiseKey.Priv[:], device.IdentityKey.Priv[:],
|
||
|
device.SignedPreKey.Priv[:], device.SignedPreKey.KeyID, device.SignedPreKey.Signature[:],
|
||
|
device.AdvSecretKey, device.Account.Details, device.Account.AccountSignature, device.Account.DeviceSignature,
|
||
|
device.Platform, device.BusinessName, device.PushName)
|
||
|
|
||
|
if !device.Initialized {
|
||
|
innerStore := NewSQLStore(c, *device.ID)
|
||
|
device.Identities = innerStore
|
||
|
device.Sessions = innerStore
|
||
|
device.PreKeys = innerStore
|
||
|
device.SenderKeys = innerStore
|
||
|
device.AppStateKeys = innerStore
|
||
|
device.AppState = innerStore
|
||
|
device.Contacts = innerStore
|
||
|
device.ChatSettings = innerStore
|
||
|
device.Initialized = true
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// DeleteDevice deletes the given device from this database. This should be called through Device.Delete()
|
||
|
func (c *Container) DeleteDevice(store *store.Device) error {
|
||
|
if store.ID == nil {
|
||
|
return ErrDeviceIDMustBeSet
|
||
|
}
|
||
|
_, err := c.db.Exec(deleteDeviceQuery, store.ID.String())
|
||
|
return err
|
||
|
}
|