56
vendor/github.com/status-im/status-go/wakuv2/persistence/dbkey.go
generated
vendored
Normal file
56
vendor/github.com/status-im/status-go/wakuv2/persistence/dbkey.go
generated
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/waku-org/go-waku/waku/v2/hash"
|
||||
)
|
||||
|
||||
const (
|
||||
TimestampLength = 8
|
||||
HashLength = 32
|
||||
DigestLength = HashLength
|
||||
PubsubTopicLength = HashLength
|
||||
DBKeyLength = TimestampLength + PubsubTopicLength + DigestLength
|
||||
)
|
||||
|
||||
type Hash [HashLength]byte
|
||||
|
||||
var (
|
||||
// ErrInvalidByteSize is returned when DBKey can't be created
|
||||
// from a byte slice because it has invalid length.
|
||||
ErrInvalidByteSize = errors.New("byte slice has invalid length")
|
||||
)
|
||||
|
||||
// DBKey key to be stored in a db.
|
||||
type DBKey struct {
|
||||
raw []byte
|
||||
}
|
||||
|
||||
// Bytes returns a bytes representation of the DBKey.
|
||||
func (k *DBKey) Bytes() []byte {
|
||||
return k.raw
|
||||
}
|
||||
|
||||
// NewDBKey creates a new DBKey with the given values.
|
||||
func NewDBKey(senderTimestamp uint64, receiverTimestamp uint64, pubsubTopic string, digest []byte) *DBKey {
|
||||
pubSubHash := make([]byte, PubsubTopicLength)
|
||||
if pubsubTopic != "" {
|
||||
pubSubHash = hash.SHA256([]byte(pubsubTopic))
|
||||
}
|
||||
|
||||
var k DBKey
|
||||
k.raw = make([]byte, DBKeyLength)
|
||||
|
||||
if senderTimestamp == 0 {
|
||||
binary.BigEndian.PutUint64(k.raw, receiverTimestamp)
|
||||
} else {
|
||||
binary.BigEndian.PutUint64(k.raw, senderTimestamp)
|
||||
}
|
||||
|
||||
copy(k.raw[TimestampLength:], pubSubHash[:])
|
||||
copy(k.raw[TimestampLength+PubsubTopicLength:], digest)
|
||||
|
||||
return &k
|
||||
}
|
||||
421
vendor/github.com/status-im/status-go/wakuv2/persistence/dbstore.go
generated
vendored
Normal file
421
vendor/github.com/status-im/status-go/wakuv2/persistence/dbstore.go
generated
vendored
Normal file
@@ -0,0 +1,421 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
gowakuPersistence "github.com/waku-org/go-waku/waku/persistence"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/pb"
|
||||
storepb "github.com/waku-org/go-waku/waku/v2/protocol/store/pb"
|
||||
"github.com/waku-org/go-waku/waku/v2/timesource"
|
||||
"github.com/waku-org/go-waku/waku/v2/utils"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var ErrInvalidCursor = errors.New("invalid cursor")
|
||||
|
||||
var ErrFutureMessage = errors.New("message timestamp in the future")
|
||||
var ErrMessageTooOld = errors.New("message too old")
|
||||
|
||||
// MaxTimeVariance is the maximum duration in the future allowed for a message timestamp
|
||||
const MaxTimeVariance = time.Duration(20) * time.Second
|
||||
|
||||
// DBStore is a MessageProvider that has a *sql.DB connection
|
||||
type DBStore struct {
|
||||
db *sql.DB
|
||||
log *zap.Logger
|
||||
|
||||
maxMessages int
|
||||
maxDuration time.Duration
|
||||
|
||||
wg sync.WaitGroup
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// DBOption is an optional setting that can be used to configure the DBStore
|
||||
type DBOption func(*DBStore) error
|
||||
|
||||
// WithDB is a DBOption that lets you use any custom *sql.DB with a DBStore.
|
||||
func WithDB(db *sql.DB) DBOption {
|
||||
return func(d *DBStore) error {
|
||||
d.db = db
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithRetentionPolicy is a DBOption that specifies the max number of messages
|
||||
// to be stored and duration before they're removed from the message store
|
||||
func WithRetentionPolicy(maxMessages int, maxDuration time.Duration) DBOption {
|
||||
return func(d *DBStore) error {
|
||||
d.maxDuration = maxDuration
|
||||
d.maxMessages = maxMessages
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Creates a new DB store using the db specified via options.
|
||||
// It will create a messages table if it does not exist and
|
||||
// clean up records according to the retention policy used
|
||||
func NewDBStore(log *zap.Logger, options ...DBOption) (*DBStore, error) {
|
||||
result := new(DBStore)
|
||||
result.log = log.Named("dbstore")
|
||||
|
||||
for _, opt := range options {
|
||||
err := opt(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (d *DBStore) Start(ctx context.Context, timesource timesource.Timesource) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
d.cancel = cancel
|
||||
|
||||
err := d.cleanOlderRecords()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.wg.Add(1)
|
||||
go d.checkForOlderRecords(ctx, 60*time.Second)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DBStore) Validate(env *protocol.Envelope) error {
|
||||
n := time.Unix(0, env.Index().ReceiverTime)
|
||||
upperBound := n.Add(MaxTimeVariance)
|
||||
lowerBound := n.Add(-MaxTimeVariance)
|
||||
|
||||
// Ensure that messages don't "jump" to the front of the queue with future timestamps
|
||||
if env.Message().GetTimestamp() > upperBound.UnixNano() {
|
||||
return ErrFutureMessage
|
||||
}
|
||||
|
||||
if env.Message().GetTimestamp() < lowerBound.UnixNano() {
|
||||
return ErrMessageTooOld
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DBStore) cleanOlderRecords() error {
|
||||
d.log.Debug("Cleaning older records...")
|
||||
|
||||
// Delete older messages
|
||||
if d.maxDuration > 0 {
|
||||
start := time.Now()
|
||||
sqlStmt := `DELETE FROM store_messages WHERE receiverTimestamp < ?`
|
||||
_, err := d.db.Exec(sqlStmt, utils.GetUnixEpochFrom(time.Now().Add(-d.maxDuration)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
d.log.Debug("deleting older records from the DB", zap.Duration("duration", elapsed))
|
||||
}
|
||||
|
||||
// Limit number of records to a max N
|
||||
if d.maxMessages > 0 {
|
||||
start := time.Now()
|
||||
sqlStmt := `DELETE FROM store_messages WHERE id IN (SELECT id FROM store_messages ORDER BY receiverTimestamp DESC LIMIT -1 OFFSET ?)`
|
||||
_, err := d.db.Exec(sqlStmt, d.maxMessages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
elapsed := time.Since(start)
|
||||
d.log.Debug("deleting excess records from the DB", zap.Duration("duration", elapsed))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DBStore) checkForOlderRecords(ctx context.Context, t time.Duration) {
|
||||
defer d.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(t)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
err := d.cleanOlderRecords()
|
||||
if err != nil {
|
||||
d.log.Error("cleaning older records", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop closes a DB connection
|
||||
func (d *DBStore) Stop() {
|
||||
if d.cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
d.cancel()
|
||||
d.wg.Wait()
|
||||
d.db.Close()
|
||||
}
|
||||
|
||||
// Put inserts a WakuMessage into the DB
|
||||
func (d *DBStore) Put(env *protocol.Envelope) error {
|
||||
stmt, err := d.db.Prepare("INSERT INTO store_messages (id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version) VALUES (?, ?, ?, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cursor := env.Index()
|
||||
dbKey := NewDBKey(uint64(cursor.SenderTime), uint64(env.Index().ReceiverTime), env.PubsubTopic(), env.Index().Digest)
|
||||
_, err = stmt.Exec(dbKey.Bytes(), cursor.ReceiverTime, env.Message().Timestamp, env.Message().ContentTopic, env.PubsubTopic(), env.Message().Payload, env.Message().Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query retrieves messages from the DB
|
||||
func (d *DBStore) Query(query *storepb.HistoryQuery) (*storepb.Index, []gowakuPersistence.StoredMessage, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
d.log.Info(fmt.Sprintf("Loading records from the DB took %s", elapsed))
|
||||
}()
|
||||
|
||||
sqlQuery := `SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version
|
||||
FROM store_messages
|
||||
%s
|
||||
ORDER BY senderTimestamp %s, id %s, pubsubTopic %s, receiverTimestamp %s `
|
||||
|
||||
var conditions []string
|
||||
var parameters []interface{}
|
||||
paramCnt := 0
|
||||
|
||||
if query.PubsubTopic != "" {
|
||||
paramCnt++
|
||||
conditions = append(conditions, fmt.Sprintf("pubsubTopic = $%d", paramCnt))
|
||||
parameters = append(parameters, query.PubsubTopic)
|
||||
}
|
||||
|
||||
if len(query.ContentFilters) != 0 {
|
||||
var ctPlaceHolder []string
|
||||
for _, ct := range query.ContentFilters {
|
||||
if ct.ContentTopic != "" {
|
||||
paramCnt++
|
||||
ctPlaceHolder = append(ctPlaceHolder, fmt.Sprintf("$%d", paramCnt))
|
||||
parameters = append(parameters, ct.ContentTopic)
|
||||
}
|
||||
}
|
||||
conditions = append(conditions, "contentTopic IN ("+strings.Join(ctPlaceHolder, ", ")+")")
|
||||
}
|
||||
|
||||
usesCursor := false
|
||||
if query.PagingInfo.Cursor != nil {
|
||||
usesCursor = true
|
||||
var exists bool
|
||||
cursorDBKey := NewDBKey(uint64(query.PagingInfo.Cursor.SenderTime), uint64(query.PagingInfo.Cursor.ReceiverTime), query.PagingInfo.Cursor.PubsubTopic, query.PagingInfo.Cursor.Digest)
|
||||
|
||||
err := d.db.QueryRow("SELECT EXISTS(SELECT 1 FROM store_messages WHERE id = $1)",
|
||||
cursorDBKey.Bytes(),
|
||||
).Scan(&exists)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if exists {
|
||||
eqOp := ">"
|
||||
if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD {
|
||||
eqOp = "<"
|
||||
}
|
||||
paramCnt++
|
||||
conditions = append(conditions, fmt.Sprintf("id %s $%d", eqOp, paramCnt))
|
||||
|
||||
parameters = append(parameters, cursorDBKey.Bytes())
|
||||
} else {
|
||||
return nil, nil, ErrInvalidCursor
|
||||
}
|
||||
}
|
||||
|
||||
if query.GetStartTime() != 0 {
|
||||
if !usesCursor || query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD {
|
||||
paramCnt++
|
||||
conditions = append(conditions, fmt.Sprintf("id >= $%d", paramCnt))
|
||||
startTimeDBKey := NewDBKey(uint64(query.GetStartTime()), uint64(query.GetStartTime()), "", []byte{})
|
||||
parameters = append(parameters, startTimeDBKey.Bytes())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if query.GetEndTime() != 0 {
|
||||
if !usesCursor || query.PagingInfo.Direction == storepb.PagingInfo_FORWARD {
|
||||
paramCnt++
|
||||
conditions = append(conditions, fmt.Sprintf("id <= $%d", paramCnt))
|
||||
endTimeDBKey := NewDBKey(uint64(query.GetEndTime()), uint64(query.GetEndTime()), "", []byte{})
|
||||
parameters = append(parameters, endTimeDBKey.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
conditionStr := ""
|
||||
if len(conditions) != 0 {
|
||||
conditionStr = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
orderDirection := "ASC"
|
||||
if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD {
|
||||
orderDirection = "DESC"
|
||||
}
|
||||
|
||||
paramCnt++
|
||||
sqlQuery += fmt.Sprintf("LIMIT $%d", paramCnt)
|
||||
sqlQuery = fmt.Sprintf(sqlQuery, conditionStr, orderDirection, orderDirection, orderDirection, orderDirection)
|
||||
|
||||
stmt, err := d.db.Prepare(sqlQuery)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
pageSize := query.PagingInfo.PageSize + 1
|
||||
|
||||
parameters = append(parameters, pageSize)
|
||||
rows, err := stmt.Query(parameters...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var result []gowakuPersistence.StoredMessage
|
||||
for rows.Next() {
|
||||
record, err := d.GetStoredMessage(rows)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
result = append(result, record)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var cursor *storepb.Index
|
||||
if len(result) != 0 {
|
||||
if len(result) > int(query.PagingInfo.PageSize) {
|
||||
result = result[0:query.PagingInfo.PageSize]
|
||||
lastMsgIdx := len(result) - 1
|
||||
cursor = protocol.NewEnvelope(result[lastMsgIdx].Message, result[lastMsgIdx].ReceiverTime, result[lastMsgIdx].PubsubTopic).Index()
|
||||
}
|
||||
}
|
||||
|
||||
// The retrieved messages list should always be in chronological order
|
||||
if query.PagingInfo.Direction == storepb.PagingInfo_BACKWARD {
|
||||
for i, j := 0, len(result)-1; i < j; i, j = i+1, j-1 {
|
||||
result[i], result[j] = result[j], result[i]
|
||||
}
|
||||
}
|
||||
|
||||
return cursor, result, nil
|
||||
}
|
||||
|
||||
// MostRecentTimestamp returns an unix timestamp with the most recent senderTimestamp
|
||||
// in the message table
|
||||
func (d *DBStore) MostRecentTimestamp() (int64, error) {
|
||||
result := sql.NullInt64{}
|
||||
|
||||
err := d.db.QueryRow(`SELECT max(senderTimestamp) FROM store_messages`).Scan(&result)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
return result.Int64, nil
|
||||
}
|
||||
|
||||
// Count returns the number of rows in the message table
|
||||
func (d *DBStore) Count() (int, error) {
|
||||
var result int
|
||||
err := d.db.QueryRow(`SELECT COUNT(*) FROM store_messages`).Scan(&result)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetAll returns all the stored WakuMessages
|
||||
func (d *DBStore) GetAll() ([]gowakuPersistence.StoredMessage, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
elapsed := time.Since(start)
|
||||
d.log.Info("loading records from the DB", zap.Duration("duration", elapsed))
|
||||
}()
|
||||
|
||||
rows, err := d.db.Query("SELECT id, receiverTimestamp, senderTimestamp, contentTopic, pubsubTopic, payload, version FROM store_messages ORDER BY senderTimestamp ASC")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var result []gowakuPersistence.StoredMessage
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
record, err := d.GetStoredMessage(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, record)
|
||||
}
|
||||
|
||||
d.log.Info("DB returned records", zap.Int("count", len(result)))
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetStoredMessage is a helper function used to convert a `*sql.Rows` into a `StoredMessage`
|
||||
func (d *DBStore) GetStoredMessage(row *sql.Rows) (gowakuPersistence.StoredMessage, error) {
|
||||
var id []byte
|
||||
var receiverTimestamp int64
|
||||
var senderTimestamp int64
|
||||
var contentTopic string
|
||||
var payload []byte
|
||||
var version uint32
|
||||
var pubsubTopic string
|
||||
|
||||
err := row.Scan(&id, &receiverTimestamp, &senderTimestamp, &contentTopic, &pubsubTopic, &payload, &version)
|
||||
if err != nil {
|
||||
d.log.Error("scanning messages from db", zap.Error(err))
|
||||
return gowakuPersistence.StoredMessage{}, err
|
||||
}
|
||||
|
||||
msg := new(pb.WakuMessage)
|
||||
msg.ContentTopic = contentTopic
|
||||
msg.Payload = payload
|
||||
msg.Timestamp = &senderTimestamp
|
||||
msg.Version = &version
|
||||
|
||||
record := gowakuPersistence.StoredMessage{
|
||||
ID: id,
|
||||
PubsubTopic: pubsubTopic,
|
||||
ReceiverTime: receiverTimestamp,
|
||||
Message: msg,
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
103
vendor/github.com/status-im/status-go/wakuv2/persistence/queries.go
generated
vendored
Normal file
103
vendor/github.com/status-im/status-go/wakuv2/persistence/queries.go
generated
vendored
Normal file
@@ -0,0 +1,103 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Queries are the sqlite queries for a given table.
|
||||
type Queries struct {
|
||||
deleteQuery string
|
||||
existsQuery string
|
||||
getQuery string
|
||||
putQuery string
|
||||
queryQuery string
|
||||
prefixQuery string
|
||||
limitQuery string
|
||||
offsetQuery string
|
||||
getSizeQuery string
|
||||
}
|
||||
|
||||
// NewQueries creates a new set of queries for the passed table
|
||||
func NewQueries(tbl string, db *sql.DB) (*Queries, error) {
|
||||
err := CreateTable(db, tbl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Queries{
|
||||
deleteQuery: fmt.Sprintf("DELETE FROM %s WHERE key = $1", tbl),
|
||||
existsQuery: fmt.Sprintf("SELECT exists(SELECT 1 FROM %s WHERE key=$1)", tbl),
|
||||
getQuery: fmt.Sprintf("SELECT data FROM %s WHERE key = $1", tbl),
|
||||
putQuery: fmt.Sprintf("INSERT INTO %s (key, data) VALUES ($1, $2)", tbl),
|
||||
queryQuery: fmt.Sprintf("SELECT key, data FROM %s", tbl),
|
||||
prefixQuery: ` WHERE key LIKE '%s%%' ORDER BY key`,
|
||||
limitQuery: ` LIMIT %d`,
|
||||
offsetQuery: ` OFFSET %d`,
|
||||
getSizeQuery: fmt.Sprintf("SELECT length(data) FROM %s WHERE key = $1", tbl),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Delete returns the query for deleting a row.
|
||||
func (q Queries) Delete() string {
|
||||
return q.deleteQuery
|
||||
}
|
||||
|
||||
// Exists returns the query for determining if a row exists.
|
||||
func (q Queries) Exists() string {
|
||||
return q.existsQuery
|
||||
}
|
||||
|
||||
// Get returns the query for getting a row.
|
||||
func (q Queries) Get() string {
|
||||
return q.getQuery
|
||||
}
|
||||
|
||||
// Put returns the query for putting a row.
|
||||
func (q Queries) Put() string {
|
||||
return q.putQuery
|
||||
}
|
||||
|
||||
// Query returns the query for getting multiple rows.
|
||||
func (q Queries) Query() string {
|
||||
return q.queryQuery
|
||||
}
|
||||
|
||||
// Prefix returns the query fragment for getting a rows with a key prefix.
|
||||
func (q Queries) Prefix() string {
|
||||
return q.prefixQuery
|
||||
}
|
||||
|
||||
// Limit returns the query fragment for limiting results.
|
||||
func (q Queries) Limit() string {
|
||||
return q.limitQuery
|
||||
}
|
||||
|
||||
// Offset returns the query fragment for returning rows from a given offset.
|
||||
func (q Queries) Offset() string {
|
||||
return q.offsetQuery
|
||||
}
|
||||
|
||||
// GetSize returns the query for determining the size of a value.
|
||||
func (q Queries) GetSize() string {
|
||||
return q.getSizeQuery
|
||||
}
|
||||
|
||||
// CreateTable creates the table that will persist the peers
|
||||
func CreateTable(db *sql.DB, tableName string) error {
|
||||
sqlStmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (key TEXT NOT NULL PRIMARY KEY ON CONFLICT REPLACE, data BYTEA);", tableName)
|
||||
_, err := db.Exec(sqlStmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Clean(db *sql.DB, tableName string) error {
|
||||
// This is fully controlled by us
|
||||
sqlStmt := fmt.Sprintf("DELETE FROM %s;", tableName) // nolint: gosec
|
||||
_, err := db.Exec(sqlStmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
125
vendor/github.com/status-im/status-go/wakuv2/persistence/signed_messages.go
generated
vendored
Normal file
125
vendor/github.com/status-im/status-go/wakuv2/persistence/signed_messages.go
generated
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
)
|
||||
|
||||
// DBStore is a MessageProvider that has a *sql.DB connection
|
||||
type ProtectedTopicsStore struct {
|
||||
db *sql.DB
|
||||
log *zap.Logger
|
||||
|
||||
insertStmt *sql.Stmt
|
||||
fetchPrivKeyStmt *sql.Stmt
|
||||
deleteStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// Creates a new DB store using the db specified via options.
|
||||
// It will create a messages table if it does not exist and
|
||||
// clean up records according to the retention policy used
|
||||
func NewProtectedTopicsStore(log *zap.Logger, db *sql.DB) (*ProtectedTopicsStore, error) {
|
||||
insertStmt, err := db.Prepare("INSERT OR REPLACE INTO pubsubtopic_signing_key (topic, priv_key, pub_key) VALUES (?, ?, ?)")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fetchPrivKeyStmt, err := db.Prepare("SELECT priv_key FROM pubsubtopic_signing_key WHERE topic = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deleteStmt, err := db.Prepare("DELETE FROM pubsubtopic_signing_key WHERE topic = ?")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := new(ProtectedTopicsStore)
|
||||
result.log = log.Named("protected-topics-store")
|
||||
result.db = db
|
||||
result.insertStmt = insertStmt
|
||||
result.fetchPrivKeyStmt = fetchPrivKeyStmt
|
||||
result.deleteStmt = deleteStmt
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (p *ProtectedTopicsStore) Close() error {
|
||||
err := p.insertStmt.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.fetchPrivKeyStmt.Close()
|
||||
}
|
||||
|
||||
func (p *ProtectedTopicsStore) Insert(pubsubTopic string, privKey *ecdsa.PrivateKey, publicKey *ecdsa.PublicKey) error {
|
||||
var privKeyBytes []byte
|
||||
if privKey != nil {
|
||||
privKeyBytes = crypto.FromECDSA(privKey)
|
||||
}
|
||||
|
||||
pubKeyBytes := crypto.FromECDSAPub(publicKey)
|
||||
|
||||
_, err := p.insertStmt.Exec(pubsubTopic, privKeyBytes, pubKeyBytes)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProtectedTopicsStore) Delete(pubsubTopic string) error {
|
||||
_, err := p.deleteStmt.Exec(pubsubTopic)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *ProtectedTopicsStore) FetchPrivateKey(topic string) (privKey *ecdsa.PrivateKey, err error) {
|
||||
var privKeyBytes []byte
|
||||
err = p.fetchPrivKeyStmt.QueryRow(topic).Scan(&privKeyBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return crypto.ToECDSA(privKeyBytes)
|
||||
}
|
||||
|
||||
type ProtectedTopic struct {
|
||||
PubKey *ecdsa.PublicKey
|
||||
Topic string
|
||||
}
|
||||
|
||||
func (p *ProtectedTopicsStore) ProtectedTopics() ([]ProtectedTopic, error) {
|
||||
rows, err := p.db.Query("SELECT pub_key, topic FROM pubsubtopic_signing_key")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []ProtectedTopic
|
||||
for rows.Next() {
|
||||
var pubKeyBytes []byte
|
||||
var topic string
|
||||
err := rows.Scan(&pubKeyBytes, &topic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubk, err := crypto.UnmarshalPubkey(pubKeyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result = append(result, ProtectedTopic{
|
||||
PubKey: pubk,
|
||||
Topic: topic,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
Reference in New Issue
Block a user